""" Collection of functions to manipulate tensors """ import torch def get_normalization(t): """ Returns the mean and standard deviations (in tuples) of the tensor `t` """ t_mean = torch.mean(t, dim=0, keepdim=True) t_std = torch.std(t, dim=0, keepdim=True) return (t_mean, t_std) def normalize_tensor(t, mean_std): """ Normalize the tensor `t` by the mean `mean_std[0]` and the standard devation `mean_std[1]` """ return (t-mean_std[0])/mean_std[1] def add_noise(tensor_list, noise_strength_list, seed_list, normalize=True, normalization_list = None): """ Takes the tensors in `tensor_list`, adds random noise using the standard deviations in `noise_strength_list` and the seeds in `seed_list`, then, if normalize is True (default), computes according normalization and returns the normalized noisy tensors and the normalized unnoisy tensors. If `normalize` is False, no normalization is performed and the second returned list will coincide with `tensor_list`. :param tensor_list: A list of torch.tensors :param noise_strength_list: A list of positive floats :param seed_list: A list of integers. :param normalize: A Boolean, defaults to True. :param normalization_list: Either None (default) or a list of tensors. If the latter, these tensors will be used for normalization and `normalize` is assumed to be True. :returns: noisy_tensor_list, unnoisy_tensor_list, both normalized """ noisy_t_list = [] unnoisy_t_list = [] if normalization_list is not None: assert len(normalization_list) == len(tensor_list) for i, (t,noise,seed) in enumerate(zip(tensor_list, noise_strength_list,\ seed_list)): noisy_t = t + noise * torch.randn(t.shape, generator=torch.Generator().manual_seed(seed)) if normalize: if normalization_list is not None: noisy_t_normalization =\ get_normalization(normalization_list[i]) else: noisy_t_normalization = get_normalization(noisy_t) noisy_t = normalize_tensor(noisy_t, noisy_t_normalization) t = normalize_tensor(t, noisy_t_normalization) noisy_t_list.append(noisy_t) unnoisy_t_list.append(t) return noisy_t_list, unnoisy_t_list