from turtle import forward
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchinfo import summary


class UNet(nn.Module):
    """
    input              -skip(+)->  output
    block_d1 (1-64)    -skip(+)->  block_u1 (128-1)
        v pool                         ^ upconv 
    block_d2 (64-128)  -skip(+)->  block_u2 (256-128)
        v pool                         ^ upconv 
    block_d3 (128-256) -skip(+)->  block_u3 (512-256)
        v pool                         ^ upconv 
    block_d4 (256-512) -skip(+)->  block_u4 (1024-512)
        v pool                         ^ upconv 
                    block_b (512-1024)
    """

    def __init__(self,
                 conv_kernel_size: int = 3,
                 conv_padding: int = None,
                 conv_weight: tuple = (0, 1),
                 conv_bias: float = 0,
                 batch_norm_weight: float = 1,
                 batch_norm_bias: float = 0,
                 max_pool_kernel_size: int = 2
                 ) -> None:
        super().__init__()
        self.conv_weight = conv_weight
        self.conv_bias = conv_bias
        self.batch_norm_weight = batch_norm_weight
        self.batch_norm_bias = batch_norm_bias
        self.conv_kernel_size = conv_kernel_size
        self.max_pool_kernel_size = max_pool_kernel_size

        if not conv_padding:
            conv_padding = conv_kernel_size // 2
        self.conv_padding = conv_padding

        block_d1 = self.block_2ch(in_channels=1, out_channels=32, return_list=True, max_pool=False)
        block_d1.extend(self.conv_block(in_channels=32, out_channels=32, kernel_size=conv_kernel_size, stride=1,
                                        padding=conv_padding, batch_norm=True, relu=True))
        self.block_d1 = nn.Sequential(*block_d1)

        self.block_d2 = self.block_2ch(in_channels=32, out_channels=64)

        self.block_d3 = self.block_2ch(in_channels=64, out_channels=128)

        self.block_d4 = self.block_2ch(in_channels=128, out_channels=256)

        block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
        block_b.extend(self.block_3ch_t(in_channels=256, out_channels=256, block_channels=512, return_list=True))
        self.block_b = nn.Sequential(*block_b)

        self.block_u4 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256)

        self.block_u3 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128)

        self.block_u2 = self.block_3ch_t(in_channels=128, out_channels=32, block_channels=64)

        block_u1 = self.block_2ch(in_channels=64, out_channels=32, return_list=True, max_pool=False)
        block_u1.extend(self.conv_block(in_channels=32, out_channels=1, kernel_size=1, stride=1,
                                        padding=0, batch_norm=False, relu=False))
        self.block_u1 = nn.Sequential(*block_u1)

    def forward(self, input):
        d1_out = self.block_d1(input)
        d2_out = self.block_d2(d1_out)
        d3_out = self.block_d3(d2_out)
        d4_out = self.block_d4(d3_out)
        b_out = self.block_b(d4_out)
        u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1))
        u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1))
        u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1))
        u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1))
        result = u1_out + input
        print(f'input: {input.size()}')
        print(f'd1_out: {d1_out.size()}')
        print(f'd2_out: {d2_out.size()}')
        print(f'd3_out: {d3_out.size()}')
        print(f'd4_out: {d4_out.size()}')
        print(f'b_out: {b_out.size()}')
        print(f'u4_out: {u4_out.size()}')
        print(f'u3_out: {u3_out.size()}')
        print(f'u2_out: {u2_out.size()}')
        print(f'u1_out: {u1_out.size()}')
        print(f'result: {result.size()}')
        return result

    def conv_block(self, in_channels: int,
                   out_channels: int,
                   kernel_size: int,
                   stride: int,
                   padding: int,
                   batch_norm: bool = False,
                   relu: bool = False
                   ) -> list:
        block = []
        conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding)
        nn.init.normal_(conv_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_layer.bias, self.conv_bias)
        block.append(conv_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)

        if relu:
            block.append(nn.ReLU())

        return block

    def conv_transpose_block(self, in_channels: int,
                             out_channels: int,
                             kernel_size: int,
                             stride: int,
                             padding: int,
                             output_padding: int,
                             batch_norm: bool = False,
                             relu: bool = False
                             ) -> list:
        block = []
        conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
                                                  kernel_size=kernel_size,
                                                  stride=stride, padding=padding, output_padding=output_padding)
        nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight)
        nn.init.constant_(conv_transpose_layer.bias, self.conv_bias)
        block.append(conv_transpose_layer)

        if batch_norm:
            batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
            nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
            nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
            block.append(batch_norm_layer)

        if relu:
            block.append(nn.ReLU())

        return block

    def block_3ch_t(self,
                    in_channels: int,
                    out_channels: int,
                    block_channels: int = None,
                    return_list=False
                    ) -> list:
        if not block_channels:
            block_channels = in_channels // 2
        block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels,
                                               kernel_size=self.conv_kernel_size, stride=2,
                                               padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)

    def block_2ch(self,
                  in_channels: int,
                  out_channels: int,
                  return_list=False,
                  max_pool=True,
                  ) -> list:
        if max_pool:
            block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)]
        else:
            block = []
        block.extend(
            self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        block.extend(
            self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
                            stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
        if return_list:
            return block
        else:
            return nn.Sequential(*block)


if __name__ == '__main__':
    import torch.nn.functional as F
    img = torch.rand(( 1, 362, 362))
    new_data = F.pad(input=img, pad=(0, 6, 6, 0), mode='constant', value=0)
    model = UNet()
    model(new_data)
    summary(model, input_data=new_data)