"""
Collection of functions to manipulate tensors
"""
import torch

def get_normalization(t):
    """
    Returns the mean and standard deviations (in tuples) of the tensor `t`
    """
    t_mean = torch.mean(t, dim=0, keepdim=True)
    t_std = torch.std(t, dim=0, keepdim=True)
    return (t_mean, t_std)

def normalize_tensor(t, mean_std):
    """
    Normalize the tensor `t` by the mean `mean_std[0]` and the standard
    devation `mean_std[1]`
    """
    return (t-mean_std[0])/mean_std[1]


def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True,
        normalization_list = None):
    """
    Takes the tensors in `tensor_list`, adds random noise using the standard
    deviations in `noise_strength_list` and the seeds in `seed_list`, then, if
    normalize is True (default), computes according normalization and returns
    the normalized noisy tensors and the normalized unnoisy tensors. If
    `normalize` is False, no normalization is performed and the second returned
    list will coincide with `tensor_list`.
    :param tensor_list: A list of torch.tensors
    :param noise_strength_list: A list of positive floats
    :param seed_list: A list of integers.
    :param normalize: A Boolean, defaults to True.
    :param normalization_list: Either None (default) or a list of tensors.
    If the latter, these tensors will be used for normalization and `normalize`
    is assumed to be True.
    :returns: noisy_tensor_list, unnoisy_tensor_list, both normalized
    """
    noisy_t_list = []
    unnoisy_t_list = []
    if normalization_list is not None:
        assert len(normalization_list) == len(tensor_list)
    for i, (t,noise,seed) in enumerate(zip(tensor_list, noise_strength_list,\
            seed_list)):
        noisy_t = t + noise * torch.randn(t.shape,
                generator=torch.Generator().manual_seed(seed))
        if normalize:
            if normalization_list is not None:
                noisy_t_normalization =\
                        get_normalization(normalization_list[i])
            else:
                noisy_t_normalization = get_normalization(noisy_t)
            noisy_t = normalize_tensor(noisy_t, noisy_t_normalization)
            t = normalize_tensor(t, noisy_t_normalization)
        noisy_t_list.append(noisy_t)
        unnoisy_t_list.append(t)
    return noisy_t_list, unnoisy_t_list