from src.models.aero import Aero from src.models.seanet import Seanet from yaml import safe_load def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"): if model_name == 'aero': with open("conf/experiment/" + experiment_file) as f: generator = Aero(**safe_load(f)["aero"]) elif model_name == 'seanet': with open("conf/experiment/" + experiment_file) as f: generator = Seanet(**safe_load(f)["seanet"]) models = {'generator': generator} return models