sereich's picture
Add phone model (beta), allow models to use different architectures
efc318c
raw
history blame
530 Bytes
from src.models.aero import Aero
from src.models.seanet import Seanet
from yaml import safe_load
def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"):
if model_name == 'aero':
with open("conf/experiment/" + experiment_file) as f:
generator = Aero(**safe_load(f)["aero"])
elif model_name == 'seanet':
with open("conf/experiment/" + experiment_file) as f:
generator = Seanet(**safe_load(f)["seanet"])
models = {'generator': generator}
return models