import pytorch_lightning as ptl from inpaint.saicinpainting.training.modules import make_generator class BaseInpaintingTrainingModule(ptl.LightningModule): def __init__(self, config, use_ddp, *args, predict_only=False, visualize_each_iters=100, average_generator=False, generator_avg_beta=0.999, average_generator_start_step=30000, average_generator_period=10, store_discr_outputs_for_vis=False, **kwargs): super().__init__(*args, **kwargs) self.config = config self.generator = make_generator(config, **self.config.generator) self.use_ddp = use_ddp self.visualize_each_iters = visualize_each_iters