import torch.nn as nn

def normal_init(m, scale = 0.01):
    """
    Initialize weights of layer `m` with `scale` as standard
    deviation.
    Biases will be set to 0.
    :param m: A torch.nn object

    *Example*:
    net = nn.Sequential(nn.Linear(1,2), nn.Linear(2,1))
    net.apply(normal_init)
    """
    classname = m.__class__.__name__
    # only initialize for Linear or Conv layers
    if classname.find('Linear') != -1 or classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, scale)
        nn.init.constant_(m.bias.data, 0.0)


def glorot_init(m, gain = 1):
    """
    Initialize weights of layer `m` via `nn.init.xavier_uniform_(m, gain)`
    and biases
    with 0.
    :param m: A torch.nn object

    *Example*:
    net = nn.Sequential(nn.Linear(1,2), nn.Linear(2,1))
    net.apply(glorot_init)
    """
    classname = m.__class__.__name__
    # only initialize for Linear or Conv layers
    if classname.find('Linear') != -1 or classname.find('Conv') != -1:
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0.0)