Commit 19187d3d authored by Matej Choma's avatar Matej Choma
Browse files

wip finishing preparation

parent ad9e70f0
...@@ -40,7 +40,7 @@ class SRDataset(data.Dataset): ...@@ -40,7 +40,7 @@ class SRDataset(data.Dataset):
def _random_crop(self, lr, hr): def _random_crop(self, lr, hr):
seed(int(time())) seed(int(time()))
h = hr.shape[1] - self.crop[0] h = hr.shape[1] - self.crop[0]
w = hr.shape[2] - self.crop[1] w = hr.shape[2] - self.crop[1]
......
import torch import torch
from torch import nn from torch import nn
from core.base import base_conv2d, base_upsample, ResidualCatConv from core.base import BaseModule, base_conv2d, base_upsample, ResidualCatConv
import math import math
class ChannelAttention(nn.Module): class ChannelAttention(nn.Module):
......
This diff is collapsed.
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data import random_split from torch.utils.data import random_split
   
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers import TensorBoardLogger
   
from core import datasets from core import datasets
from core.utils import visualization from core.utils import visualization
from core.utils.helpers import torch_scale from core.utils.helpers import torch_scale
from core.metrics import ssim, psnr from core.metrics import ssim, psnr
   
from core.base import BaseModule from core.base import BaseModule
from core.modules import UnetBT, BaseUpsampler from core.modules import UnetBT, BaseUpsampler
from core.module_unet import UnetCascade
   
import numpy as np import numpy as np
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
HPARAMS = { HPARAMS = {
'batch_size': 16, 'batch_size': 16,
'loss': nn.MSELoss(), 'loss': nn.MSELoss(),
'val_loss': { 'val_loss': {
'ssim': ssim, 'ssim': ssim,
'psnr': psnr 'psnr': psnr
} }
} }
   
# logger # logger
logger = TensorBoardLogger('lightning_logs/', name='srcnn', default_hp_metric=False) logger = TensorBoardLogger('lightning_logs/', name='srcnn', default_hp_metric=False)
   
# data # data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128)) dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160]) d_train, d_val = random_split(dataset, [640, 160])
   
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4) train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4) val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
   
# model # model
model = SRCNN(**HPARAMS) model = SRCNN(**HPARAMS)
model.float() ; model.float() ;
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
   
####################################### #######################################
# RCAN # RCAN
####################################### #######################################
HPARAMS = { HPARAMS = {
'n_rg': 10, 'n_rg': 10,
'n_rcab': 20, 'n_rcab': 20,
'n_feat': 64, 'n_feat': 64,
'kernel_size': 3, 'kernel_size': 3,
'ca_reduction': 16, 'ca_reduction': 16,
'act': nn.ReLU(inplace=True), 'act': nn.ReLU(inplace=True),
'lr_scale': 4, 'lr_scale': 4,
   
'loss': nn.L1Loss(), 'loss': nn.L1Loss(),
'val_loss': { 'val_loss': {
'ssim': ssim, 'ssim': ssim,
'psnr': psnr 'psnr': psnr
}, },
'batch_size': 16, 'batch_size': 16,
   
'range_in': (0, 1), 'range_in': (0, 1),
'range_out': (0, 1), 'range_out': (0, 1),
} }
   
# logger # logger
logger = TensorBoardLogger('lightning_logs/', name='rcan', default_hp_metric=False) logger = TensorBoardLogger('lightning_logs/', name='rcan', default_hp_metric=False)
   
# data # data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128)) dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160]) d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/') d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
   
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4) train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4) val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4) val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
   
# model # model
model = RCAN(**HPARAMS) model = RCAN(**HPARAMS)
model.float() ; model.float() ;
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
####################################### #######################################
# BASE UPSAMPLER # BASE UPSAMPLER
####################################### #######################################
HPARAMS = { HPARAMS = {
'lr_scale': 4, 'lr_scale': 4,
'upsample_mode': 'bicubic', 'upsample_mode': 'bicubic',
   
'loss': nn.L1Loss(), 'loss': nn.L1Loss(),
'val_loss': { 'val_loss': {
'ssim': ssim, 'ssim': ssim,
'psnr': psnr 'psnr': psnr
}, },
   
'range_in': (0, 1), 'range_in': (0, 1),
'range_out': (0, 1), 'range_out': (0, 1),
'range_base': (0, 255), 'range_base': (0, 255),
} }
   
