import torch def repeat_tensors(*args, number_of_draws=1): """ For each tensor in args repeat the slice for each batch index (the first one) `number_of_draws` times. This is useful in combination in EIV Modelling for multiple draws. :param *args: Include here torch.tensor elements :param number_of_draws: An integer >= 1 """ repeated_args = [] for arg in args: repeated_arg = arg.repeat_interleave(number_of_draws, dim=0) repeated_args.append(repeated_arg) return repeated_args def reshape_to_chunks(*args, number_of_draws=1): """ For each element of *args, split in chunks of number_of_draws and stack them. Applying this to the output of repeat_tensors this leads to a tensor, where the first dimension counts the batch dimension, and the second counts the number_of_draws. :param number_of_draws: An integer, the chunk size, defaults to 1 """ reshaped_args = [] for arg in args: arg = torch.split(arg, number_of_draws, dim=0) reshaped_args.append(torch.stack(arg, dim=0)) return reshaped_args