Commit 19187d3d authored by Matej Choma's avatar Matej Choma
Browse files

wip finishing preparation

parent ad9e70f0
......@@ -40,7 +40,7 @@ class SRDataset(data.Dataset):
def _random_crop(self, lr, hr):
seed(int(time()))
h = hr.shape[1] - self.crop[0]
w = hr.shape[2] - self.crop[1]
......
import torch
from torch import nn
from core.base import base_conv2d, base_upsample, ResidualCatConv
from core.base import BaseModule, base_conv2d, base_upsample, ResidualCatConv
import math
class ChannelAttention(nn.Module):
......
%% Cell type:code id: tags:
 
``` python
from core import datasets
from core.utils import visualization
 
from torch.utils.data import DataLoader
```
 
%% Cell type:code id: tags:
 
``` python
d = datasets.SRDataset(dir_hr='../datasets/DIV2K/DIV2K_train_HR/', dir_lr='../datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
dl = DataLoader(d, batch_size=8, num_workers=4)
d = datasets.SRDataset(dir_hr='data/datasets/DIV2K/train_HR_480/', dir_lr='data/datasets/DIV2K/train_LR_120_unknown/', crop=(128, 128))
# d = datasets.SRDataset(dir_hr='data/datasets/DIV2K/valid_HR_1000/', dir_lr='data/datasets/DIV2K/valid_LR_250_unknown/')
```
 
%% Cell type:code id: tags:
 
``` python
for i,o in dl:
print(i.shape)
break
i, o = d[19]
visualization.show(i)
visualization.show(o)
```
 
%% Output
 
torch.Size([8, 3, 32, 32])
<matplotlib.image.AxesImage at 0x7f8679597a90>
 
%% Cell type:code id: tags:
 
``` python
i, o = df[0]
visualization.show(i)
visualization.show(o)
```
 
%% Output
 
<matplotlib.image.AxesImage at 0x7f9670138210>
<matplotlib.image.AxesImage at 0x7f1c02bdcc90>