JuyeopDang commited on
Commit
8a1772b
Β·
verified Β·
1 Parent(s): d734cdd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -4,6 +4,7 @@ import random
4
  import torch
5
 
6
  from helper.cond_encoder import CLIPEncoder
 
7
 
8
  from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
9
  from clip.models.ko_clip import KoCLIPWrapper
@@ -15,29 +16,35 @@ from diffusion_model.network.unet_wrapper import UnetWrapper
15
  # import spaces #[uncomment to use ZeroGPU]
16
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
18
 
19
  if torch.cuda.is_available():
20
  torch_dtype = torch.float16
21
  else:
22
  torch_dtype = torch.float32
23
 
24
- if __name__ == "__main__":
25
- from huggingface_hub import hf_hub_download
26
- CONFIG_PATH = 'configs/composite_config.yaml'
27
-
28
- repo_id = "JuyeopDang/KoFace-Diffusion"
29
- filename = "composite_epoch2472.pth" # 예: "pytorch_model.pt" λ˜λŠ” "model.pt"
30
- vae = VariationalAutoEncoder(CONFIG_PATH)
31
-
32
  try:
33
- # 파일 λ‹€μš΄λ‘œλ“œ
34
- # cache_dir을 μ§€μ •ν•˜λ©΄ λ‹€μš΄λ‘œλ“œλœ 파일이 μ €μž₯될 경둜λ₯Ό μ œμ–΄ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
35
- # κΈ°λ³Έμ μœΌλ‘œλŠ” ~/.cache/huggingface/hub 에 μ €μž₯λ©λ‹ˆλ‹€.
36
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
37
- print(f"λͺ¨λΈ κ°€μ€‘μΉ˜ 파일이 λ‹€μŒ κ²½λ‘œμ— λ‹€μš΄λ‘œλ“œλ˜μ—ˆμŠ΅λ‹ˆλ‹€: {model_path}")
38
  except Exception as e:
39
  print(f"파일 λ‹€μš΄λ‘œλ“œ λ˜λŠ” λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
 
 
 
40
 
41
- state_dict = torch.load(model_path, map_location=device)
42
- vae.load_state_dict(state_dict['model_state_dict'])
43
- print(vae)
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import torch
5
 
6
  from helper.cond_encoder import CLIPEncoder
7
+ from helper.loader import Loader
8
 
9
  from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
10
  from clip.models.ko_clip import KoCLIPWrapper
 
16
  # import spaces #[uncomment to use ZeroGPU]
17
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ loader = Loader(device)
20
+ repo_id = "JuyeopDang/KoFace-Diffusion"
21
 
22
  if torch.cuda.is_available():
23
  torch_dtype = torch.float16
24
  else:
25
  torch_dtype = torch.float32
26
 
27
+ def load_model_from_HF(model, repo_id, filename, is_ema=False):
 
 
 
 
 
 
 
28
  try:
 
 
 
29
  model_path = hf_hub_download(repo_id=repo_id, filename=filename)
 
30
  except Exception as e:
31
  print(f"파일 λ‹€μš΄λ‘œλ“œ λ˜λŠ” λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
32
+ model_path = model_path[:-4]
33
+ model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
34
+ return model
35
 
36
+ if __name__ == "__main__":
37
+ from huggingface_hub import hf_hub_download
38
+ CONFIG_PATH = 'configs/composite_config.yaml'
39
+
40
+ vae = VariationalAutoEncoder(CONFIG_PATH)
41
+ clip = KoCLIPWrapper()
42
+ cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
43
+ network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
44
+ dm = LatentDiffusionModel(network, sampler, vae)
45
+
46
+ vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
47
+ clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
48
+ dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)
49
+
50
+ print(dm)