Ephemeral182 commited on
Commit
bbb89b4
·
verified ·
1 Parent(s): 359c460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -35
app.py CHANGED
@@ -7,10 +7,23 @@ import spaces
7
  import torch
8
  from diffusers import FluxPipeline, FluxTransformer2DModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
10
 
11
  # ------------------------------------------------------------------
12
- # 1. Global Configuration
13
  # ------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
14
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
15
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
16
  DEFAULT_CUSTOM_WEIGHTS_PATH = "PosterCraft/PosterCraft-v1_RL"
@@ -46,23 +59,23 @@ def download_model_weights(target_dir, repo_id, subdir=None):
46
  os.makedirs(tmp_dir, exist_ok=True)
47
 
48
  try:
 
 
 
 
 
 
 
 
 
 
 
49
  if subdir:
50
- snapshot_download(
51
- repo_id=repo_id,
52
- repo_type="model",
53
- local_dir=tmp_dir,
54
- allow_patterns=os.path.join(subdir, "**"),
55
- local_dir_use_symlinks=False,
56
- )
57
- src_dir = os.path.join(tmp_dir, subdir)
58
- else:
59
- snapshot_download(
60
- repo_id=repo_id,
61
- repo_type="model",
62
- local_dir=tmp_dir,
63
- local_dir_use_symlinks=False,
64
- )
65
- src_dir = tmp_dir
66
 
67
  if os.path.exists(src_dir):
68
  shutil.copytree(src_dir, target_dir)
@@ -86,14 +99,20 @@ def ensure_models_downloaded():
86
  # Download custom weights
87
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
88
  if not os.path.exists(custom_weights_local):
89
- logging.info("Downloading custom Transformer weights...")
90
- download_model_weights(custom_weights_local, DEFAULT_CUSTOM_WEIGHTS_PATH)
 
 
 
91
 
92
  # Download Qwen model
93
  qwen_local = "local_weights/Qwen3-8B"
94
  if not os.path.exists(qwen_local):
95
- logging.info("Downloading Qwen model...")
96
- download_model_weights(qwen_local, DEFAULT_QWEN_MODEL_PATH)
 
 
 
97
 
98
  logging.info("Model download check completed")
99
 
@@ -105,12 +124,17 @@ ensure_models_downloaded()
105
  # ------------------------------------------------------------------
106
  def create_qwen_agent(model_path):
107
  """Create Qwen agent inside GPU context"""
108
- tokenizer = AutoTokenizer.from_pretrained(model_path)
109
- model = AutoModelForCausalLM.from_pretrained(
110
- model_path,
111
- torch_dtype=torch.bfloat16,
112
- device_map="auto"
113
- )
 
 
 
 
 
114
  return tokenizer, model
115
 
116
  def recap_prompt(tokenizer, model, text):
@@ -181,7 +205,7 @@ Elaborate on each core requirement to create a rich description.
181
  # ------------------------------------------------------------------
182
  # 5. ZeroGPU Inference Function
183
  # ------------------------------------------------------------------
