from src.models.aero import Aero | |
from src.models.seanet import Seanet | |
from yaml import safe_load | |
def get_model(model_name="aero", experiment_file="aero_441-441_512_256.yaml"): | |
if model_name == 'aero': | |
with open("conf/experiment/" + experiment_file) as f: | |
generator = Aero(**safe_load(f)["aero"]) | |
elif model_name == 'seanet': | |
with open("conf/experiment/" + experiment_file) as f: | |
generator = Seanet(**safe_load(f)["seanet"]) | |
models = {'generator': generator} | |
return models |