Commit fe7e9cb3 authored by Matej Choma's avatar Matej Choma

preparation for the first training

parent 9a08a317
from torch.utils import data
import numpy as np
import os
from pathlib import Path
from PIL import Image
from random import randint, seed
from time import time
class SRDataset(data.Dataset):
"""SRDataset"""
def __init__(self, dir_hr, dir_lr, lr_scale=4, crop=None):
"""init"""
super().__init__()
self.dir_hr = Path(dir_hr)
self.dir_lr = Path(dir_lr)
self.hr = sorted(self.dir_hr.glob('*.png'))
self.lr = sorted(self.dir_lr.glob('*.png'))
# it is up to user to assure that the images in directories are corresponding
assert len(self.hr) == len(self.lr), f'hr directory has {len(self.hr)} files, while lr has {len(self.lr)}'
self.lr_scale = lr_scale
assert crop == None or (isinstance(crop, tuple) and len(crop) == 2),\
'Crop should be tuple (H, W).'
self.crop = crop
def __len__(self):
""" Denotes the total number of samples.
"""
return len(self.hr)
@staticmethod
def _load_img(path):
img = np.array(Image.open(path).convert('RGB'))
img = img.transpose(2, 0, 1)
return img
def _random_crop(self, lr, hr):
seed(int(time()))
h = hr.shape[1] - self.crop[0]
w = hr.shape[2] - self.crop[1]
y, x = randint(0, h - 1)//4*4, randint(0, w - 1)//4*4
_hr = hr[..., y:y+self.crop[0], x:x+self.crop[1]]
y, x, c0, c1 = y//self.lr_scale, x//self.lr_scale, self.crop[0]//self.lr_scale, self.crop[1]//self.lr_scale
_lr = lr[..., y:y+c0, x:x+c1]
return _lr, _hr
def __getitem__(self, index):
""" Generates one sample of data.
"""
assert index < self.__len__(), f'Index {index} out of range {{0,...,{self.__len__() - 1}}}.'
LR = self._load_img(self.lr[index])
HR = self._load_img(self.hr[index])
if self.crop is not None:
LR, HR = self._random_crop(LR, HR)
return LR, HR
import numpy as np
import torch
def torch_scale(x, from_, to):
assert isinstance(from_, tuple) and len(from_) == 2
assert isinstance(to, tuple) and len(to) == 2
res = torch.clip(x, *from_)
return (res - from_[0])/(from_[1] - from_[0])*(to[1] - to[0]) + to[0]
......@@ -39,6 +39,7 @@ def show(x, size=(15, 10), title=None, vmax=None, cmap=plt.cm.nipy_spectral, **k
if (x.ndim == 2):
return plt.imshow(x, cmap=cmap, vmax=vmax, **kwargs)
else:
return plt.imshow(x, **kwargs)
_x = x.transpose(1, 2, 0)
return plt.imshow(_x, **kwargs)
plt.tight_layout()
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment