Newer
Older
from turtle import forward
import torch
from torch import nn
from torch.utils.data import DataLoader
class UNet(nn.Module):
"""
input -skip(+)-> output
block_d1 (1-64) -skip(+)-> block_u1 (128-1)
v pool ^ upconv
block_d2 (64-128) -skip(+)-> block_u2 (256-128)
v pool ^ upconv
block_d3 (128-256) -skip(+)-> block_u3 (512-256)
v pool ^ upconv
block_d4 (256-512) -skip(+)-> block_u4 (1024-512)
v pool ^ upconv
block_b (512-1024)
"""
def __init__(self,
conv_kernel_size: int = 3,
conv_padding: int = None,
conv_weight: tuple = (0, 1),
conv_bias: float = 0,
batch_norm_weight: float = 1,
batch_norm_bias: float = 0,
max_pool_kernel_size: int = 2
) -> None:
super().__init__()
self.conv_weight = conv_weight
self.conv_bias = conv_bias
self.batch_norm_weight = batch_norm_weight
self.batch_norm_bias = batch_norm_bias
self.conv_kernel_size = conv_kernel_size
self.max_pool_kernel_size = max_pool_kernel_size
if not conv_padding:
conv_padding = conv_kernel_size // 2
self.conv_padding = conv_padding
block_d1 = self.block_2ch(in_channels=1, out_channels=32, return_list=True, max_pool=False)
block_d1.extend(self.conv_block(in_channels=32, out_channels=32, kernel_size=conv_kernel_size, stride=1,
padding=conv_padding, batch_norm=True, relu=True))
self.block_d1 = nn.Sequential(*block_d1)
self.block_d2 = self.block_2ch(in_channels=32, out_channels=64)
self.block_d3 = self.block_2ch(in_channels=64, out_channels=128)
self.block_d4 = self.block_2ch(in_channels=128, out_channels=256)
block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
block_b.extend(self.block_3ch_t(in_channels=256, out_channels=256, block_channels=512, return_list=True))
self.block_u4 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256)
self.block_u3 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128)
self.block_u2 = self.block_3ch_t(in_channels=128, out_channels=32, block_channels=64)
block_u1 = self.block_2ch(in_channels=64, out_channels=32, return_list=True, max_pool=False)
block_u1.extend(self.conv_block(in_channels=32, out_channels=1, kernel_size=1, stride=1,
self.block_u1 = nn.Sequential(*block_u1)
def forward(self, input):
d1_out = self.block_d1(input)
d2_out = self.block_d2(d1_out)
d3_out = self.block_d3(d2_out)
d4_out = self.block_d4(d3_out)
b_out = self.block_b(d4_out)
u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1))
u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1))
u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1))
u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1))
result = u1_out + input
print(f'input: {input.size()}')
print(f'd1_out: {d1_out.size()}')
print(f'd2_out: {d2_out.size()}')
print(f'd3_out: {d3_out.size()}')
print(f'd4_out: {d4_out.size()}')
print(f'b_out: {b_out.size()}')
print(f'u4_out: {u4_out.size()}')
print(f'u3_out: {u3_out.size()}')
print(f'u2_out: {u2_out.size()}')
print(f'u1_out: {u1_out.size()}')
print(f'result: {result.size()}')
out_channels: int,
kernel_size: int,
stride: int,
padding: int,
batch_norm: bool = False,
relu: bool = False
) -> list:
block = []
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
nn.init.normal_(conv_layer.weight, *self.conv_weight)
nn.init.constant_(conv_layer.bias, self.conv_bias)
block.append(conv_layer)
if batch_norm:
batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
block.append(batch_norm_layer)
if relu:
block.append(nn.ReLU())
return block
def conv_transpose_block(self, in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int,
output_padding: int,
batch_norm: bool = False,
relu: bool = False
) -> list:
conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=output_padding)
nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight)
nn.init.constant_(conv_transpose_layer.bias, self.conv_bias)
block.append(conv_transpose_layer)
if batch_norm:
batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
block.append(batch_norm_layer)
if relu:
block.append(nn.ReLU())
return block
def block_3ch_t(self,
in_channels: int,
out_channels: int,
block_channels: int = None,
return_list=False
if not block_channels:
block_channels = in_channels // 2
block = []
block.extend(
self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(
self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels,
kernel_size=self.conv_kernel_size, stride=2,
padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True))
if return_list:
return block
else:
return nn.Sequential(*block)
def block_2ch(self,
in_channels: int,
out_channels: int,
return_list=False,
max_pool=True,
if max_pool:
block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)]
else:
block = []
block.extend(
self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(
self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
if return_list:
return block
else:
return nn.Sequential(*block)
import torch.nn.functional as F
img = torch.rand(( 1, 362, 362))
new_data = F.pad(input=img, pad=(0, 6, 6, 0), mode='constant', value=0)