Commit 45053b57 authored by Matej Choma's avatar Matej Choma

training in progress

parent 19187d3d
......@@ -14,6 +14,8 @@ class BaseModule(pl.LightningModule):
self.range_in = range_in
self.range_out = range_out
self.range_base = range_base
self.test_prefix = 'test_epoch'
def predict(self, x, ret_np=True):
_x = torch.tensor(x, device=torch.device('cuda' if self.on_gpu else 'cpu'))[None, ...].float()
......@@ -34,7 +36,7 @@ class BaseModule(pl.LightningModule):
hr_hat = self(lr)
loss = self.loss(hr_hat, hr)
self.log('train_loss', loss)
self.log('train/loss', loss)
return loss
......@@ -44,7 +46,7 @@ class BaseModule(pl.LightningModule):
loss += o['loss']
loss /= len(outputs)
self.logger.log_metrics({'train_epoch_loss': loss}, step=self.current_epoch)
self.logger.log_metrics({'train/epoch_loss': loss}, step=self.current_epoch)
# --
......@@ -71,7 +73,7 @@ class BaseModule(pl.LightningModule):
loss += o[key]
loss /= len(outputs)
res[f'val_epoch_{key}'] = loss
res[f'val_epoch/{key}'] = loss
self.logger.log_metrics(res, step=self.current_epoch)
......@@ -97,8 +99,8 @@ class BaseModule(pl.LightningModule):
for o in outputs:
loss += o[key]
loss /= len(outputs)
res[f'test_epoch_{key}'] = loss
res[f'{self.test_prefix}/{key}'] = loss
self.logger.log_metrics(res, step=self.current_epoch)
......
import torch
from torch import nn
from core.base import base_conv2d, BaseModule
from core.utils.helpers import torch_scale
import math
class ResBlock(nn.Module):
def __init__(self, n_feat=64, kernel_size=3, act=nn.PReLU(), bn=True):
super().__init__()
module_list = []
for i in range(2):
module_list.append(base_conv2d(n_feat, n_feat, kernel_size, bn=bn, act=act if i == 0 else None))
self.net = nn.Sequential(*module_list)
def forward(self, x):
y = self.net(x)
return y + x
class SRGen(nn.Module):
def __init__(self, n_feat=64, kernel_size=3, n_blocks=16, lr_scale=4, act=nn.PReLU(), bn=True):
super().__init__()
self.head = base_conv2d(3, n_feat, kernel_size*3, act=act)
self.res_blocks = [ResBlock(n_feat=n_feat, kernel_size=kernel_size, act=act, bn=bn)
for _ in range(n_blocks)]
self.res_blocks = nn.ModuleList(self.res_blocks)
self.res_conv = base_conv2d(n_feat, n_feat, kernel_size, bn=bn)
upsample = [nn.Sequential(
base_conv2d(n_feat, n_feat*4, kernel_size),
nn.PixelShuffle(2),
act
) for _ in range(int(math.log(lr_scale, 2)))]
self.upsample = nn.Sequential(*upsample)
self.tail = base_conv2d(n_feat, 3, kernel_size*3)
def forward(self, x):
y_resid = y1 = self.head(x)
for res_block in self.res_blocks:
y_resid = res_block(y_resid)
y_resid = self.res_conv(y_resid)
y_resid += y1
y_ups = self.upsample(y_resid)
y = self.tail(y_ups)
return y
class SRDisc(nn.Module):
def __init__(self, W=128, n_feat=64, kernel_size=3, act=nn.LeakyReLU(0.2), bn=True):
super().__init__()
module_list = [
base_conv2d(3, n_feat, kernel_size, act=act),
base_conv2d(n_feat, n_feat, kernel_size, act=act, bn=bn)
]
for i in range(3):
module_list.append(base_conv2d(n_feat*2**i, n_feat*2**(i+1), kernel_size, act=act, bn=bn))
module_list.append(base_conv2d(n_feat*2**(i+1), n_feat*2**(i+1), kernel_size, stride=2, act=act, bn=bn))
module_list.append(base_conv2d(n_feat*2**(i+1), n_feat, kernel_size, act=act, bn=bn)) # sad GPU conv
module_list.append(nn.Flatten())
module_list.append(nn.Linear((W//2**3)**2*n_feat, 1024))
module_list.append(act)
module_list.append(nn.Linear(1024, 1))
module_list.append(nn.Sigmoid())
self.net = nn.Sequential(*module_list)
def forward(self, x):
return self.net(x)
class SRResNet(BaseModule):
def __init__(self, n_feat=64, kernel_size=3, n_blocks=16, lr_scale=4, act=nn.PReLU(), bn=True, **kwargs):
super().__init__(**kwargs)
self.gen = SRGen(n_feat, kernel_size, n_blocks, lr_scale, act, bn)
self.save_hyperparameters()
def forward(self, x):
return self.gen(x)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
return optimizer
class SRGAN(BaseModule):
def __init__(self, W=128, gen_init=None, n_feat=64, kernel_size=3, n_blocks=16, lr_scale=4, act=nn.PReLU(), bn=True,
loss=None, **kwargs):
super().__init__(**kwargs)
self.gen = SRGen(n_feat, kernel_size, n_blocks, lr_scale, act, bn)
self.disc = SRDisc(W=W, n_feat=n_feat, kernel_size=kernel_size, bn=bn)
# init generator
if gen_init is not None:
tmp_srresnet = SRResNet.load_from_checkpoint(gen_init)
self.gen.load_state_dict(tmp_srresnet.gen.state_dict())
self.mse = nn.MSELoss()
self.save_hyperparameters()
def forward(self, x):
return self.gen(x)
def gen_loss(self, lr, hr):
hr_hat = self(lr)
# content loss
loss = self.mse(hr_hat, hr)
# adversarial loss
probs = self.disc(hr_hat)
loss += 1e-3*torch.sum(-torch.log(probs))
return loss
def disc_loss(self, lr, hr):
# train on real
y_real = self.disc(hr)
real_loss = -torch.sum(torch.log(y_real))
# train on generated
hr_hat = self(lr)
y_gen = self.disc(hr_hat)
gen_loss = -torch.sum(torch.log(1 - y_gen))
return real_loss + gen_loss
def training_step(self, train_batch, batch_idx, optimizer_idx):
lr, hr = train_batch
lr, hr = torch_scale(lr.float(), self.range_base, self.range_in), torch_scale(hr.float(), self.range_base, self.range_out)
# train generator
result = None
if optimizer_idx == 0:
result = self.gen_loss(lr, hr)
self.log('train/g_loss', result)
# train discriminator
if optimizer_idx == 1:
result = self.disc_loss(lr, hr)
self.log('train/d_loss', result)
return result
def training_epoch_end(self, outputs):
return
def validation_step(self, val_batch, batch_idx):
lr, hr = val_batch
lr, hr = torch_scale(lr.float(), self.range_base, self.range_in), torch_scale(hr.float(), self.range_base, self.range_out)
hr_hat = self(lr)
res = {'loss': self.mse(hr_hat, hr)} # {'loss': self.gen_loss(lr, hr)}
hr_hat, hr = torch_scale(hr_hat, self.range_out, self.range_base), torch_scale(hr, self.range_out, self.range_base)
for key, loss in self.val_loss.items():
res[key] = loss(hr_hat, hr)
self.log('val_loss', res['loss'])
return res
def configure_optimizers(self):
opt_g = torch.optim.Adam(self.gen.parameters(), lr=1e-4)
opt_d = torch.optim.Adam(self.disc.parameters(), lr=1e-4)
return [opt_g, opt_d], []
\ No newline at end of file
......@@ -22,13 +22,13 @@ class ChannelAttention(nn.Module):
return x*y
class RCAB(nn.Module):
def __init__(self, n_feat, kernel_size, ca_reduction=16, act=nn.ReLU(inplace=True)):
def __init__(self, n_feat, kernel_size, ca_reduction=16, act=nn.ReLU(inplace=True), bn=False):
super().__init__()
self.conv = nn.Sequential(
base_conv2d(n_feat, n_feat, kernel_size),
base_conv2d(n_feat, n_feat, kernel_size, bn=bn),
act,
base_conv2d(n_feat, n_feat, kernel_size),
base_conv2d(n_feat, n_feat, kernel_size, bn=bn),
)
self.ca = ChannelAttention(n_feat, ca_reduction)
......@@ -112,28 +112,28 @@ class Cascades(nn.Module):
class UnetCascade(BaseModule):
def __init__(self,
unet_kernel_size=5, n_casc=4,
rcab_feat=128, rcab_kernel_size=3, rcab_reduction=16, act=nn.PReLU(),
lr_scale=4, loss=nn.L1Loss(), val_loss=None, batch_size=1, **kwargs):
rcab_feat=256, rcab_kernel_size=3, rcab_reduction=16, act=nn.PReLU(), bn=False,
lr_scale=4, loss=nn.L1Loss(), val_loss=None, **kwargs):
super().__init__(loss, val_loss, **kwargs)
self.first_upsample = base_upsample(lr_scale=lr_scale, upsample_mode='bicubic')
self.conv1 = base_conv2d(3, rcab_feat//4, unet_kernel_size, stride=1, act=act)
self.conv2 = base_conv2d(rcab_feat//4, rcab_feat//2, unet_kernel_size, stride=2, act=act)
self.conv3 = base_conv2d(rcab_feat//2, rcab_feat, unet_kernel_size, stride=2, act=act)
self.conv1 = base_conv2d(3, rcab_feat//4, unet_kernel_size, stride=1, act=act, bn=bn)
self.conv2 = base_conv2d(rcab_feat//4, rcab_feat//2, unet_kernel_size, stride=2, act=act, bn=bn)
self.conv3 = base_conv2d(rcab_feat//2, rcab_feat, unet_kernel_size, stride=2, act=act, bn=bn)
def _get_rcab():
return RCAB(rcab_feat, rcab_kernel_size, rcab_reduction, act)
return RCAB(rcab_feat, rcab_kernel_size, rcab_reduction, act, bn=bn)
def _get_rcab_casc():
rcabs = [_get_rcab() for i in range(n_casc)]
return Cascades(rcab_feat, rcabs)
cascs = [_get_rcab_casc() for i in range(n_casc)]
self.main_net = Cascades(rcab_feat, cascs)
self.upsample2 = self._upsample_block(rcab_feat, unet_kernel_size, act)
self.upsample2 = self._upsample_block(rcab_feat, unet_kernel_size, act, bn=bn)
self.catconv2 = ResidualCatConv(rcab_feat, rcab_feat//2)
self.upsample1 = self._upsample_block(rcab_feat//2, unet_kernel_size, act)
self.upsample1 = self._upsample_block(rcab_feat//2, unet_kernel_size, act, bn=bn)
self.catconv1 = ResidualCatConv(rcab_feat//2, rcab_feat//4)
self.last_conv = base_conv2d(rcab_feat//4, 3, unet_kernel_size, stride=1)
......@@ -163,9 +163,9 @@ class UnetCascade(BaseModule):
return [optimizer], [scheduler]
@staticmethod
def _upsample_block(in_channels, kernel, act):
def _upsample_block(in_channels, kernel, act, bn=False):
net = nn.Sequential(
nn.PixelShuffle(2),
base_conv2d(in_channels//4, in_channels//2, kernel, stride=1, act=act),
base_conv2d(in_channels//4, in_channels//2, kernel, stride=1, act=act, bn=bn),
)
return net
\ No newline at end of file
import json
import numpy as np
import torch
......@@ -16,3 +17,7 @@ def torch_scale(x, from_, to):
res = torch.clip(x, *from_)
return (res - from_[0])/(from_[1] - from_[0])*(to[1] - to[0]) + to[0]
def dump_json(d, file):
with open(file, 'w') as f:
json.dump(d, f, indent=4)
\ No newline at end of file
......@@ -117,7 +117,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
......@@ -193,19 +193,19 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 800/800 [00:11<00:00, 72.05it/s]\n"
"100%|██████████| 800/800 [00:10<00:00, 76.99it/s]\n"
]
}
],
"source": [
"crop_folder('data/datasets/DIV2K/DIV2K_train_LR_unknown/X4/', 'data/datasets/DIV2K/train_LR_120_unknown/', 120, 60)"
"crop_folder('data/datasets/DIV2K/DIV2K_train_LR_bicubic/X4/', 'data/datasets/DIV2K/train_LR_120_bicubic/', 120, 60)"
]
},
{
......@@ -227,19 +227,19 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 100/100 [00:02<00:00, 48.23it/s]\n"
"100%|██████████| 100/100 [00:01<00:00, 54.41it/s]\n"
]
}
],
"source": [
"crop_center_folder('data/datasets/DIV2K/DIV2K_valid_LR_unknown/X4/', 'data/datasets/DIV2K/valid_LR_250_unknown/', 250)"
"crop_center_folder('data/datasets/DIV2K/DIV2K_valid_LR_bicubic/X4/', 'data/datasets/DIV2K/valid_LR_250_bicubic/', 250)"
]
},
{
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment