Skip to content
Snippets Groups Projects
model.py 7.89 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from turtle import forward
import torch
from torch import nn
from torch.utils.data import DataLoader
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from torchinfo import summary
Kerstin Kaspar's avatar
Kerstin Kaspar committed


class UNet(nn.Module):
    """
    input              -skip(+)->  output
    block_d1 (1-64)    -skip(+)->  block_u1 (128-1)
        v pool                         ^ upconv 
    block_d2 (64-128)  -skip(+)->  block_u2 (256-128)
        v pool                         ^ upconv 
    block_d3 (128-256) -skip(+)->  block_u3 (512-256)
        v pool                         ^ upconv 
    block_d4 (256-512) -skip(+)->  block_u4 (1024-512)
        v pool                         ^ upconv 
                    block_b (512-1024)
    """
Kerstin Kaspar's avatar
Kerstin Kaspar committed

    def __init__(self,
                 conv_kernel_size: int = 3,
                 conv_padding: int = None,
                 conv_weight: tuple = (0, 1),
                 conv_bias: float = 0,
                 batch_norm_weight: float = 1,
                 batch_norm_bias: float = 0,
                 max_pool_kernel_size: int = 2
                 ) -> None:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        super().__init__()
        self.conv_weight = conv_weight
        self.conv_bias = conv_bias
        self.batch_norm_weight = batch_norm_weight
        self.batch_norm_bias = batch_norm_bias
        self.conv_kernel_size = conv_kernel_size
        self.max_pool_kernel_size = max_pool_kernel_size

        if not conv_padding:
            conv_padding = conv_kernel_size // 2
        self.conv_padding = conv_padding

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_d1 = self.block_2ch(in_channels=1, out_channels=32, return_list=True, max_pool=False)
        block_d1.extend(self.conv_block(in_channels=32, out_channels=32, kernel_size=conv_kernel_size, stride=1,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                        padding=conv_padding, batch_norm=True, relu=True))
        self.block_d1 = nn.Sequential(*block_d1)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_d2 = self.block_2ch(in_channels=32, out_channels=64)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_d3 = self.block_2ch(in_channels=64, out_channels=128)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_d4 = self.block_2ch(in_channels=128, out_channels=256)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

        block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_b.extend(self.block_3ch_t(in_channels=256, out_channels=256, block_channels=512, return_list=True))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_b = nn.Sequential(*block_b)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_u4 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_u3 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_u2 = self.block_3ch_t(in_channels=128, out_channels=32, block_channels=64)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_u1 = self.block_2ch(in_channels=64, out_channels=32, return_list=True, max_pool=False)
        block_u1.extend(self.conv_block(in_channels=32, out_channels=1, kernel_size=1, stride=1,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                        padding=0, batch_norm=False, relu=False))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_u1 = nn.Sequential(*block_u1)

    def forward(self, input):
        d1_out = self.block_d1(input)
        d2_out = self.block_d2(d1_out)
        d3_out = self.block_d3(d2_out)
        d4_out = self.block_d4(d3_out)
        b_out = self.block_b(d4_out)
        u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1))
        u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1))
        u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        result = u1_out + input
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        return result
Kerstin Kaspar's avatar
Kerstin Kaspar committed

    def conv_block(self, in_channels: int,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                   out_channels: int,
                   kernel_size: int,
                   stride: int,
                   padding: int,
                   batch_norm: bool = False,
                   relu: bool = False
                   ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block = []
        conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                               stride=stride, padding=padding)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        nn.init.normal_(conv_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_layer.bias, self.conv_bias)
        block.append(conv_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if relu:
            block.append(nn.ReLU())

        return block

    def conv_transpose_block(self, in_channels: int,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                             out_channels: int,
                             kernel_size: int,
                             stride: int,
                             padding: int,
                             output_padding: int,
                             batch_norm: bool = False,
                             relu: bool = False
                             ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block = []
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride, padding=padding, output_padding=output_padding)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_transpose_layer.bias, self.conv_bias)
        block.append(conv_transpose_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if relu:
            block.append(nn.ReLU())

        return block

Kerstin Kaspar's avatar
Kerstin Kaspar committed
    def block_3ch_t(self,
                    in_channels: int,
                    out_channels: int,
                    block_channels: int = None,
                    return_list=False
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                    ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if not block_channels:
            block_channels = in_channels // 2
        block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels,
                                               kernel_size=self.conv_kernel_size, stride=2,
                                               padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)

    def block_2ch(self,
                  in_channels: int,
                  out_channels: int,
                  return_list=False,
                  max_pool=True,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                  ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if max_pool:
            block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)]
        else:
            block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)

Kerstin Kaspar's avatar
Kerstin Kaspar committed

if __name__ == '__main__':
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    import torch.nn.functional as F
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    img = torch.rand((1, 362, 362))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    new_data = F.pad(input=img, pad=(0, 6, 6, 0), mode='constant', value=0)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    model = UNet()
Kerstin Kaspar's avatar
Kerstin Kaspar committed
    model(new_data)
    summary(model, input_data=new_data)