import torch from EIVTrainingRoutines.train_and_store import open_stored_training def create_strings(template, iterator, before=(), after=()): """ Returns a list of strings that are created via inserting `before` and `after` in the string returned by `iterator` :param template: A string :param iterator: Iterator (list, generator, range, ...) :param before: A list :param after: A list """ return_list = [] for i in iterator: return_list.append(template % (*before, i, *after)) return return_list class Ensemble(): """ Takes the strings from the list saved_files and uses them to load networks of architecture_class with **kwargs into a collection. A list of all members is given by self.members. :param saved_files: A list of strings :param architecture_class: A class, inherited from nn.Module :param device: torch.device :param :**kwargs will be given to architecture_class upon creation of each member """ def __init__(self, saved_files, architecture_class, device, **kwargs): self.saved_files = saved_files self.architecture_class = architecture_class self.device = device self.kwargs = kwargs self.members = [] # self.load_members() def load_members(self): """ Loads nets from saved_files, sets them to eval mode and stores them in into self.members """ for filename in self.saved_files: net = self.architecture_class(**self.kwargs) open_stored_training(saved_file=filename, net=net, device=self.device) net.eval() self.members.append(net) def __getitem__(self, i): return self.members[i] def __call__(self, x): """ Evaluates net(x) for each net in self.members and returns the result as a list. """ out = [] for net in self.members: out.append(net(x)) return out def mean_and_std(self, x, detach = True): out = [] for net in self.members: pred = net(x) if type(pred) is list or type(pred) is tuple: pred = pred[0] if detach: pred = pred.detach() out.append(pred) out = torch.stack(out, dim=0) return torch.mean(out, dim=0), torch.std(out, dim=0) def mean(self, x, detach = True): return self.mean_and_std(x, detach=detach)[0] def std(self, x, detach = True): return self.mean_and_std(x, detach=detach)[1]