ovi054 commited on
Commit
334bae7
·
verified ·
1 Parent(s): 106bc18

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -1
app.py CHANGED
@@ -5,6 +5,12 @@ import spaces
5
  import torch
6
  from diffusers import QwenImagePipeline
7
 
 
 
 
 
 
 
8
  dtype = torch.bfloat16
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -19,6 +25,46 @@ MAX_IMAGE_SIZE = 2048
19
 
20
  # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @spaces.GPU()
23
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
24
  if randomize_seed:
@@ -28,7 +74,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
28
 
29
  if lora_id and lora_id.strip() != "":
30
  pipe.unload_lora_weights()
31
- pipe.load_lora_weights(lora_id.strip())
32
 
33
  try:
34
  image = pipe(
@@ -41,6 +87,7 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidan
41
  true_cfg_scale=guidance_scale,
42
  guidance_scale=1.0 # Use a fixed default for distilled guidance
43
  ).images[0]
 
44
  return image, seed
45
  finally:
46
  # Unload LoRA weights if they were loaded
 
5
  import torch
6
  from diffusers import QwenImagePipeline
7
 
8
+ import os
9
+ import requests
10
+ import tempfile
11
+ import shutil
12
+ from urllib.parse import urlparse
13
+
14
  dtype = torch.bfloat16
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
 
25
 
26
  # pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
27
 
28
+
29
+ def load_lora_auto(pipe, lora_input):
30
+ lora_input = lora_input.strip()
31
+ if not lora_input:
32
+ return
33
+
34
+ # If it's just an ID like "author/model"
35
+ if "/" in lora_input and not lora_input.startswith("http"):
36
+ pipe.load_lora_weights(lora_input)
37
+ return
38
+
39
+ if lora_input.startswith("http"):
40
+ url = lora_input
41
+
42
+ # Repo page (no blob/resolve)
43
+ if "huggingface.co" in url and "/blob/" not in url and "/resolve/" not in url:
44
+ repo_id = urlparse(url).path.strip("/")
45
+ pipe.load_lora_weights(repo_id)
46
+ return
47
+
48
+ # Blob link → convert to resolve link
49
+ if "/blob/" in url:
50
+ url = url.replace("/blob/", "/resolve/")
51
+
52
+ # Download direct file
53
+ tmp_dir = tempfile.mkdtemp()
54
+ local_path = os.path.join(tmp_dir, os.path.basename(urlparse(url).path))
55
+
56
+ try:
57
+ print(f"Downloading LoRA from {url}...")
58
+ resp = requests.get(url, stream=True)
59
+ resp.raise_for_status()
60
+ with open(local_path, "wb") as f:
61
+ for chunk in resp.iter_content(chunk_size=8192):
62
+ f.write(chunk)
63
+ print(f"Saved LoRA to {local_path}")
64
+ pipe.load_lora_weights(local_path)
65
+ finally:
66
+ shutil.rmtree(tmp_dir, ignore_errors=True)
67
+
68
  @spaces.GPU()
69
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=4, num_inference_steps=28, lora_id=None, lora_scale=0.95, progress=gr.Progress(track_tqdm=True)):
70
  if randomize_seed:
 
74
 
75
  if lora_id and lora_id.strip() != "":
76
  pipe.unload_lora_weights()
77
+ load_lora_auto(pipe, lora_id)
78
 
79
  try:
80
  image = pipe(
 
87
  true_cfg_scale=guidance_scale,
88
  guidance_scale=1.0 # Use a fixed default for distilled guidance
89
  ).images[0]
90
+ print("Image Generation Completed for: ", prompt, lora_id)
91
  return image, seed
92
  finally:
93
  # Unload LoRA weights if they were loaded