import os
import torch
import numpy as np
import csv
from scipy import signal

# function to model
a=0.25
func = lambda x: (1 - (x/a)**2) * np.exp(-0.5*(x/a)**2)
# standard deviation of input and out output noise
n_train = 300

def normalize(*args,tuple_for_normalization):
    """
    returns a list of normalized args, 
    where each arg is expected to be a tuple.
    The mean and std will be taken from `tuple_for_normalization` and
    will be returned as last arguments
    """
    return_list = []
    assert len(tuple_for_normalization) == 2
    mean =[np.mean(a) for a in tuple_for_normalization]
    std = [np.std(a) for a in tuple_for_normalization]
    for xy in args:
        x,y = xy
        assert len(xy) == 2
        normalized_x = (x-mean[0])/std[0]
        normalized_y = (y-mean[1])/std[1]
        return_list.append((normalized_x, normalized_y))
    return return_list, (mean, std)

def get_data(std_x, std_y, seed = 1, n_train=n_train,
        n_test=200, n_val=200, post_normalize=False):
    """
    Returns Mexican hat data in 1 dimension, perturbed 
    by input and output noise with standard deviation `std_x` and `std_y`.
    :param std_x: A positive float
    :param std_y: A positive float
    :param seed: The seed for random number generation, defaults to 1.
    :param n_train: Number of training points, defaults to 1e4
    :param n_test: Number of test points, defaults to 1e2
    :param n_val: Number of validation points, defaults to 1e2
    :param n_offset: Number of points for scaling purposes, defaults to 5e4.
    :returns: train_data_pure, train_data,
    :post_normalize: If True the data will be "normlalized"
    by the mean and std of the noisy train before returned
    test_data_pure, test_data, val_data_pure, val_data, func 
    """
    # Fix a seed
    np.random.seed(seed)
    def gen_data(zeta, std_x=std_x, std_y=std_y, func=func):
        print(' -- Generating Mexican hat data with std_x = %.2f and std_y = %.2f --' % (std_x, std_y,))
        true_y = func(zeta)
        x = zeta + std_x * np.random.normal(size=zeta.shape)
        y = true_y + std_y * np.random.normal(size=zeta.shape) 
        return (zeta, true_y), (x,y)
    # Generate zeta
    train_zeta = np.random.uniform(low=-1.0, high=1.0, size=n_train)
    test_zeta = np.random.uniform(low=-1.0, high=1.0, size=n_test)
    val_zeta = np.random.uniform(low=-1.0, high=1.0, size=n_val)
    # Generate data
    train_data_pure, train_data = gen_data(train_zeta)
    test_data_pure, test_data = gen_data(test_zeta)
    val_data_pure, val_data = gen_data(val_zeta)
    if post_normalize:
        data_list, (mean, std) = normalize(train_data_pure, train_data,
                            test_data_pure, test_data,
                            val_data_pure, val_data,
                            tuple_for_normalization=train_data)
    else:
        data_list = [train_data_pure, train_data,
                test_data_pure, test_data,
                val_data_pure, val_data]
        mean=[0, 0]; std= [1, 1]
    train_data_pure, train_data,\
            test_data_pure,test_data,\
            val_data_pure,val_data  = \
            [(torch.tensor(a[0], dtype=torch.float32)[:,None],
                torch.tensor(a[1], dtype=torch.float32)[:,None])
                    for a in data_list]
    def normalized_func(x):
        normalized_x = (x-mean[0])/std[0]
        y = func(x)
        normalized_y = (y-mean[1])/std[1]
        return normalized_x, normalized_y
    return train_data_pure, train_data,\
            test_data_pure,test_data,\
            val_data_pure,val_data,(normalized_func, mean, std)