# logger # logger
logger = TensorBoardLogger('lightning_logs/', name='base_upsampler', default_hp_metric=False) logger = TensorBoardLogger('lightning_logs/', name='base_upsampler', default_hp_metric=False)
   
# data # data
d_train = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_unknown/X4/', crop=(128, 128)) d_train = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_unknown/X4/', crop=(128, 128))
d_val = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/') d_val = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
   
train_loader = DataLoader(d_train, batch_size=1, num_workers=4) train_loader = DataLoader(d_train, batch_size=1, num_workers=4)
val_loader = DataLoader(d_val, batch_size=1, num_workers=4) val_loader = DataLoader(d_val, batch_size=1, num_workers=4)
   
# model # model
model = BaseUpsampler(**HPARAMS) model = BaseUpsampler(**HPARAMS)
model.float() ; model.float() ;
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
####################################### #######################################
# U-Net Cascade # U-Net Cascade
####################################### #######################################
HPARAMS = { HPARAMS = {
'lr_scale': 4, 'lr_scale': 4,
   
'loss': nn.L1Loss(), 'loss': nn.L1Loss(),
'val_loss': { 'val_loss': {
'ssim': ssim, 'ssim': ssim,
'psnr': psnr 'psnr': psnr
}, },
'batch_size': 16, 'batch_size': 16,
   
'range_in': (0, 1), 'range_in': (0, 1),
'range_out': (0, 1), 'range_out': (0, 1),
'range_base': (0, 256) 'range_base': (0, 256)
} }
   
# logger # logger
logger = TensorBoardLogger('lightning_logs/', name='unetcascade', default_hp_metric=False) logger = TensorBoardLogger('lightning_logs/', name='unetcascade', default_hp_metric=False)
   
# data # data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128)) dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160]) d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/') d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
   
train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4) train_loader = DataLoader(d_train, batch_size=HPARAMS['batch_size'], num_workers=4)
val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4) val_loader = DataLoader(d_val, batch_size=HPARAMS['batch_size'], num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4) val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
   
# model # model
model = UnetCascade(**HPARAMS) model = UnetCascade(**HPARAMS)
model.float() ; model.float() ;
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
####################################### #######################################
# SRResNet # SRResNet
####################################### #######################################
HPARAMS = { HPARAMS = {
'lr_scale': 4, 'lr_scale': 4,
   
'loss': nn.MSELoss(), 'loss': nn.MSELoss(),
'val_loss': { 'val_loss': {
'ssim': ssim, 'ssim': ssim,
'psnr': psnr 'psnr': psnr
}, },
   
'range_in': (0, 1), 'range_in': (0, 1),
'range_out': (0, 1), 'range_out': (0, 1),
'range_base': (0, 256) 'range_base': (0, 256)
} }
BATCH_SIZE = 16 BATCH_SIZE = 16
   
# logger # logger
logger = TensorBoardLogger('lightning_logs/', name='srresnet', default_hp_metric=False) logger = TensorBoardLogger('lightning_logs/', name='srresnet', default_hp_metric=False)
   
# data # data
dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128)) dataset = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_train_HR/', dir_lr='data/datasets/DIV2K/DIV2K_train_LR_mild/', crop=(128, 128))
d_train, d_val = random_split(dataset, [640, 160]) d_train, d_val = random_split(dataset, [640, 160])
d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/') d_val_full = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/')
   
train_loader = DataLoader(d_train, batch_size=BATCH_SIZE, num_workers=4) train_loader = DataLoader(d_train, batch_size=BATCH_SIZE, num_workers=4)
val_loader = DataLoader(d_val, batch_size=BATCH_SIZE, num_workers=4) val_loader = DataLoader(d_val, batch_size=BATCH_SIZE, num_workers=4)
val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4) val_full_loader = DataLoader(d_val_full, batch_size=1, num_workers=4)
   
# model # model
model = SRResNet(**HPARAMS) model = SRResNet(**HPARAMS)
model.float() ; model.float() ;
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=1, gpus=1,
auto_select_gpus=True, auto_select_gpus=True,
logger=logger, logger=logger,
progress_bar_refresh_rate=10 progress_bar_refresh_rate=10
) )
trainer.test(model, val_full_loader, ckpt_path='lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt') trainer.test(model, val_full_loader, ckpt_path='lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt')
``` ```
   
%% Output %% Output
   
