-
Jörg Martin authoredJörg Martin authored
ensemble_handling.py 2.57 KiB
import torch
from EIVTrainingRoutines.train_and_store import open_stored_training
def create_strings(template, iterator, before=(), after=()):
"""
Returns a list of strings that are created via inserting `before`
and `after` in the string returned by `iterator`
:param template: A string
:param iterator: Iterator (list, generator, range, ...)
:param before: A list
:param after: A list
"""
return_list = []
for i in iterator:
return_list.append(template % (*before, i, *after))
return return_list
class Ensemble():
"""
Takes the strings from the list saved_files and uses them to load
networks of architecture_class with **kwargs into a collection.
A list of all members is given by self.members.
:param saved_files: A list of strings
:param architecture_class: A class, inherited from nn.Module
:param device: torch.device
:param :**kwargs will be given to architecture_class upon creation of
each member
"""
def __init__(self, saved_files, architecture_class, device, **kwargs):
self.saved_files = saved_files
self.architecture_class = architecture_class
self.device = device
self.kwargs = kwargs
self.members = []
#
self.load_members()
def load_members(self):
"""
Loads nets from saved_files, sets them to eval mode and stores them in
into self.members
"""
for filename in self.saved_files:
net = self.architecture_class(**self.kwargs)
open_stored_training(saved_file=filename,
net=net,
device=self.device)
net.eval()
self.members.append(net)
def __getitem__(self, i):
return self.members[i]
def __call__(self, x):
"""
Evaluates net(x) for each net in self.members and returns the result
as a list.
"""
out = []
for net in self.members:
out.append(net(x))
return out
def mean_and_std(self, x, detach = True):
out = []
for net in self.members:
pred = net(x)
if type(pred) is list or type(pred) is tuple:
pred = pred[0]
if detach:
pred = pred.detach()
out.append(pred)
out = torch.stack(out, dim=0)
return torch.mean(out, dim=0), torch.std(out, dim=0)
def mean(self, x, detach = True):
return self.mean_and_std(x, detach=detach)[0]
def std(self, x, detach = True):
return self.mean_and_std(x, detach=detach)[1]