184
- @spaces.GPU(duration=300) # 增加到5分钟,给模型加载更多时间
185
  def generate_image_interface(
186
  original_prompt, enable_recap, height, width,
187
  num_inference_steps, guidance_scale, seed_input,
@@ -198,11 +222,14 @@ def generate_image_interface(
198
 
199
  progress(0.1, desc="Loading FLUX pipeline...")
200
 
201
- # Load FLUX pipeline
202
- pipeline = FluxPipeline.from_pretrained(
203
- DEFAULT_PIPELINE_PATH,
204
- torch_dtype=torch.bfloat16
205
- )
 
 
 
206
 
207
  progress(0.2, desc="Loading custom transformer...")
208
 
@@ -210,9 +237,12 @@ def generate_image_interface(
210
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
211
  if os.path.exists(custom_weights_local):
212
  try:
 
 
 
 
213
  transformer = FluxTransformer2DModel.from_pretrained(
214
- custom_weights_local,
215
- torch_dtype=torch.bfloat16
216
  )
217
  pipeline.transformer = transformer
218
  logging.info("Custom Transformer loaded successfully")
@@ -274,6 +304,11 @@ def generate_image_interface(
274
  with gr.Blocks(theme=gr.themes.Soft(), title="PosterCraft") as demo:
275
  gr.Markdown("# PosterCraft-v1.0")
276
  gr.Markdown(f"Base Pipeline: **{DEFAULT_PIPELINE_PATH}**")
 
 
 
 
 
277
  gr.Markdown("⚠️ **First use requires model download, please wait about 10-15 minutes**")
278
 
279
  with gr.Row():
 
7
  import torch
8
  from diffusers import FluxPipeline, FluxTransformer2DModel
9
  from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ from huggingface_hub import login
11
 
12
  # ------------------------------------------------------------------
13
+ # 1. Authentication and Global Configuration
14
  # ------------------------------------------------------------------
15
+ # Authenticate with HF token
16
+ hf_token = os.getenv("HF_TOKEN")
17
+ if hf_token:
18
+ try:
19
+ login(token=hf_token, add_to_git_credential=True)
20
+ logging.info("Successfully authenticated with Hugging Face")
21
+ except Exception as e:
22
+ logging.error(f"HF authentication failed: {e}")
23
+ raise Exception("Authentication failed. Please check your HF_TOKEN.")
24
+ else:
25
+ logging.warning("No HF_TOKEN found in environment variables")
26
+
27
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
28
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
29
  DEFAULT_CUSTOM_WEIGHTS_PATH = "PosterCraft/PosterCraft-v1_RL"
 
59
  os.makedirs(tmp_dir, exist_ok=True)
60
 
61
  try:
62
+ download_kwargs = {
63
+ "repo_id": repo_id,
64
+ "repo_type": "model",
65
+ "local_dir": tmp_dir,
66
+ "local_dir_use_symlinks": False,
67
+ }
68
+
69
+ # Add token if available
70
+ if hf_token:
71
+ download_kwargs["token"] = hf_token
72
+
73
  if subdir:
74
+ download_kwargs["allow_patterns"] = os.path.join(subdir, "**")
75
+
76
+ snapshot_download(**download_kwargs)
77
+
78
+ src_dir = os.path.join(tmp_dir, subdir) if subdir else tmp_dir
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  if os.path.exists(src_dir):
81
  shutil.copytree(src_dir, target_dir)
 
99
  # Download custom weights
100
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
101
  if not os.path.exists(custom_weights_local):
102
+ try:
103
+ logging.info("Downloading custom Transformer weights...")
104
+ download_model_weights(custom_weights_local, DEFAULT_CUSTOM_WEIGHTS_PATH)
105
+ except Exception as e:
106
+ logging.warning(f"Failed to download custom weights: {e}")
107
 
108
  # Download Qwen model
109
  qwen_local = "local_weights/Qwen3-8B"
110
  if not os.path.exists(qwen_local):
111
+ try:
112
+ logging.info("Downloading Qwen model...")
113
+ download_model_weights(qwen_local, DEFAULT_QWEN_MODEL_PATH)
114
+ except Exception as e:
115
+ logging.warning(f"Failed to download Qwen model: {e}")
116
 
117
  logging.info("Model download check completed")
118
 
 
124
  # ------------------------------------------------------------------
125
  def create_qwen_agent(model_path):
126
  """Create Qwen agent inside GPU context"""
127
+ load_kwargs = {
128
+ "torch_dtype": torch.bfloat16,
129
+ "device_map": "auto"
130
+ }
131
+
132
+ # Add token if available
133
+ if hf_token:
134
+ load_kwargs["token"] = hf_token
135
+
136
+ tokenizer = AutoTokenizer.from_pretrained(model_path, **load_kwargs)
137
+ model = AutoModelForCausalLM.from_pretrained(model_path, **load_kwargs)
138
  return tokenizer, model
139
 
140
  def recap_prompt(tokenizer, model, text):
 
205
  # ------------------------------------------------------------------
206
  # 5. ZeroGPU Inference Function
207
  # ------------------------------------------------------------------
208
+ @spaces.GPU(duration=300)
209
  def generate_image_interface(
210
  original_prompt, enable_recap, height, width,
211
  num_inference_steps, guidance_scale, seed_input,
 
222
 
223
  progress(0.1, desc="Loading FLUX pipeline...")
224
 
225
+ # Load FLUX pipeline with explicit token
226
+ load_kwargs = {
227
+ "torch_dtype": torch.bfloat16
228
+ }
229
+ if hf_token:
230
+ load_kwargs["token"] = hf_token
231
+
232
+ pipeline = FluxPipeline.from_pretrained(DEFAULT_PIPELINE_PATH, **load_kwargs)
233
 
234
  progress(0.2, desc="Loading custom transformer...")
235
 
 
237
  custom_weights_local = "local_weights/PosterCraft-v1_RL"
238
  if os.path.exists(custom_weights_local):
239
  try:
240
+ transformer_kwargs = {"torch_dtype": torch.bfloat16}
241
+ if hf_token:
242
+ transformer_kwargs["token"] = hf_token
243
+
244
  transformer = FluxTransformer2DModel.from_pretrained(
245
+ custom_weights_local, **transformer_kwargs
 
246
  )
247
  pipeline.transformer = transformer
248
  logging.info("Custom Transformer loaded successfully")
 
304
  with gr.Blocks(theme=gr.themes.Soft(), title="PosterCraft") as demo:
305
  gr.Markdown("# PosterCraft-v1.0")
306
  gr.Markdown(f"Base Pipeline: **{DEFAULT_PIPELINE_PATH}**")
307
+
308
+ # Show authentication status
309
+ auth_status = "🟢 Authenticated" if hf_token else "🔴 Not Authenticated"
310
+ gr.Markdown(f"Authentication Status: {auth_status}")
311
+
312
  gr.Markdown("⚠️ **First use requires model download, please wait about 10-15 minutes**")
313
 
314
  with gr.Row():