import torch

def repeat_tensors(*args, number_of_draws=1):
    """
    For each tensor in args repeat the slice 
    for each batch index (the first one) `number_of_draws` times. 
    This is useful in combination in EIV Modelling for multiple draws.
    :param *args: Include here torch.tensor elements
    :param number_of_draws: An integer >= 1
    """
    repeated_args = []
    for arg in args:
        repeated_arg = arg.repeat_interleave(number_of_draws, dim=0)
        repeated_args.append(repeated_arg)
    return repeated_args


def reshape_to_chunks(*args, number_of_draws=1):
    """
    For each element of *args, split in chunks of number_of_draws 
    and stack them. Applying this to the output of repeat_tensors
    this leads to a tensor, where the first dimension 
    counts the batch dimension, and the second counts the number_of_draws.
    :param number_of_draws: An integer, the chunk size, defaults to 1
    """
    reshaped_args = []
    for arg in args:
        arg = torch.split(arg, number_of_draws, dim=0)
        reshaped_args.append(torch.stack(arg, dim=0))
    return reshaped_args