-
Jörg Martin authoredJörg Martin authored
get_multinomial_function.py 2.35 KiB
import numpy as np
import torch
def partition(n, max_number_of_terms = np.inf):
assert max_number_of_terms > 0
list_of_numbers = []
while n>0:
if len(list_of_numbers) >= max_number_of_terms -1 and n>0:
list_of_numbers.append(n)
n = 0
else:
k = np.random.randint(1,n+1)
n -= k
list_of_numbers.append(k)
return list_of_numbers
class Polynomial():
def __init__(self, terms, coefficients):
self.terms = terms
self.coefficients = coefficients
def degree(self):
return max([sum(term).item()
for term in self.terms])
def __call__(self, x):
assert len(x.shape) == 2 and x.shape[1] == self.terms[0].shape[0]
result = 0
for coefficient, term in zip(self.coefficients, self.terms):
term = term[None,:]
result += coefficient * torch.prod(x**term, dim=1)
return result
def draw_polynomial_term(dim, degree):
term = torch.zeros((dim,))
#
powers = torch.tensor(partition(degree, max_number_of_terms=dim), dtype=torch.float32)
variables = torch.tensor(
np.random.choice(dim, size=len(powers), replace=False))
term[variables] = powers
return term
def normal_coefficient_sampler():
coefficient = torch.tensor([0.0])
while coefficient.item() == 0.0:
coefficient = torch.randn(1)
return coefficient
def draw_polynomial(dim, degree, number_of_terms=None,
coefficient_sampler = normal_coefficient_sampler,
include_bias = True, seed=None):
if number_of_terms is None:
number_of_terms = degree
if seed is not None:
torch.manual_seed(seed)
np.random.seed(seed)
assert number_of_terms > 0 and degree > 0 and dim > 0
coefficients, terms = [], []
# add a random term to assure that
# the polynomial has the right degree
coefficients.append(normal_coefficient_sampler())
terms.append(draw_polynomial_term(dim, degree))
for _ in range(number_of_terms-1):
deg = np.random.randint(degree) + 1
coefficients.append(normal_coefficient_sampler())
terms.append(draw_polynomial_term(dim, deg))
if include_bias:
coefficients.append(normal_coefficient_sampler())
terms.append(torch.zeros(dim))
return Polynomial(terms, coefficients)