Commit ad9e70f0 authored by Matej Choma's avatar Matej Choma

wip gans

parent 9b826c0f
......@@ -15,13 +15,16 @@ class BaseModule(pl.LightningModule):
self.range_out = range_out
self.range_base = range_base
def predict(self, x):
def predict(self, x, ret_np=True):
_x = torch.tensor(x, device=torch.device('cuda' if self.on_gpu else 'cpu'))[None, ...].float()
_x = torch_scale(_x, self.range_base, self.range_in)
pred = self(_x)
pred = torch_scale(pred, self.range_out, self.range_base).cpu().detach().numpy()[0]
return pred.astype('uint8')
if ret_np is True:
pred = torch_scale(pred, self.range_out, self.range_base).cpu().detach().numpy()[0]
pred = pred.astype('uint8')
return pred
# --
......@@ -80,7 +83,7 @@ class BaseModule(pl.LightningModule):
hr_hat = self(lr)
res = {'loss': self.loss(hr_hat, hr)}
res = {}
hr_hat, hr = torch_scale(hr_hat, self.range_out, self.range_base), torch_scale(hr, self.range_out, self.range_base)
for key, loss in self.val_loss.items():
res[key] = loss(hr_hat, hr)
......@@ -101,9 +104,11 @@ class BaseModule(pl.LightningModule):
# -----------------------------------------------------------------------------
def base_conv2d(in_channels, out_channels, kernel_size, stride=1, bias=True, act=None):
def base_conv2d(in_channels, out_channels, kernel_size, stride=1, bias=True, act=None, bn=False):
module_list = [nn.Conv2d(in_channels, out_channels, kernel_size,
stride=stride, padding=(kernel_size//2), bias=bias)]
if bn is True:
if act is not None:
import torch
from torch import nn
from core.base import base_conv2d, base_upsample, ResidualCatConv
import math
class ChannelAttention(nn.Module):
def __init__(self, channel, reduction=16):
# global average pooling: feature --> point
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# feature channel downscale and upscale --> channel weight
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x*y
class RCAB(nn.Module):
def __init__(self, n_feat, kernel_size, ca_reduction=16, act=nn.ReLU(inplace=True)):
self.conv = nn.Sequential(
base_conv2d(n_feat, n_feat, kernel_size),
base_conv2d(n_feat, n_feat, kernel_size),
) = ChannelAttention(n_feat, ca_reduction)
def forward(self, x):
y = self.conv(x)
y =
return y + x
class ResidualGroup(nn.Module):
def __init__(self, n_feat, kernel_size, ca_reduction, act, n_rcab):
module_list = [RCAB(n_feat, kernel_size, ca_reduction, act) for _ in range(n_rcab)]
module_list.append(base_conv2d(n_feat, n_feat, kernel_size))
self.rg = nn.Sequential(*module_list)
def forward(self, x):
y = self.rg(x)
return y + x
class RCAN(BaseModule):
def __init__(self, n_rg, n_rcab, n_feat, kernel_size=3, ca_reduction=16,
act=nn.ReLU(inplace=True), lr_scale=4,
loss=nn.L1Loss(), val_loss=None, batch_size=1, **kwargs):
super().__init__(loss, val_loss, **kwargs)
self.conv_first = base_conv2d(3, n_feat, kernel_size)
# residual in residual
module_list = [ResidualGroup(n_feat, kernel_size, ca_reduction, act, n_rcab)
for _ in range(n_rg)]
module_list.append(base_conv2d(n_feat, n_feat, kernel_size))
self.rir = nn.Sequential(*module_list)
# upsample module
if (lr_scale & (lr_scale - 1)) != 0: raise NotImplementedError # lr_scale != 2^n
upsample_list = []
for _ in range(int(math.log(lr_scale, 2))):
upsample_list.append(base_conv2d(n_feat, n_feat*4, 3))
upsample_list.append(base_conv2d(n_feat, 3, kernel_size))
self.upsample = nn.Sequential(*upsample_list)
def forward(self, x):
f0 = self.conv_first(x)
fdf = self.rir(f0) + f0
y = self.upsample(fdf)
return y
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10, gamma=0.5, verbose=True) # !!! step vs epoch?
return [optimizer], [scheduler]
class Cascades(nn.Module):
def __init__(self, n_feat, mods):
catconvs = []
self.mods = nn.ModuleList(mods)
for i in range(2, len(mods)+2):
catconvs.append(ResidualCatConv(i*n_feat, n_feat))
self.catconvs = nn.ModuleList(catconvs)
def forward(self, x):
inp = [x]
out = [x]
for mod, cat in zip(self.mods, self.catconvs):
return out[-1]
class UnetCascade(BaseModule):
def __init__(self,
unet_kernel_size=5, n_casc=4,
rcab_feat=128, rcab_kernel_size=3, rcab_reduction=16, act=nn.PReLU(),
lr_scale=4, loss=nn.L1Loss(), val_loss=None, batch_size=1, **kwargs):
super().__init__(loss, val_loss, **kwargs)
self.first_upsample = base_upsample(lr_scale=lr_scale, upsample_mode='bicubic')
self.conv1 = base_conv2d(3, rcab_feat//4, unet_kernel_size, stride=1, act=act)
self.conv2 = base_conv2d(rcab_feat//4, rcab_feat//2, unet_kernel_size, stride=2, act=act)
self.conv3 = base_conv2d(rcab_feat//2, rcab_feat, unet_kernel_size, stride=2, act=act)
def _get_rcab():
return RCAB(rcab_feat, rcab_kernel_size, rcab_reduction, act)
def _get_rcab_casc():
rcabs = [_get_rcab() for i in range(n_casc)]
return Cascades(rcab_feat, rcabs)
cascs = [_get_rcab_casc() for i in range(n_casc)]
self.main_net = Cascades(rcab_feat, cascs)
self.upsample2 = self._upsample_block(rcab_feat, unet_kernel_size, act)
self.catconv2 = ResidualCatConv(rcab_feat, rcab_feat//2)
self.upsample1 = self._upsample_block(rcab_feat//2, unet_kernel_size, act)
self.catconv1 = ResidualCatConv(rcab_feat//2, rcab_feat//4)
self.last_conv = base_conv2d(rcab_feat//4, 3, unet_kernel_size, stride=1)
def forward(self, x):
u = self.first_upsample(x)
t1 = self.conv1(u)
t2 = self.conv2(t1)
res = self.conv3(t2)
res = self.main_net(res)
res = self.upsample2(res)
res = self.catconv2(res, t2)
res = self.upsample1(res)
res = self.catconv1(res, t1)
return self.last_conv(res)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50, gamma=0.5)
return [optimizer], [scheduler]
def _upsample_block(in_channels, kernel, act):
net = nn.Sequential(
base_conv2d(in_channels//4, in_channels//2, kernel, stride=1, act=act),
return net
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment