import numpy as np
import torch
from get_multinomial_function import draw_polynomial

def get_data(std_x, std_y, dim=10, degree=5, 
        seed = 1, n_train=10000,
        n_test=200, n_val=200, n_offset=50000):
    """
    Returns modulated multinomial data in dimension `dim`, perturbed 
    by input and output noise with standard deviation `std_x` and `std_y`.
    :param std_x: A positive float
    :param std_y: A positive float
    :param dim: The input dimension, default to 10.
    :param degree: The degree of the multinomial, defaults to 5.
    :param seed: The seed for random number generation, defaults to 1.
    :param n_train: Number of training points, defaults to 1e4
    :param n_test: Number of test points, defaults to 1e2
    :param n_val: Number of validation points, defaults to 1e2
    :param n_offset: Number of points for scaling purposes, defaults to 5e4.
    :returns: train_data_pure, train_data,
    test_data_pure, test_data, val_data_pure, val_data, func 
    """
    # draw the polynomial to use
    pol = draw_polynomial(dim=dim, degree=degree, number_of_terms=dim*2, seed=seed)
    # modulated polynomial, will be scaled below by offset
    def unscaled_func(x, freq=5):
        return pol(x) * torch.exp(-torch.sin(freq*torch.sum(x**2, dim=1)))
    # compute offset to scale function
    np.random.seed(seed)
    offset_input = torch.tensor(np.random.uniform(low=-1.0,
            high=1.0, size=(n_offset, dim)), dtype=torch.float32)
    offset_values = unscaled_func(offset_input)
    offset_mean = torch.mean(offset_values).item()
    offset_std = torch.std(offset_values).item()
    # actual function that will be used.
    def func(x):
        return (unscaled_func(x)-offset_mean)/offset_std
    
    # map to generate arrays using func
    def gen_data(zeta, std_x, std_y):
        true_y = func(zeta)[...,None]
        x = zeta + std_x * torch.randn_like(zeta)
        y = true_y + std_y * torch.randn_like(true_y) 
        return (zeta, true_y), (x,y)

    # Generate zeta
    np.random.seed(seed)
    train_zeta = torch.tensor(np.random.uniform(low=-1.0,
        high=1.0, size=(n_train, dim)), dtype=torch.float32)
    test_zeta = torch.tensor(np.random.uniform(low=-1.0,
        high=1.0, size=(n_test, dim)), dtype=torch.float32)
    val_zeta = torch.tensor(np.random.uniform(low=-1.0,
        high=1.0, size=(n_val, dim)), dtype=torch.float32)
    # Generate data
    train_data_pure, train_data = gen_data(zeta=train_zeta, std_x=std_x, std_y=std_y)
    test_data_pure, test_data = gen_data(zeta=test_zeta, std_x=std_x, std_y=std_y)
    val_data_pure, val_data = gen_data(val_zeta, std_x=std_x, std_y=std_y)
    return train_data_pure, train_data,\
            test_data_pure, test_data,\
            val_data_pure, val_data, func