Newer
Older
from turtle import forward
import torch
from torch import nn
from torch.utils.data import DataLoader
class UNet(nn.Module):
"""
input -skip(+)-> output
block_d1 (1-64) -skip(+)-> block_u1 (128-1)
v pool ^ upconv
block_d2 (64-128) -skip(+)-> block_u2 (256-128)
v pool ^ upconv
block_d3 (128-256) -skip(+)-> block_u3 (512-256)
v pool ^ upconv
block_d4 (256-512) -skip(+)-> block_u4 (1024-512)
v pool ^ upconv
block_b (512-1024)
"""
def __init__(self,
conv_kernel_size: int = 3,
conv_padding: int = None,
conv_weight: tuple = (0, 1),
conv_bias: float = 0,
batch_norm_weight: float = 1,
batch_norm_bias: float = 0,
max_pool_kernel_size: int = 2
) -> None:
super().__init__()
self.conv_weight = conv_weight
self.conv_bias = conv_bias
self.batch_norm_weight = batch_norm_weight
self.batch_norm_bias = batch_norm_bias
self.conv_kernel_size = conv_kernel_size
self.max_pool_kernel_size = max_pool_kernel_size
if not conv_padding:
conv_padding = conv_kernel_size // 2
self.conv_padding = conv_padding
block_d1 = self.block_2ch(in_channels=1, out_channels=64, return_list=True, max_pool=False)
block_d1.extend(self.conv_block(in_channels=64, out_channels=64, kernel_size=conv_kernel_size, stride=1,
padding=conv_padding, batch_norm=True, relu=True))
self.block_d1 = nn.Sequential(*block_d1)
self.block_d2 = self.block_2ch(in_channels=64, out_channels=128)
self.block_d3 = self.block_2ch(in_channels=128, out_channels=256)
self.block_d4 = self.block_2ch(in_channels=256, out_channels=512)
block_b = [nn.MaxPool2d(kernel_size=max_pool_kernel_size)]
block_b.extend(self.block_3ch_t(in_channels=512, out_channels=512, block_channels=1024, return_list=True))
self.block_b = nn.Sequential(*block_b)
self.block_u4 = self.block_3ch_t(in_channels=1024, out_channels=256, block_channels=512)
self.block_u3 = self.block_3ch_t(in_channels=512, out_channels=128, block_channels=256)
self.block_u2 = self.block_3ch_t(in_channels=256, out_channels=64, block_channels=128)
block_u1 = self.block_2ch(in_channels=128, out_channels=64, return_list=True, max_pool=False)
block_u1.extend(self.conv_block(in_channels=64, out_channels=1, kernel_size=1, stride=1,
self.block_u1 = nn.Sequential(*block_u1)
def forward(self, input):
d1_out = self.block_d1(input)
d2_out = self.block_d2(d1_out)
d3_out = self.block_d3(d2_out)
d4_out = self.block_d4(d3_out)
b_out = self.block_b(d4_out)
u4_out = self.block_u4(torch.cat((d4_out, b_out), dim=1))
u3_out = self.block_u3(torch.cat((d3_out, u4_out), dim=1))
u2_out = self.block_u2(torch.cat((d2_out, u3_out), dim=1))
u1_out = self.block_u1(torch.cat((d1_out, u2_out), dim=1))
result = u1_out + input
print(f'input: {input.size()}')
print(f'd1_out: {d1_out.size()}')
print(f'd2_out: {d2_out.size()}')
print(f'd3_out: {d3_out.size()}')
print(f'd4_out: {d4_out.size()}')
print(f'b_out: {b_out.size()}')
print(f'u4_out: {u4_out.size()}')
print(f'u3_out: {u3_out.size()}')
print(f'u2_out: {u2_out.size()}')
print(f'u1_out: {u1_out.size()}')
print(f'result: {result.size()}')
out_channels: int,
kernel_size: int,
stride: int,
padding: int,
batch_norm: bool = False,
relu: bool = False
) -> list:
block = []
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
nn.init.normal_(conv_layer.weight, *self.conv_weight)
nn.init.constant_(conv_layer.bias, self.conv_bias)
block.append(conv_layer)
if batch_norm:
batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
block.append(batch_norm_layer)
if relu:
block.append(nn.ReLU())
return block
def conv_transpose_block(self, in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
padding: int,
output_padding: int,
batch_norm: bool = False,
relu: bool = False
) -> list:
conv_transpose_layer = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size,
stride=stride, padding=padding, output_padding=output_padding)
nn.init.normal_(conv_transpose_layer.weight, *self.conv_weight)
nn.init.constant_(conv_transpose_layer.bias, self.conv_bias)
block.append(conv_transpose_layer)
if batch_norm:
batch_norm_layer = nn.BatchNorm2d(num_features=out_channels)
nn.init.constant_(batch_norm_layer.weight, self.batch_norm_weight)
nn.init.constant_(batch_norm_layer.bias, self.batch_norm_bias)
block.append(batch_norm_layer)
if relu:
block.append(nn.ReLU())
return block
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def block_3ch_t(self,
in_channels: int,
out_channels: int,
block_channels: int = None,
return_list=False
) -> list | nn.Sequential:
if not block_channels:
block_channels = in_channels // 2
block = []
block.extend(
self.conv_block(in_channels=in_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(
self.conv_block(in_channels=block_channels, out_channels=block_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(self.conv_transpose_block(in_channels=block_channels, out_channels=out_channels,
kernel_size=self.conv_kernel_size, stride=2,
padding=self.conv_padding, output_padding=1, batch_norm=True, relu=True))
if return_list:
return block
else:
return nn.Sequential(*block)
def block_2ch(self,
in_channels: int,
out_channels: int,
return_list=False,
max_pool=True,
) -> list | nn.Sequential:
if max_pool:
block = [nn.MaxPool2d(kernel_size=self.max_pool_kernel_size)]
else:
block = []
block.extend(
self.conv_block(in_channels=in_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
block.extend(
self.conv_block(in_channels=out_channels, out_channels=out_channels, kernel_size=self.conv_kernel_size,
stride=1, padding=self.conv_padding, batch_norm=True, relu=True))
if return_list:
return block
else:
return nn.Sequential(*block)
if __name__ == '__main__':
img = torch.rand((1, 1, 512, 512))
model = UNet()
model(img)