Commit 19187d3d authored by Matej Choma's avatar Matej Choma
Browse files

wip finishing preparation

parent ad9e70f0
......@@ -40,7 +40,7 @@ class SRDataset(data.Dataset):
def _random_crop(self, lr, hr):
seed(int(time()))
h = hr.shape[1] - self.crop[0]
w = hr.shape[2] - self.crop[1]
......
import torch
from torch import nn
from core.base import base_conv2d, base_upsample, ResidualCatConv
from core.base import BaseModule, base_conv2d, base_upsample, ResidualCatConv
import math
class ChannelAttention(nn.Module):
......
This diff is collapsed.
%% Cell type:code id: tags:
 
``` python
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data import random_split
 
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
 
from core import datasets
from core.utils import visualization
from core.utils.helpers import torch_scale
from core.metrics import ssim, psnr
 
from core.base import BaseModule
from core.modules import UnetBT, BaseUpsampler
from core.module_unet import UnetCascade
 
import numpy as np
```
 
%% Cell type:code id: tags:
 
``` python
```
 
%% Cell type:code id: tags:
 
``` python
HPARAMS = {
'batch_size': 16,
'loss': nn.MSELoss(),
'val_loss': {
'ssim': ssim,
'psnr': psnr
}
}
 
# logger
logger = TensorBoardLogger('lightning_logs/', name='srcnn', default_hp_metric=False)
 
# data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160])
 
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
 
# model
model = SRCNN(**HPARAMS)
model.float() ;
```
 
%% Cell type:code id: tags:
 
``` python
 
#######################################
# RCAN
#######################################
HPARAMS = {
'n_rg': 10,
'n_rcab': 20,
'n_feat': 64,
'kernel_size': 3,
'ca_reduction': 16,
'act': nn.ReLU(inplace=True),
'lr_scale': 4,
 
'loss': nn.L1Loss(),
'val_loss': {
'ssim': ssim,
'psnr': psnr
},
'batch_size': 16,
 
'range_in': (0, 1),
'range_out': (0, 1),
}
 
# logger
logger = TensorBoardLogger('lightning_logs/', name='rcan', default_hp_metric=False)
 
# data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
 
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
 
# model
model = RCAN(**HPARAMS)
model.float() ;
```
 
%% Cell type:code id: tags:
 
``` python
#######################################
# BASE UPSAMPLER
#######################################
HPARAMS = {
'lr_scale': 4,
'upsample_mode': 'bicubic',
 
'loss': nn.L1Loss(),
'val_loss': {
'ssim': ssim,
'psnr': psnr
},
 
'range_in': (0, 1),
'range_out': (0, 1),
'range_base': (0, 255),
}
 
# logger
logger = TensorBoardLogger('lightning_logs/', name='base_upsampler', default_hp_metric=False)
 
# data
d_train = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_unknown/X4/', crop=(128, 128))
d_val = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
 
train_loader = DataLoader(d_train, batch_size=1, num_workers=4)
val_loader = DataLoader(d_val, batch_size=1, num_workers=4)
 
# model
model = BaseUpsampler(**HPARAMS)
model.float() ;
```
 
%% Cell type:code id: tags:
 
``` python
#######################################
# U-Net Cascade
#######################################
HPARAMS = {
'lr_scale': 4,
 
'loss': nn.L1Loss(),
'val_loss': {
'ssim': ssim,
'psnr': psnr
},
'batch_size': 16,
 
'range_in': (0, 1),
'range_out': (0, 1),
'range_base': (0, 256)
}
 
# logger
logger = TensorBoardLogger('lightning_logs/', name='unetcascade', default_hp_metric=False)
 
# data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
 
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
 
# model
model = UnetCascade(**HPARAMS)
model.float() ;
```
 
%% Cell type:code id: tags:
 
``` python
#######################################
# SRResNet
#######################################
HPARAMS = {
'lr_scale': 4,
 
'loss': nn.MSELoss(),
'val_loss': {
'ssim': ssim,
'psnr': psnr
},
 
'range_in': (0, 1),
'range_out': (0, 1),
'range_base': (0, 256)
}
BATCH_SIZE = 16
 
# logger
logger = TensorBoardLogger('lightning_logs/', name='srresnet', default_hp_metric=False)
 
# data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
 
train_loader = DataLoader(d_train, batch_size=BATCH_SIZE, num_workers=4)
val_loader = DataLoader(d_val, batch_size=BATCH_SIZE, num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
 
# model
model = SRResNet(**HPARAMS)
model.float() ;
```
 
