Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import random
|
|
4 |
import torch
|
5 |
|
6 |
from helper.cond_encoder import CLIPEncoder
|
|
|
7 |
|
8 |
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
|
9 |
from clip.models.ko_clip import KoCLIPWrapper
|
@@ -15,29 +16,35 @@ from diffusion_model.network.unet_wrapper import UnetWrapper
|
|
15 |
# import spaces #[uncomment to use ZeroGPU]
|
16 |
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
18 |
|
19 |
if torch.cuda.is_available():
|
20 |
torch_dtype = torch.float16
|
21 |
else:
|
22 |
torch_dtype = torch.float32
|
23 |
|
24 |
-
|
25 |
-
from huggingface_hub import hf_hub_download
|
26 |
-
CONFIG_PATH = 'configs/composite_config.yaml'
|
27 |
-
|
28 |
-
repo_id = "JuyeopDang/KoFace-Diffusion"
|
29 |
-
filename = "composite_epoch2472.pth" # μ: "pytorch_model.pt" λλ "model.pt"
|
30 |
-
vae = VariationalAutoEncoder(CONFIG_PATH)
|
31 |
-
|
32 |
try:
|
33 |
-
# νμΌ λ€μ΄λ‘λ
|
34 |
-
# cache_dirμ μ§μ νλ©΄ λ€μ΄λ‘λλ νμΌμ΄ μ μ₯λ κ²½λ‘λ₯Ό μ μ΄ν μ μμ΅λλ€.
|
35 |
-
# κΈ°λ³Έμ μΌλ‘λ ~/.cache/huggingface/hub μ μ μ₯λ©λλ€.
|
36 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
37 |
-
print(f"λͺ¨λΈ κ°μ€μΉ νμΌμ΄ λ€μ κ²½λ‘μ λ€μ΄λ‘λλμμ΅λλ€: {model_path}")
|
38 |
except Exception as e:
|
39 |
print(f"νμΌ λ€μ΄λ‘λ λλ λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}")
|
|
|
|
|
|
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import torch
|
5 |
|
6 |
from helper.cond_encoder import CLIPEncoder
|
7 |
+
from helper.loader import Loader
|
8 |
|
9 |
from auto_encoder.models.variational_auto_encoder import VariationalAutoEncoder
|
10 |
from clip.models.ko_clip import KoCLIPWrapper
|
|
|
16 |
# import spaces #[uncomment to use ZeroGPU]
|
17 |
|
18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
loader = Loader(device)
|
20 |
+
repo_id = "JuyeopDang/KoFace-Diffusion"
|
21 |
|
22 |
if torch.cuda.is_available():
|
23 |
torch_dtype = torch.float16
|
24 |
else:
|
25 |
torch_dtype = torch.float32
|
26 |
|
27 |
+
def load_model_from_HF(model, repo_id, filename, is_ema=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
try:
|
|
|
|
|
|
|
29 |
model_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
|
|
30 |
except Exception as e:
|
31 |
print(f"νμΌ λ€μ΄λ‘λ λλ λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}")
|
32 |
+
model_path = model_path[:-4]
|
33 |
+
model = loader.model_load(model_path, model, is_ema=is_ema, print_dict=False)
|
34 |
+
return model
|
35 |
|
36 |
+
if __name__ == "__main__":
|
37 |
+
from huggingface_hub import hf_hub_download
|
38 |
+
CONFIG_PATH = 'configs/composite_config.yaml'
|
39 |
+
|
40 |
+
vae = VariationalAutoEncoder(CONFIG_PATH)
|
41 |
+
clip = KoCLIPWrapper()
|
42 |
+
cond_encoder = CLIPEncoder(clip, CONFIG_PATH)
|
43 |
+
network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)
|
44 |
+
dm = LatentDiffusionModel(network, sampler, vae)
|
45 |
+
|
46 |
+
vae = load_model_from_HF(vae, repo_id, "composite_epoch2472.pth", False)
|
47 |
+
clip = load_model_from_HF(clip, repo_id, "asian-composite-fine-tuned-koclip.pth", True)
|
48 |
+
dm = load_model_from_HF(dm, repo_id, "asian-composite-clip-ldm.pth", True)
|
49 |
+
|
50 |
+
print(dm)
|