Skip to content
Snippets Groups Projects
model.py 8.19 KiB
Newer Older
Kerstin Kaspar's avatar
Kerstin Kaspar committed
from turtle import forward
import torch
from torch import nn
from torch.utils.data import DataLoader


class UNet(nn.Module):
    """
    input              -skip(+)->  output
    block_d1 (1-64)    -skip(+)->  block_u1 (128-1)
        v pool                         ^ upconv 
    block_d2 (64-128)  -skip(+)->  block_u2 (256-128)
        v pool                         ^ upconv 
    block_d3 (128-256) -skip(+)->  block_u3 (512-256)
        v pool                         ^ upconv 
    block_d4 (256-512) -skip(+)->  block_u4 (1024-512)
        v pool                         ^ upconv 
                    block_b (512-1024)
    """
Kerstin Kaspar's avatar
Kerstin Kaspar committed

    def __init__(self,
                 conv_kernel_size: int = 3,
                 conv_padding: int = None,
                 conv_weight: tuple = (0, 1),
                 conv_bias: float = 0,
                 batch_norm_weight: float = 1,
                 batch_norm_bias: float = 0,
                 max_pool_kernel_size: int = 2
                 ) -> None:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        super().__init__()
        self.conv_weight = conv_weight
        self.conv_bias = conv_bias
        self.batch_norm_weight = batch_norm_weight
        self.batch_norm_bias = batch_norm_bias
        self.conv_kernel_size = conv_kernel_size
        self.max_pool_kernel_size = max_pool_kernel_size

        if not conv_padding:
            conv_padding = conv_kernel_size // 2
        self.conv_padding = conv_padding

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_d1 = self.block_2ch(in_channels=1, out_channels=64, return_list=True, max_pool=False)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_d1.extend(self.conv_block(in_channels=64, out_channels=64, kernel_size=conv_kernel_size, stride=1,
                                        padding=conv_padding, batch_norm=True, relu=True))
        self.block_d1 = nn.Sequential(*block_d1)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_d2 = self.block_2ch(in_channels=64, out_channels=128)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

        self.block_d3 = self.block_2ch(in_channels=128, out_channels=256)

        self.block_d4 = self.block_2ch(in_channels=256, out_channels=512)

        block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
        block_b.extend(self.block_3ch_t(in_channels=512, out_channels=512, block_channels=1024, return_list=True))
        self.block_b = nn.Sequential(*block_b)

        self.block_u4 = self.block_3ch_t(in_channels=1024, out_channels=256, block_channels=512)

        self.block_u3 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256)

        self.block_u2 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128)

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_u1 = self.block_2ch(in_channels=128, out_channels=64, return_list=True, max_pool=False)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block_u1.extend(self.conv_block(in_channels=64, out_channels=1, kernel_size=1, stride=1,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                                        padding=0, batch_norm=False, relu=False))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        self.block_u1 = nn.Sequential(*block_u1)

    def forward(self, input):
        d1_out = self.block_d1(input)
        d2_out = self.block_d2(d1_out)
        d3_out = self.block_d3(d2_out)
        d4_out = self.block_d4(d3_out)
        b_out = self.block_b(d4_out)
        u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1))
        u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1))
        u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1))
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        result = u1_out + input
        print(f'input: {input.size()}')
        print(f'd1_out: {d1_out.size()}')
        print(f'd2_out: {d2_out.size()}')
        print(f'd3_out: {d3_out.size()}')
        print(f'd4_out: {d4_out.size()}')
        print(f'b_out: {b_out.size()}')
        print(f'u4_out: {u4_out.size()}')
        print(f'u3_out: {u3_out.size()}')
        print(f'u2_out: {u2_out.size()}')
        print(f'u1_out: {u1_out.size()}')
        print(f'result: {result.size()}')
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        return result
Kerstin Kaspar's avatar
Kerstin Kaspar committed

    def conv_block(self, in_channels: int,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                   out_channels: int,
                   kernel_size: int,
                   stride: int,
                   padding: int,
                   batch_norm: bool = False,
                   relu: bool = False
                   ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block = []
        conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                               stride=stride, padding=padding)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        nn.init.normal_(conv_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_layer.bias, self.conv_bias)
        block.append(conv_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if relu:
            block.append(nn.ReLU())

        return block

    def conv_transpose_block(self, in_channels: int,
Kerstin Kaspar's avatar
Kerstin Kaspar committed
                             out_channels: int,
                             kernel_size: int,
                             stride: int,
                             padding: int,
                             output_padding: int,
                             batch_norm: bool = False,
                             relu: bool = False
                             ) -> list:
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        block = []
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride, padding=padding, output_padding=output_padding)
Kerstin Kaspar's avatar
Kerstin Kaspar committed
        nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_transpose_layer.bias, self.conv_bias)
        block.append(conv_transpose_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)
Kerstin Kaspar's avatar
Kerstin Kaspar committed

Kerstin Kaspar's avatar
Kerstin Kaspar committed
        if relu:
            block.append(nn.ReLU())

        return block

Kerstin Kaspar's avatar
Kerstin Kaspar committed
    def block_3ch_t(self,
                    in_channels: int,
                    out_channels: int,
                    block_channels: int = None,
                    return_list=False
                    ) -> list | nn.Sequential:
        if not block_channels:
            block_channels = in_channels // 2
        block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels,
                                               kernel_size=self.conv_kernel_size, stride=2,
                                               padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)

    def block_2ch(self,
                  in_channels: int,
                  out_channels: int,
                  return_list=False,
                  max_pool=True,
                  ) -> list | nn.Sequential:
        if max_pool:
            block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)]
        else:
            block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)

Kerstin Kaspar's avatar
Kerstin Kaspar committed

if __name__ == '__main__':
    img = torch.rand((1, 1, 512, 512))
    model = UNet()
    model(img)