import torch.nn as nn def normal_init(m, scale = 0.01): """ Initialize weights of layer `m` with `scale` as standard deviation. Biases will be set to 0. :param m: A torch.nn object *Example*: net = nn.Sequential(nn.Linear(1,2), nn.Linear(2,1)) net.apply(normal_init) """ classname = m.__class__.__name__ # only initialize for Linear or Conv layers if classname.find('Linear') != -1 or classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, scale) nn.init.constant_(m.bias.data, 0.0) def glorot_init(m, gain = 1): """ Initialize weights of layer `m` via `nn.init.xavier_uniform_(m, gain)` and biases with 0. :param m: A torch.nn object *Example*: net = nn.Sequential(nn.Linear(1,2), nn.Linear(2,1)) net.apply(glorot_init) """ classname = m.__class__.__name__ # only initialize for Linear or Conv layers if classname.find('Linear') != -1 or classname.find('Conv') != -1: nn.init.xavier_uniform_(m.weight.data) nn.init.constant_(m.bias.data, 0.0)