datasets.py 1.95 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
from torch.utils import data

import numpy as np
import os
from pathlib import Path
from PIL import Image
from random import randint, seed
from time import time

class SRDataset(data.Dataset):
    """SRDataset"""
    def __init__(self, dir_hr, dir_lr, lr_scale=4, crop=None):
        """init"""
        super().__init__()

        self.dir_hr = Path(dir_hr)
        self.dir_lr = Path(dir_lr)

        self.hr = sorted(self.dir_hr.glob('*.png'))
        self.lr = sorted(self.dir_lr.glob('*.png'))
        # it is up to user to assure that the images in directories are corresponding
        assert len(self.hr) == len(self.lr), f'hr directory has {len(self.hr)} files, while lr has {len(self.lr)}'

        self.lr_scale = lr_scale

        assert crop == None or (isinstance(crop, tuple) and len(crop) == 2),\
            'Crop should be tuple (H, W).'
        self.crop = crop

    def __len__(self):
        """ Denotes the total number of samples.
        """
        return len(self.hr)

    @staticmethod
    def _load_img(path):
        img = np.array(Image.open(path).convert('RGB'))
        img = img.transpose(2, 0, 1)
        return img

    def _random_crop(self, lr, hr):
        seed(int(time()))
Matej Choma's avatar
Matej Choma committed
43
        
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        h = hr.shape[1] - self.crop[0]
        w = hr.shape[2] - self.crop[1]

        y, x = randint(0, h - 1)//4*4, randint(0, w - 1)//4*4
        _hr = hr[..., y:y+self.crop[0], x:x+self.crop[1]]

        y, x, c0, c1 = y//self.lr_scale, x//self.lr_scale, self.crop[0]//self.lr_scale, self.crop[1]//self.lr_scale
        _lr = lr[..., y:y+c0, x:x+c1]

        return _lr, _hr

    def __getitem__(self, index):
        """ Generates one sample of data.
        """
        assert index < self.__len__(), f'Index {index} out of range {{0,...,{self.__len__() - 1}}}.'

        LR = self._load_img(self.lr[index])
        HR = self._load_img(self.hr[index])

        if self.crop is not None:
            LR, HR = self._random_crop(LR, HR)
            
        return LR, HR