from turtle import forward import torch from torch import nn from torch.utils.data import DataLoader from torchinfo import summary class UNet(nn.Module): """ input -skip(+)-> output block_d1 (1-64) -skip(+)-> block_u1 (128-1) v pool ^ upconv block_d2 (64-128) -skip(+)-> block_u2 (256-128) v pool ^ upconv block_d3 (128-256) -skip(+)-> block_u3 (512-256) v pool ^ upconv block_d4 (256-512) -skip(+)-> block_u4 (1024-512) v pool ^ upconv block_b (512-1024) """ def __init__(self, conv_kernel_size: int = 3, conv_padding: int = None, conv_weight: tuple = (0, 1), conv_bias: float = 0, batch_norm_weight: float = 1, batch_norm_bias: float = 0, max_pool_kernel_size: int = 2 ) -> None: super().__init__() self.conv_weight = conv_weight self.conv_bias = conv_bias self.batch_norm_weight = batch_norm_weight self.batch_norm_bias = batch_norm_bias self.conv_kernel_size = conv_kernel_size self.max_pool_kernel_size = max_pool_kernel_size if not conv_padding: conv_padding = conv_kernel_size // 2 self.conv_padding = conv_padding block_d1 = self.block_2ch(in_channels=1, out_channels=32, return_list=True, max_pool=False) block_d1.extend(self.conv_block(in_channels=32, out_channels=32, kernel_size=conv_kernel_size, stride=1, padding=conv_padding, batch_norm=True, relu=True)) self.block_d1 = nn.Sequential(*block_d1) self.block_d2 = self.block_2ch(in_channels=32, out_channels=64) self.block_d3 = self.block_2ch(in_channels=64, out_channels=128) self.block_d4 = self.block_2ch(in_channels=128, out_channels=256) block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)] block_b.extend(self.block_3ch_t(in_channels=256, out_channels=256, block_channels=512, return_list=True)) self.block_b = nn.Sequential(*block_b) self.block_u4 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256) self.block_u3 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128) self.block_u2 = self.block_3ch_t(in_channels=128, out_channels=32, block_channels=64) block_u1 = self.block_2ch(in_channels=64, out_channels=32, return_list=True, max_pool=False) block_u1.extend(self.conv_block(in_channels=32, out_channels=1, kernel_size=1, stride=1, padding=0, batch_norm=False, relu=False)) self.block_u1 = nn.Sequential(*block_u1) def forward(self, input): d1_out = self.block_d1(input) d2_out = self.block_d2(d1_out) d3_out = self.block_d3(d2_out) d4_out = self.block_d4(d3_out) b_out = self.block_b(d4_out) u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1)) u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1)) u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1)) u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1)) result = u1_out + input print(f'input: {input.size()}') print(f'd1_out: {d1_out.size()}') print(f'd2_out: {d2_out.size()}') print(f'd3_out: {d3_out.size()}') print(f'd4_out: {d4_out.size()}') print(f'b_out: {b_out.size()}') print(f'u4_out: {u4_out.size()}') print(f'u3_out: {u3_out.size()}') print(f'u2_out: {u2_out.size()}') print(f'u1_out: {u1_out.size()}') print(f'result: {result.size()}') return result def conv_block(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, batch_norm: bool = False, relu: bool = False ) -> list: block = [] conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding) nn.init.normal_(conv_layer.weight, *self.conv_weight) nn.init.constant_(conv_layer.bias, self.conv_bias) block.append(conv_layer) if batch_norm: batch_norm_layer = nn.BatchNorm2d(num_features=out_channels) nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight) nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias) block.append(batch_norm_layer) if relu: block.append(nn.ReLU()) return block def conv_transpose_block(self, in_channels: int, out_channels: int, kernel_size: int, stride: int, padding: int, output_padding: int, batch_norm: bool = False, relu: bool = False ) -> list: block = [] conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding) nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight) nn.init.constant_(conv_transpose_layer.bias, self.conv_bias) block.append(conv_transpose_layer) if batch_norm: batch_norm_layer = nn.BatchNorm2d(num_features=out_channels) nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight) nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias) block.append(batch_norm_layer) if relu: block.append(nn.ReLU()) return block def block_3ch_t(self, in_channels: int, out_channels: int, block_channels: int = None, return_list=False ) -> list: if not block_channels: block_channels = in_channels // 2 block = [] block.extend( self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size, stride=1, padding=self.conv_padding, batch_norm=True, relu=True)) block.extend( self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size, stride=1, padding=self.conv_padding, batch_norm=True, relu=True)) block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size, stride=2, padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True)) if return_list: return block else: return nn.Sequential(*block) def block_2ch(self, in_channels: int, out_channels: int, return_list=False, max_pool=True, ) -> list: if max_pool: block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)] else: block = [] block.extend( self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size, stride=1, padding=self.conv_padding, batch_norm=True, relu=True)) block.extend( self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size, stride=1, padding=self.conv_padding, batch_norm=True, relu=True)) if return_list: return block else: return nn.Sequential(*block) if __name__ == '__main__': import torch.nn.functional as F img = torch.rand(( 1, 362, 362)) new_data = F.pad(input=img, pad=(0, 6, 6, 0), mode='constant', value=0) model = UNet() model(new_data) summary(model, input_data=new_data)