%% Cell type:code id: tags:
 
``` python
trainer = pl.Trainer(
gpus=1,
auto_select_gpus=True,
logger=logger,
progress_bar_refresh_rate=10
)
trainer.test(model, val_full_loader, ckpt_path='lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt')
```
 
%% Output
 
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
 
Testing: 100%|██████████| 100/100 [02:25<00:00, 1.46s/it]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.4379, device='cuda:0')}
--------------------------------------------------------------------------------
 
[{'test_loss': 0.4379383325576782}]
 
%% Cell type:code id: tags:
 
``` python
# checkpoints
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
filename='{epoch:02d}-{val_loss:.2f}',
save_top_k=3,
mode='min',
)
 
# early stopping
# early_stop_callback = EarlyStopping(
# monitor='train_loss',
# min_delta=0.0000,
# patience=3,
# verbose=False,
# mode='min'
# )
 
# training
trainer = pl.Trainer(
gpus=1,
auto_select_gpus=True,
callbacks=[checkpoint_callback], # early_stop_callback],
# resume_from_checkpoint='lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt',
logger=logger,
progress_bar_refresh_rate=10
)
trainer.fit(model, train_loader, val_loader)
```
 
%% Output
 
GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: lightning_logs/srresnet
| Name | Type | Params
---------------------------------
0 | loss | MSELoss | 0
1 | net | SRGen | 1.5 M
---------------------------------
1.5 M Trainable params
0 Non-trainable params
1.5 M Total params
 
Epoch 0: 80%|████████ | 40/50 [00:12<00:03, 3.09it/s, loss=0.0319, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 50/50 [00:16<00:00, 2.97it/s, loss=0.0319, v_num=0]
Epoch 1: 80%|████████ | 40/50 [00:12<00:03, 3.08it/s, loss=0.0237, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 1: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0237, v_num=0]
Epoch 2: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0259, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 2: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0259, v_num=0]
Epoch 3: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0215, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 3: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0215, v_num=0]
Epoch 4: 80%|████████ | 40/50 [00:13<00:03, 3.05it/s, loss=0.0182, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 4: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0182, v_num=0]
Epoch 5: 80%|████████ | 40/50 [00:13<00:03, 3.03it/s, loss=0.0185, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 5: 100%|██████████| 50/50 [00:17<00:00, 2.90it/s, loss=0.0185, v_num=0]
Epoch 6: 80%|████████ | 40/50 [00:13<00:03, 3.04it/s, loss=0.0174, v_num=0]
Validating: 0it [00:00, ?it/s]
Epoch 6: 100%|██████████| 50/50 [00:17<00:00, 2.92it/s, loss=0.0174, v_num=0]
Epoch 7: 60%|██████ | 30/50 [00:11<00:07, 2.54it/s, loss=0.0176, v_num=0]
 
/home/matej/miniconda/envs/fitmvi/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...
warnings.warn(*args, **kwargs)
 
1
 
%% Cell type:code id: tags:
 
``` python
d = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_4_bicubic/', crop=(256, 256))
i, o = d[6]
```
 
%% Cell type:code id: tags:
 
``` python
# t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_12/checkpoints/epoch=23-val_loss=0.17.ckpt', val_loss=None)
# t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt')
 
# t = SRCNN.load_from_checkpoint('lightning_logs/srcnn/version_1/checkpoints/epoch=18-val_loss=0.06.ckpt')
 
# t = RCAN.load_from_checkpoint('lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt')
 
t = UnetCascade.load_from_checkpoint('lightning_logs/unetcascade/version_16/checkpoints/epoch=22-val_loss=0.07.ckpt')
p = t.predict(i)
```
 
%% Cell type:code id: tags:
 
``` python
import matplotlib.pyplot as plt
from skimage.transform import resize
 
plt.figure(figsize=(22,22)) # figsize is in inches
 
plt.subplot(221)
visualization.show(i, fig=False, title='Input')
 
plt.subplot(222)
visualization.show(o, fig=False, title='Output')
 
plt.subplot(223)
p = t.predict(i)
visualization.show(p, fig=False, title='Prediction')
 
plt.subplot(224)
b = BaseUpsampler(upsample_mode='bicubic').predict(i)
visualization.show(b, fig=False, title='Bicubic')
```
 
%% Output
 
<matplotlib.image.AxesImage at 0x7f2206dc2940>