import numpy as np
import torch

def partition(n, max_number_of_terms = np.inf):
    assert max_number_of_terms > 0 
    list_of_numbers = []
    while n>0:
        if len(list_of_numbers) >= max_number_of_terms -1 and n>0:
            list_of_numbers.append(n)
            n = 0
        else:
            k = np.random.randint(1,n+1)
            n -= k
            list_of_numbers.append(k)
    return list_of_numbers


class Polynomial():
    def __init__(self, terms, coefficients):
        self.terms = terms
        self.coefficients = coefficients

    def degree(self):
        return max([sum(term).item() 
            for term in self.terms])

    def __call__(self, x):
        assert len(x.shape) == 2 and x.shape[1] == self.terms[0].shape[0]
        result = 0
        for coefficient, term in zip(self.coefficients, self.terms):
            term = term[None,:]
            result += coefficient * torch.prod(x**term, dim=1)
        return result
        

def draw_polynomial_term(dim, degree):
    term = torch.zeros((dim,))
    #
    powers = torch.tensor(partition(degree, max_number_of_terms=dim), dtype=torch.float32)
    variables = torch.tensor(
            np.random.choice(dim, size=len(powers), replace=False))
    term[variables] = powers
    return term
     

def normal_coefficient_sampler():
    coefficient = torch.tensor([0.0])
    while coefficient.item() == 0.0:
        coefficient = torch.randn(1)
    return coefficient


def draw_polynomial(dim, degree, number_of_terms=None,
        coefficient_sampler = normal_coefficient_sampler,
        include_bias = True, seed=None):
    if  number_of_terms is None:
        number_of_terms = degree
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
    assert number_of_terms > 0 and degree > 0 and dim > 0
    coefficients, terms = [], []
    # add a random term to assure that 
    # the polynomial has the right degree
    coefficients.append(normal_coefficient_sampler())
    terms.append(draw_polynomial_term(dim, degree))
    for _ in range(number_of_terms-1):
        deg = np.random.randint(degree) + 1
        coefficients.append(normal_coefficient_sampler())
        terms.append(draw_polynomial_term(dim, deg))
    if include_bias:
        coefficients.append(normal_coefficient_sampler())
        terms.append(torch.zeros(dim))
    return Polynomial(terms, coefficients)