GPU available: True, used: True GPU available: True, used: True
TPU available: None, using: 0 TPU cores TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
   
Testing: 100%|██████████| 100/100 [02:25<00:00, 1.46s/it] Testing: 100%|██████████| 100/100 [02:25<00:00, 1.46s/it]
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS DATALOADER:0 TEST RESULTS
{'test_loss': tensor(0.4379, device='cuda:0')} {'test_loss': tensor(0.4379, device='cuda:0')}
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
   
[{'test_loss': 0.4379383325576782}] [{'test_loss': 0.4379383325576782}]
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
# checkpoints # checkpoints
checkpoint_callback = ModelCheckpoint( checkpoint_callback = ModelCheckpoint(
monitor='val_loss', monitor='val_loss',
filename='{epoch:02d}-{val_loss:.2f}', filename='{epoch:02d}-{val_loss:.2f}',
save_top_k=3, save_top_k=3,
mode='min', mode='min',
) )
   
# early stopping # early stopping
# early_stop_callback = EarlyStopping( # early_stop_callback = EarlyStopping(
# monitor='train_loss', # monitor='train_loss',
# min_delta=0.0000, # min_delta=0.0000,
# patience=3, # patience=3,
# verbose=False, # verbose=False,
# mode='min' # mode='min'
# ) # )
   
# training # training
trainer = pl.Trainer( trainer = pl.Trainer(
gpus=1, gpus=1,
auto_select_gpus=True, auto_select_gpus=True,
callbacks=[checkpoint_callback], # early_stop_callback], callbacks=[checkpoint_callback], # early_stop_callback],
# resume_from_checkpoint='lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt', # resume_from_checkpoint='lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt',
logger=logger, logger=logger,
progress_bar_refresh_rate=10 progress_bar_refresh_rate=10
) )
trainer.fit(model, train_loader, val_loader) trainer.fit(model, train_loader, val_loader)
``` ```
   
%% Output %% Output
   
GPU available: True, used: True GPU available: True, used: True
TPU available: None, using: 0 TPU cores TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0] LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Missing logger folder: lightning_logs/srresnet Missing logger folder: lightning_logs/srresnet
| Name | Type | Params | Name | Type | Params
--------------------------------- ---------------------------------
0 | loss | MSELoss | 0 0 | loss | MSELoss | 0
1 | net | SRGen | 1.5 M 1 | net | SRGen | 1.5 M
--------------------------------- ---------------------------------
1.5 M Trainable params 1.5 M Trainable params
0 Non-trainable params 0 Non-trainable params
1.5 M Total params 1.5 M Total params
   
