import gradio as gr import numpy as np import random import torch from PIL import Image as im from helper.cond_encoder import CLIPEncoder from helper.loader import Loader from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder from clip.models.ko_clip import KoCLIPWrapper from diffusion_model.sampler.ddim import DDIM from diffusion_model.models.latent_diffusion_model import LatentDiffusionModel from diffusion_model.network.unet import Unet from diffusion_model.network.unet_wrapper import UnetWrapper from huggingface_hub import hf_hub_download # import spaces #[uncomment to use ZeroGPU] device = "cuda" if torch.cuda.is_available() else "cpu" loader = Loader(device) repo_id = "JuyeopDang/KoFace-Diffusion" CONFIG_PATH = 'configs/composite_config.yaml' if torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 def load_model_from_HF(model, repo_id, filename, is_ema=False): try: model_path = hf_hub_download(repo_id=repo_id, filename=filename) except Exception as e: print(f"파일 다운로드 또는 모델 로드 중 오류 발생: {e}") model_path = model_path[:-4] model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False) return model if __name__ == "__main__": vae = VariationalAutoEncoder(CONFIG_PATH) sampler = DDIM(CONFIG_PATH) clip = KoCLIPWrapper() cond_encoder = CLIPEncoder(clip, CONFIG_PATH) network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder) dm = LatentDiffusionModel(network, sampler, vae) vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False) clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True) dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True) def generate_image(y, gamma): images = dm.sample(2, y = y, gamma = gamma) images = images.permute(0, 2, 3, 1) if type(images) is torch.Tensor: images = images.detach().cpu().numpy() images = np.clip(images / 2 + 0.5, 0, 1) return im.fromarray((images[0] * 255).astype(np.uint8)) demo = gr.Interface( generate_image, inputs=["textbox", gr.Slider(0, 10)], outputs=["image"], ) demo.launch()