import torch
from EIVTrainingRoutines.train_and_store import open_stored_training

def create_strings(template, iterator, before=(), after=()):
    """
    Returns a list of strings that are created via inserting `before`
    and `after` in the string returned by `iterator`
    :param template: A string
    :param iterator: Iterator (list, generator, range, ...)
    :param before: A list
    :param after: A list
    """
    return_list = []
    for i in iterator:
        return_list.append(template % (*before, i, *after))
    return return_list


class Ensemble():
    """
    Takes the strings from the list saved_files and uses them to load
    networks of architecture_class with **kwargs into a collection.
    A list of all members is given by self.members.
    :param saved_files: A list of strings
    :param architecture_class: A class, inherited from nn.Module
    :param device: torch.device
    :param :**kwargs will be given to architecture_class upon creation of
    each member
    """
    def __init__(self, saved_files, architecture_class, device, **kwargs):
        self.saved_files = saved_files
        self.architecture_class = architecture_class
        self.device = device
        self.kwargs = kwargs
        self.members = []
        #
        self.load_members()

    def load_members(self):
        """
        Loads nets from saved_files, sets them to eval mode and stores them in
        into self.members
        """
        for filename in self.saved_files:
            net = self.architecture_class(**self.kwargs)
            open_stored_training(saved_file=filename,
                    net=net,
                    device=self.device)
            net.eval()
            self.members.append(net)


    def __getitem__(self, i):
        return self.members[i]

    def __call__(self, x):
        """
        Evaluates net(x) for each net in self.members and returns the result
        as a list.
        """
        out = []
        for net in self.members:
            out.append(net(x))
        return out

    def mean_and_std(self, x, detach = True):
        out = []
        for net in self.members:
            pred = net(x)
            if type(pred) is list or type(pred) is tuple:
                pred = pred[0]
            if detach:
                pred = pred.detach()
            out.append(pred)
        out = torch.stack(out, dim=0)
        return torch.mean(out, dim=0), torch.std(out, dim=0)

    def mean(self, x, detach = True):
        return self.mean_and_std(x, detach=detach)[0]

    def std(self, x, detach = True):
        return self.mean_and_std(x, detach=detach)[1]