Epoch 0: 80%|████████ | 40/50 [00:12<00:03, 3.09it/s, loss=0.0319, v_num=0] Epoch 0: 80%|████████ | 40/50 [00:12<00:03, 3.09it/s, loss=0.0319, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 0: 100%|██████████| 50/50 [00:16<00:00, 2.97it/s, loss=0.0319, v_num=0] Epoch 0: 100%|██████████| 50/50 [00:16<00:00, 2.97it/s, loss=0.0319, v_num=0]
Epoch 1: 80%|████████ | 40/50 [00:12<00:03, 3.08it/s, loss=0.0237, v_num=0] Epoch 1: 80%|████████ | 40/50 [00:12<00:03, 3.08it/s, loss=0.0237, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 1: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0237, v_num=0] Epoch 1: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0237, v_num=0]
Epoch 2: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0259, v_num=0] Epoch 2: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0259, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 2: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0259, v_num=0] Epoch 2: 100%|██████████| 50/50 [00:16<00:00, 2.95it/s, loss=0.0259, v_num=0]
Epoch 3: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0215, v_num=0] Epoch 3: 80%|████████ | 40/50 [00:13<00:03, 3.06it/s, loss=0.0215, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 3: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0215, v_num=0] Epoch 3: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0215, v_num=0]
Epoch 4: 80%|████████ | 40/50 [00:13<00:03, 3.05it/s, loss=0.0182, v_num=0] Epoch 4: 80%|████████ | 40/50 [00:13<00:03, 3.05it/s, loss=0.0182, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 4: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0182, v_num=0] Epoch 4: 100%|██████████| 50/50 [00:17<00:00, 2.94it/s, loss=0.0182, v_num=0]
Epoch 5: 80%|████████ | 40/50 [00:13<00:03, 3.03it/s, loss=0.0185, v_num=0] Epoch 5: 80%|████████ | 40/50 [00:13<00:03, 3.03it/s, loss=0.0185, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 5: 100%|██████████| 50/50 [00:17<00:00, 2.90it/s, loss=0.0185, v_num=0] Epoch 5: 100%|██████████| 50/50 [00:17<00:00, 2.90it/s, loss=0.0185, v_num=0]
Epoch 6: 80%|████████ | 40/50 [00:13<00:03, 3.04it/s, loss=0.0174, v_num=0] Epoch 6: 80%|████████ | 40/50 [00:13<00:03, 3.04it/s, loss=0.0174, v_num=0]
Validating: 0it [00:00, ?it/s] Validating: 0it [00:00, ?it/s]
Epoch 6: 100%|██████████| 50/50 [00:17<00:00, 2.92it/s, loss=0.0174, v_num=0] Epoch 6: 100%|██████████| 50/50 [00:17<00:00, 2.92it/s, loss=0.0174, v_num=0]
Epoch 7: 60%|██████ | 30/50 [00:11<00:07, 2.54it/s, loss=0.0176, v_num=0] Epoch 7: 60%|██████ | 30/50 [00:11<00:07, 2.54it/s, loss=0.0176, v_num=0]
   
/home/matej/miniconda/envs/fitmvi/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown... /home/matej/miniconda/envs/fitmvi/lib/python3.8/site-packages/pytorch_lightning/utilities/distributed.py:49: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...
warnings.warn(*args, **kwargs) warnings.warn(*args, **kwargs)
   
1 1
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
d = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_4_bicubic/', crop=(256, 256)) d = datasets.SRDataset(dir_hr='data/datasets/DIV2K/DIV2K_valid_HR/', dir_lr='data/datasets/DIV2K/DIV2K_valid_LR_4_bicubic/', crop=(256, 256))
i, o = d[6] i, o = d[6]
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
# t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_12/checkpoints/epoch=23-val_loss=0.17.ckpt', val_loss=None) # t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_12/checkpoints/epoch=23-val_loss=0.17.ckpt', val_loss=None)
# t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt') # t = UnetBT.load_from_checkpoint('lightning_logs/unetbt_128/version_4/checkpoints/epoch=30-val_loss=0.15.ckpt')
   
# t = SRCNN.load_from_checkpoint('lightning_logs/srcnn/version_1/checkpoints/epoch=18-val_loss=0.06.ckpt') # t = SRCNN.load_from_checkpoint('lightning_logs/srcnn/version_1/checkpoints/epoch=18-val_loss=0.06.ckpt')
   
# t = RCAN.load_from_checkpoint('lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt') # t = RCAN.load_from_checkpoint('lightning_logs/rcan/version_3/checkpoints/epoch=03-val_loss=0.10.ckpt')
   
t = UnetCascade.load_from_checkpoint('lightning_logs/unetcascade/version_16/checkpoints/epoch=22-val_loss=0.07.ckpt') t = UnetCascade.load_from_checkpoint('lightning_logs/unetcascade/version_16/checkpoints/epoch=22-val_loss=0.07.ckpt')
p = t.predict(i) p = t.predict(i)
``` ```
   
%% Cell type:code id: tags: %% Cell type:code id: tags:
   
``` python ``` python
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from skimage.transform import resize from skimage.transform import resize
   
plt.figure(figsize=(22,22)) # figsize is in inches plt.figure(figsize=(22,22)) # figsize is in inches
   
plt.subplot(221) plt.subplot(221)
visualization.show(i, fig=False, title='Input') visualization.show(i, fig=False, title='Input')
   
plt.subplot(222) plt.subplot(222)
visualization.show(o, fig=False, title='Output') visualization.show(o, fig=False, title='Output')
   
plt.subplot(223) plt.subplot(223)
p = t.predict(i) p = t.predict(i)
visualization.show(p, fig=False, title='Prediction') visualization.show(p, fig=False, title='Prediction')
   
plt.subplot(224) plt.subplot(224)
b = BaseUpsampler(upsample_mode='bicubic').predict(i) b = BaseUpsampler(upsample_mode='bicubic').predict(i)
visualization.show(b, fig=False, title='Bicubic') visualization.show(b, fig=False, title='Bicubic')
``` ```
   
%% Output %% Output
   
<matplotlib.image.AxesImage at 0x7f2206dc2940> <matplotlib.image.AxesImage at 0x7f2206dc2940>