sereich's picture
Initial commit of Radio Upscaling UI (minus models)
f113387
raw
history blame
493 Bytes
from src.models.aero import Aero
from src.models.seanet import Seanet
from yaml import safe_load
def get_model(model_name="aero"):
if model_name == 'aero':
with open("conf/experiment/aero_441-441_512_256.yaml") as f:
generator = Aero(**safe_load(f)["aero"])
elif model_name == 'seanet':
with open("conf/experiment/seanet_441-441.yaml") as f:
generator = Seanet(**safe_load(f)["seanet"])
models = {'generator': generator}
return models