Ephemeral182 commited on
Commit
712801c
ยท
verified ยท
1 Parent(s): 044ef1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -303
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import os
2
  import random
3
  import logging
@@ -7,25 +8,10 @@ import spaces
7
  import torch
8
  from huggingface_hub import login, whoami
9
 
10
- # ------------------------------------------------------------------
11
- # 1. Authentication and Global Configuration
12
- # ------------------------------------------------------------------
13
- # Authenticate with HF token
14
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
15
- auth_status = "๐Ÿ”ด Not Authenticated"
16
-
17
- if HF_TOKEN:
18
- try:
19
- login(token=HF_TOKEN, add_to_git_credential=True)
20
- user_info = whoami(HF_TOKEN)
21
- auth_status = f"โœ… Authenticated as {user_info['name']}"
22
- logging.info(f"Successfully authenticated with Hugging Face as {user_info['name']}")
23
- except Exception as e:
24
- logging.error(f"HF authentication failed: {e}")
25
- auth_status = f"๐Ÿ”ด Authentication Error: {str(e)}"
26
- else:
27
- logging.warning("No HF_TOKEN found in environment variables")
28
- auth_status = "๐Ÿ”ด No HF_TOKEN found"
29
 
30
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
31
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
@@ -34,126 +20,69 @@ DEFAULT_CUSTOM_WEIGHTS_PATH = "PosterCraft/PosterCraft-v1_RL"
34
  MAX_SEED = np.iinfo(np.int32).max
35
  MAX_IMAGE_SIZE = 2048
36
 
37
- logging.basicConfig(
38
- level=logging.INFO,
39
- format="%(asctime)s - %(levelname)s - %(message)s",
40
- )
41
-
42
- # ------------------------------------------------------------------
43
- # 2. Model Download Function (CPU only)
44
- # ------------------------------------------------------------------
45
- def download_model_weights(target_dir, repo_id, subdir=None):
46
- """Download model weights to specified directory (CPU operation)"""
47
- from huggingface_hub import snapshot_download
48
- import shutil
49
-
50
- if os.path.exists(target_dir):
51
- logging.info(f"Directory {target_dir} already exists, skipping download")
52
- return
53
-
54
- tmp_dir = "hf_temp_download"
55
- os.makedirs(tmp_dir, exist_ok=True)
56
 
 
 
 
57
  try:
58
- download_kwargs = {
59
- "repo_id": repo_id,
60
- "repo_type": "model",
61
- "local_dir": tmp_dir,
62
- "local_dir_use_symlinks": False,
63
- }
64
-
65
- if HF_TOKEN:
66
- download_kwargs["token"] = HF_TOKEN
67
-
68
- if subdir:
69
- download_kwargs["allow_patterns"] = os.path.join(subdir, "**")
70
-
71
- snapshot_download(**download_kwargs)
72
-
73
- src_dir = os.path.join(tmp_dir, subdir) if subdir else tmp_dir
74
-
75
- if os.path.exists(src_dir):
76
- shutil.copytree(src_dir, target_dir)
77
- logging.info(f"Successfully downloaded {repo_id} to {target_dir}")
78
- else:
79
- logging.warning(f"Source directory {src_dir} does not exist")
80
-
81
  except Exception as e:
82
- logging.error(f"Download failed: {e}")
83
- finally:
84
- if os.path.exists(tmp_dir):
85
- shutil.rmtree(tmp_dir)
86
 
87
- # ------------------------------------------------------------------
88
- # 3. Pre-download models (CPU operation)
89
- # ------------------------------------------------------------------
90
- def ensure_models_downloaded():
91
- """Pre-download all models to avoid GPU timeout"""
92
- logging.info("Checking and downloading models if needed...")
93
-
94
- # Download custom weights
95
- custom_weights_local = "local_weights/PosterCraft-v1_RL"
96
- if not os.path.exists(custom_weights_local):
97
- try:
98
- logging.info("Downloading custom Transformer weights...")
99
- download_model_weights(custom_weights_local, DEFAULT_CUSTOM_WEIGHTS_PATH)
100
- except Exception as e:
101
- logging.warning(f"Failed to download custom weights: {e}")
102
 
103
- # Download Qwen model
104
- qwen_local = "local_weights/Qwen3-8B"
105
- if not os.path.exists(qwen_local):
106
- try:
107
- logging.info("Downloading Qwen model...")
108
- download_model_weights(qwen_local, DEFAULT_QWEN_MODEL_PATH)
109
- except Exception as e:
110
- logging.warning(f"Failed to download Qwen model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- logging.info("Model download check completed")
113
-
114
- # Pre-download models at startup (CPU)
115
- ensure_models_downloaded()
116
-
117
- # ------------------------------------------------------------------
118
- # 4. Qwen Recap Agent (ๅŸบไบŽไฝ ็š„ๅŽŸๅง‹้€ป่พ‘)
119
- # ------------------------------------------------------------------
120
- class QwenRecapAgent:
121
- def __init__(self, model_path, max_retries=3, retry_delay=2, device_map="auto"):
122
- self.max_retries = max_retries
123
- self.retry_delay = retry_delay
124
- self.device = device_map
125
-
126
- # ๅผบๅˆถไฝฟ็”จ Fast Tokenizer
127
- try:
128
- self.tokenizer = AutoTokenizer.from_pretrained(
129
- model_path,
130
- token=HF_TOKEN,
131
- use_fast=True, # ๅผบๅˆถไฝฟ็”จ fast tokenizer
132
- trust_remote_code=True
133
- )
134
- logging.info("Successfully loaded fast tokenizer")
135
- except Exception as e:
136
- logging.warning(f"Fast tokenizer failed, falling back to slow: {e}")
137
- self.tokenizer = AutoTokenizer.from_pretrained(
138
- model_path,
139
- token=HF_TOKEN,
140
- use_fast=False,
141
- trust_remote_code=True
142
- )
143
 
144
- model_kwargs = {
145
- "torch_dtype": torch.bfloat16,
146
- "device_map": device_map if device_map == "auto" else None,
147
- "trust_remote_code": True
148
- }
149
- if HF_TOKEN:
150
- model_kwargs["token"] = HF_TOKEN
151
-
152
- self.model = AutoModelForCausalLM.from_pretrained(model_path, **model_kwargs)
153
- if device_map != "auto":
154
- self.model.to(device_map)
155
-
156
- self.prompt_template = """You are an expert poster prompt designer. Your task is to rewrite a user's short poster prompt into a detailed and vivid long-format prompt. Follow these steps carefully:
157
 
158
  **Step 1: Analyze the Core Requirements**
159
  Identify the key elements in the user's prompt. Do not miss any details.
@@ -191,172 +120,95 @@ Elaborate on each core requirement to create a rich description.
191
  ---
192
  **User Prompt:**
193
  {brief_description}"""
194
-
195
- def recap_prompt(self, original_prompt):
196
- full_prompt = self.prompt_template.format(brief_description=original_prompt)
197
- messages = [{"role": "user", "content": full_prompt}]
198
- try:
199
- text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
200
- model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
201
-
202
- with torch.no_grad():
203
- generated_ids = self.model.generate(**model_inputs, max_new_tokens=4096, temperature=0.6)
204
-
205
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
206
- full_response = self.tokenizer.decode(output_ids, skip_special_tokens=True)
207
- final_answer = self._extract_final_answer(full_response)
208
-
209
- if final_answer:
210
- return final_answer.strip()
211
-
212
- logging.info("Qwen returned an empty answer. Using original prompt.")
213
- return original_prompt
214
- except Exception as e:
215
- logging.error(f"Qwen recap failed: {e}. Using original prompt.")
216
- return original_prompt
217
 
218
- def _extract_final_answer(self, full_response):
219
- if "</think>" in full_response:
220
- return full_response.split("</think>")[-1].strip()
221
- if "<think>" not in full_response:
222
- return full_response.strip()
223
- return None
224
-
225
- # ------------------------------------------------------------------
226
- # 5. Poster Generator Class (ๅŸบไบŽไฝ ็š„ๅŽŸๅง‹้€ป่พ‘๏ผŒไฝ†ๅŠ ไธŠ็ผ“ๅญ˜)
227
- # ------------------------------------------------------------------
228
- class PosterGenerator:
229
- def __init__(self, pipeline_path, qwen_model_path, custom_weights_path, device):
230
- self.device = device
231
- self.pipeline_path = pipeline_path
232
- self.qwen_model_path = qwen_model_path
233
- self.custom_weights_path = custom_weights_path
234
 
235
- # ็ผ“ๅญ˜ๅ˜้‡
236
- self.qwen_agent = None
237
- self.pipeline = None
238
 
239
- def _load_qwen_agent(self):
240
- if self.qwen_agent is None:
241
- if not self.qwen_model_path:
242
- return None
243
-
244
- # ๆฃ€ๆŸฅๆœฌๅœฐ่ทฏๅพ„
245
- qwen_local = "local_weights/Qwen3-8B"
246
- model_path = qwen_local if os.path.exists(qwen_local) else self.qwen_model_path
247
-
248
- logging.info(f"Loading Qwen agent from {model_path}")
249
- self.qwen_agent = QwenRecapAgent(model_path=model_path, device_map=str(self.device))
250
- return self.qwen_agent
251
-
252
- def _load_flux_pipeline(self):
253
- if self.pipeline is None:
254
- logging.info("Loading FLUX pipeline...")
255
- self.pipeline = FluxPipeline.from_pretrained(
256
- self.pipeline_path,
257
- torch_dtype=torch.bfloat16,
258
- token=HF_TOKEN
259
- )
260
-
261
- # ๅŠ ่ฝฝ่‡ชๅฎšไน‰ๆƒ้‡
262
- custom_weights_local = "local_weights/PosterCraft-v1_RL"
263
- if os.path.exists(custom_weights_local):
264
- logging.info(f"Loading custom Transformer from directory: {custom_weights_local}")
265
- transformer = FluxTransformer2DModel.from_pretrained(
266
- custom_weights_local,
267
- torch_dtype=torch.bfloat16,
268
- token=HF_TOKEN
269
- )
270
- self.pipeline.transformer = transformer
271
- elif self.custom_weights_path and os.path.exists(self.custom_weights_path):
272
- logging.info(f"Loading custom Transformer from directory: {self.custom_weights_path}")
273
- transformer = FluxTransformer2DModel.from_pretrained(
274
- self.custom_weights_path,
275
- torch_dtype=torch.bfloat16,
276
- token=HF_TOKEN
277
- )
278
- self.pipeline.transformer = transformer
279
-
280
- self.pipeline.to(self.device)
281
- return self.pipeline
282
-
283
- def generate(self, prompt, enable_recap, **kwargs):
284
- final_prompt = prompt
285
- if enable_recap:
286
- qwen_agent = self._load_qwen_agent()
287
- if not qwen_agent:
288
- raise gr.Error("Recap is enabled, but the recap model is not available. Check model path.")
289
- final_prompt = qwen_agent.recap_prompt(prompt)
290
-
291
- pipeline = self._load_flux_pipeline()
292
- generator = torch.Generator(device=self.device).manual_seed(kwargs['seed'])
293
 
294
- with torch.inference_mode():
295
- image = pipeline(
296
- prompt=final_prompt,
297
- generator=generator,
298
- num_inference_steps=kwargs['num_inference_steps'],
299
- guidance_scale=kwargs['guidance_scale'],
300
- width=kwargs['width'],
301
- height=kwargs['height']
302
- ).images[0]
303
 
304
- return image, final_prompt
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
- # ------------------------------------------------------------------
307
- # 6. Main Generation Function (GPU) - ไฟๆŒไฝ ็š„ๅŽŸๅง‹้€ป่พ‘
308
- # ------------------------------------------------------------------
309
  @spaces.GPU(duration=300)
310
- def generate_image_interface(
311
- original_prompt, enable_recap, height, width,
312
- num_inference_steps, guidance_scale, seed_input,
 
 
 
 
 
313
  progress=gr.Progress(track_tqdm=True),
314
  ):
315
- """Generate image using FLUX pipeline"""
316
  if not original_prompt or not original_prompt.strip():
317
- return None, "โŒ Prompt cannot be empty!", ""
318
 
319
  try:
320
  if not HF_TOKEN:
321
- return None, "โŒ Error: HF_TOKEN not found. Please configure authentication.", ""
322
 
323
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
324
-
325
- # ๅ…จๅฑ€็”Ÿๆˆๅ™จๅฎžไพ‹
326
- if not hasattr(generate_image_interface, 'generator'):
327
- generate_image_interface.generator = PosterGenerator(
328
- pipeline_path=DEFAULT_PIPELINE_PATH,
329
- qwen_model_path=DEFAULT_QWEN_MODEL_PATH,
330
- custom_weights_path=DEFAULT_CUSTOM_WEIGHTS_PATH,
331
- device=device
332
- )
333
 
 
334
  actual_seed = int(seed_input) if seed_input and seed_input != -1 else random.randint(1, 2**32 - 1)
335
 
336
- progress(0.1, desc="Starting generation...")
337
 
338
- image, final_prompt = generate_image_interface.generator.generate(
339
- prompt=original_prompt,
340
- enable_recap=enable_recap,
341
- height=int(height),
342
- width=int(width),
343
- num_inference_steps=int(num_inference_steps),
344
- guidance_scale=float(guidance_scale),
345
- seed=actual_seed
346
- )
 
 
 
347
 
348
- status_log = f"โœ… Generation complete! Seed: {actual_seed}"
 
349
  return image, status_log, final_prompt
350
 
351
  except Exception as e:
352
- logging.error(f"Generation failed: {e}")
353
- return None, f"โŒ Generation failed: {str(e)}", ""
354
 
355
- # ------------------------------------------------------------------
356
- # 7. Gradio Interface (ไฟๆŒไฝ ็š„ๅŽŸๅง‹้ฃŽๆ ผ)
357
- # ------------------------------------------------------------------
358
  def create_interface():
359
- """Create Gradio interface"""
360
 
361
  with gr.Blocks(
362
  title="PosterCraft-v1.0",
@@ -371,50 +223,50 @@ def create_interface():
371
  <div class="main-container">
372
  <h1 style="text-align: center; margin-bottom: 20px;">๐ŸŽจ PosterCraft-v1.0</h1>
373
  <p style="text-align: center; color: #666; margin-bottom: 30px;">
374
- Professional poster generation with FLUX.1-dev and custom fine-tuned weights
375
  </p>
376
  </div>
377
  """)
378
 
379
  with gr.Row():
380
- gr.Markdown(f"**Base Pipeline:** `{DEFAULT_PIPELINE_PATH}`")
381
- gr.Markdown(f"**Authentication Status:** {auth_status}")
382
 
383
  gr.HTML("""
384
  <div class="status-box">
385
- <p><strong>โš ๏ธ First generation requires model loading (5-10 minutes). Subsequent generations are much faster!</strong></p>
386
  </div>
387
  """)
388
 
389
  with gr.Row():
390
  with gr.Column(scale=1):
391
- gr.Markdown("### 1. Configuration")
392
  prompt_input = gr.Textbox(
393
- label="Poster Prompt",
394
  lines=3,
395
- placeholder="Enter your poster description...",
396
- value="A vintage travel poster for Paris, featuring the Eiffel Tower at sunset with warm golden lighting"
397
  )
398
  enable_recap_checkbox = gr.Checkbox(
399
- label="Enable Prompt Enhancement (Qwen3-8B)",
400
  value=True,
401
- info="Use AI to enhance and expand your prompt"
402
  )
403
 
404
  with gr.Row():
405
- width_input = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, value=768, step=32)
406
- height_input = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, value=1024, step=32)
407
 
408
- num_inference_steps_input = gr.Slider(label="Inference Steps", minimum=1, maximum=100, value=20, step=1)
409
- guidance_scale_input = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=20.0, value=3.5, step=0.1)
410
- seed_number_input = gr.Number(label="Seed (-1 for random)", value=-1, minimum=-1, step=1)
411
- generate_button = gr.Button("๐ŸŽจ Generate Poster", variant="primary", size="lg")
412
 
413
  with gr.Column(scale=1):
414
- gr.Markdown("### 2. Results")
415
- image_output = gr.Image(label="Generated Poster", type="pil", height=600)
416
- status_output = gr.Textbox(label="Generation Status", lines=2, interactive=False)
417
- recapped_prompt_output = gr.Textbox(label="Enhanced Prompt", lines=5, interactive=False, info="The final prompt used for generation")
418
 
419
  inputs_list = [
420
  prompt_input, enable_recap_checkbox, height_input, width_input,
@@ -422,28 +274,28 @@ def create_interface():
422
  ]
423
  outputs_list = [image_output, status_output, recapped_prompt_output]
424
 
425
- generate_button.click(fn=generate_image_interface, inputs=inputs_list, outputs=outputs_list)
426
 
427
- # Examples
428
  gr.Examples(
429
  examples=[
430
- ["A retro sci-fi movie poster with neon colors and flying cars"],
431
- ["An elegant art deco poster for a luxury hotel"],
432
- ["A minimalist concert poster with bold typography"],
433
- ["A vintage advertisement for organic coffee"],
434
  ],
435
  inputs=[prompt_input]
436
  )
437
 
438
  return demo
439
 
440
- # ------------------------------------------------------------------
441
- # 8. Launch Application
442
- # ------------------------------------------------------------------
443
  if __name__ == "__main__":
444
  demo = create_interface()
445
  demo.launch(
446
  server_name="0.0.0.0",
447
  server_port=7860,
448
  show_api=False
449
- )
 
1
+ # -*- coding: utf-8 -*-
2
  import os
3
  import random
4
  import logging
 
8
  import torch
9
  from huggingface_hub import login, whoami
10
 
11
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
12
+ # ๅ…จๅฑ€้…็ฝฎ
13
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
 
14
  HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  DEFAULT_PIPELINE_PATH = "black-forest-labs/FLUX.1-dev"
17
  DEFAULT_QWEN_MODEL_PATH = "Qwen/Qwen3-8B"
 
20
  MAX_SEED = np.iinfo(np.int32).max
21
  MAX_IMAGE_SIZE = 2048
22
 
23
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # ่ฎค่ฏ็Šถๆ€
26
+ auth_status = "๐Ÿ”ด Not Authenticated"
27
+ if HF_TOKEN:
28
  try:
29
+ login(token=HF_TOKEN, add_to_git_credential=True)
30
+ user_info = whoami(HF_TOKEN)
31
+ auth_status = f"โœ… Authenticated as {user_info['name']}"
32
+ logging.info(f"Successfully authenticated with Hugging Face as {user_info['name']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  except Exception as e:
34
+ logging.error(f"HF authentication failed: {e}")
35
+ auth_status = f"๐Ÿ”ด Authentication Error: {str(e)}"
 
 
36
 
37
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
38
+ # ๅœจ GPU ๅญ่ฟ›็จ‹ import ้˜ถๆฎตๅฐฑๆŠŠๅคงๆจกๅž‹่ฏป่ฟ›ๆ˜พๅญ˜
39
+ # ๅชๅœจ GPU ่ฟ›็จ‹ๆ‰ง่กŒ๏ผŒCPU ไธป่ฟ›็จ‹่ทณ่ฟ‡
40
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
41
+ if spaces.GPU.is_gpu():
42
+ from diffusers import FluxPipeline, FluxTransformer2DModel
43
+ from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
44
 
45
+ print("โฑ๏ธ [GPU init] Loading FLUX pipeline...")
46
+ FLUX_PIPELINE = FluxPipeline.from_pretrained(
47
+ DEFAULT_PIPELINE_PATH,
48
+ torch_dtype=torch.bfloat16,
49
+ token=HF_TOKEN
50
+ ).to("cuda")
51
+
52
+ print("โฑ๏ธ [GPU init] Loading PosterCraft transformer...")
53
+ POSTERCRAFT_TRANSFORMER = FluxTransformer2DModel.from_pretrained(
54
+ DEFAULT_CUSTOM_WEIGHTS_PATH,
55
+ torch_dtype=torch.bfloat16,
56
+ token=HF_TOKEN
57
+ ).to("cuda")
58
+ FLUX_PIPELINE.transformer = POSTERCRAFT_TRANSFORMER
59
+
60
+ print("โฑ๏ธ [GPU init] Loading Qwen model...")
61
+ QWEN_TOKENIZER = AutoTokenizer.from_pretrained(
62
+ DEFAULT_QWEN_MODEL_PATH,
63
+ token=HF_TOKEN,
64
+ trust_remote_code=True,
65
+ use_fast=True
66
+ )
67
+ QWEN_MODEL = AutoModelForCausalLM.from_pretrained(
68
+ DEFAULT_QWEN_MODEL_PATH,
69
+ torch_dtype=torch.bfloat16,
70
+ device_map="auto",
71
+ token=HF_TOKEN,
72
+ trust_remote_code=True,
73
+ )
74
 
75
+ print("โœ… [GPU init] All models loaded successfully!")
76
+
77
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
78
+ # Qwen ๆ็คบ่ฏๅขžๅผบๅ‡ฝๆ•ฐ
79
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
80
+ def enhance_prompt_with_qwen(original_prompt):
81
+ """ไฝฟ็”จ้ข„ๅŠ ่ฝฝ็š„ Qwen ๆจกๅž‹ๅขžๅผบๆ็คบ่ฏ"""
82
+ if not spaces.GPU.is_gpu():
83
+ return original_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ prompt_template = """You are an expert poster prompt designer. Your task is to rewrite a user's short poster prompt into a detailed and vivid long-format prompt. Follow these steps carefully:
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  **Step 1: Analyze the Core Requirements**
88
  Identify the key elements in the user's prompt. Do not miss any details.
 
120
  ---
121
  **User Prompt:**
122
  {brief_description}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ try:
125
+ full_prompt = prompt_template.format(brief_description=original_prompt)
126
+ messages = [{"role": "user", "content": full_prompt}]
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
+ text = QWEN_TOKENIZER.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=False)
129
+ model_inputs = QWEN_TOKENIZER([text], return_tensors="pt").to(QWEN_MODEL.device)
 
130
 
131
+ with torch.no_grad():
132
+ generated_ids = QWEN_MODEL.generate(**model_inputs, max_new_tokens=4096, temperature=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
135
+ full_response = QWEN_TOKENIZER.decode(output_ids, skip_special_tokens=True)
 
 
 
 
 
 
 
136
 
137
+ # ๆๅ–ๆœ€็ปˆ็ญ”ๆกˆ
138
+ if "</think>" in full_response:
139
+ final_answer = full_response.split("</think>")[-1].strip()
140
+ elif "<think>" not in full_response:
141
+ final_answer = full_response.strip()
142
+ else:
143
+ final_answer = original_prompt
144
+
145
+ return final_answer if final_answer else original_prompt
146
+
147
+ except Exception as e:
148
+ logging.error(f"Qwen enhancement failed: {e}")
149
+ return original_prompt
150
 
151
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
152
+ # ไธป่ฆ็”Ÿๆˆๅ‡ฝๆ•ฐ๏ผˆไฝฟ็”จ้ข„ๅŠ ่ฝฝ็š„ๆจกๅž‹๏ผ‰
153
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
154
  @spaces.GPU(duration=300)
155
+ def generate_poster(
156
+ original_prompt,
157
+ enable_recap,
158
+ height,
159
+ width,
160
+ num_inference_steps,
161
+ guidance_scale,
162
+ seed_input,
163
  progress=gr.Progress(track_tqdm=True),
164
  ):
165
+ """ไฝฟ็”จ้ข„ๅŠ ่ฝฝ็š„ๆจกๅž‹็”ŸๆˆๆตทๆŠฅ"""
166
  if not original_prompt or not original_prompt.strip():
167
+ return None, "โŒ ๆ็คบ่ฏไธ่ƒฝไธบ็ฉบ๏ผ", ""
168
 
169
  try:
170
  if not HF_TOKEN:
171
+ return None, "โŒ ้”™่ฏฏ๏ผšๆœชๆ‰พๅˆฐ HF_TOKEN๏ผŒ่ฏท้…็ฝฎ่ฎค่ฏใ€‚", ""
172
 
173
+ progress(0.1, desc="ๅผ€ๅง‹็”Ÿๆˆ...")
174
+
175
+ # ็กฎๅฎšๆœ€็ปˆๆ็คบ่ฏ
176
+ final_prompt = original_prompt
177
+ if enable_recap:
178
+ progress(0.2, desc="ๅขžๅผบๆ็คบ่ฏ...")
179
+ final_prompt = enhance_prompt_with_qwen(original_prompt)
 
 
 
180
 
181
+ # ็กฎๅฎš็งๅญ
182
  actual_seed = int(seed_input) if seed_input and seed_input != -1 else random.randint(1, 2**32 - 1)
183
 
184
+ progress(0.3, desc="็”Ÿๆˆๅ›พๅƒ...")
185
 
186
+ # ไฝฟ็”จ้ข„ๅŠ ่ฝฝ็š„ FLUX ็ฎก้“็”Ÿๆˆๅ›พๅƒ
187
+ generator = torch.Generator("cuda").manual_seed(actual_seed)
188
+
189
+ with torch.inference_mode():
190
+ image = FLUX_PIPELINE(
191
+ prompt=final_prompt,
192
+ generator=generator,
193
+ num_inference_steps=int(num_inference_steps),
194
+ guidance_scale=float(guidance_scale),
195
+ width=int(width),
196
+ height=int(height)
197
+ ).images[0]
198
 
199
+ progress(1.0, desc="ๅฎŒๆˆ๏ผ")
200
+ status_log = f"โœ… ็”ŸๆˆๅฎŒๆˆ๏ผ็งๅญ๏ผš{actual_seed}"
201
  return image, status_log, final_prompt
202
 
203
  except Exception as e:
204
+ logging.error(f"็”Ÿๆˆๅคฑ่ดฅ๏ผš{e}")
205
+ return None, f"โŒ ็”Ÿๆˆๅคฑ่ดฅ๏ผš{str(e)}", ""
206
 
207
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
208
+ # Gradio ็•Œ้ข
209
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
210
  def create_interface():
211
+ """ๅˆ›ๅปบ Gradio ็•Œ้ข"""
212
 
213
  with gr.Blocks(
214
  title="PosterCraft-v1.0",
 
223
  <div class="main-container">
224
  <h1 style="text-align: center; margin-bottom: 20px;">๐ŸŽจ PosterCraft-v1.0</h1>
225
  <p style="text-align: center; color: #666; margin-bottom: 30px;">
226
+ ไธ“ไธšๆตทๆŠฅ็”Ÿๆˆๅทฅๅ…ท๏ผŒๅŸบไบŽ FLUX.1-dev ๅ’Œๅฎšๅˆถๅพฎ่ฐƒๆƒ้‡
227
  </p>
228
  </div>
229
  """)
230
 
231
  with gr.Row():
232
+ gr.Markdown(f"**ๅŸบ็ก€ๆจกๅž‹๏ผš** `{DEFAULT_PIPELINE_PATH}`")
233
+ gr.Markdown(f"**่ฎค่ฏ็Šถๆ€๏ผš** {auth_status}")
234
 
235
  gr.HTML("""
236
  <div class="status-box">
237
+ <p><strong>โšก ้ฆ–ๆฌก็”Ÿๆˆ้œ€่ฆๅŠ ่ฝฝๆจกๅž‹๏ผˆ5-10ๅˆ†้’Ÿ๏ผ‰๏ผŒๅŽ็ปญ็”Ÿๆˆไผš้žๅธธๅฟซ๏ผ</strong></p>
238
  </div>
239
  """)
240
 
241
  with gr.Row():
242
  with gr.Column(scale=1):
243
+ gr.Markdown("### 1. ้…็ฝฎ")
244
  prompt_input = gr.Textbox(
245
+ label="ๆตทๆŠฅๆ็คบ่ฏ",
246
  lines=3,
247
+ placeholder="่พ“ๅ…ฅๆ‚จ็š„ๆตทๆŠฅๆ่ฟฐ...",
248
+ value="ๅคๅค็ง‘ๅนป็”ตๅฝฑๆตทๆŠฅ๏ผŒ้œ“่™น่‰ฒๅฝฉๅ’Œ้ฃž่กŒๆฑฝ่ฝฆ"
249
  )
250
  enable_recap_checkbox = gr.Checkbox(
251
+ label="ๅฏ็”จๆ็คบ่ฏๅขžๅผบ (Qwen3-8B)",
252
  value=True,
253
+ info="ไฝฟ็”จ AI ๅขžๅผบๅ’Œๆ‰ฉๅฑ•ๆ‚จ็š„ๆ็คบ่ฏ"
254
  )
255
 
256
  with gr.Row():
257
+ width_input = gr.Slider(label="ๅฎฝๅบฆ", minimum=256, maximum=MAX_IMAGE_SIZE, value=768, step=32)
258
+ height_input = gr.Slider(label="้ซ˜ๅบฆ", minimum=256, maximum=MAX_IMAGE_SIZE, value=1024, step=32)
259
 
260
+ num_inference_steps_input = gr.Slider(label="ๆŽจ็†ๆญฅๆ•ฐ", minimum=1, maximum=100, value=20, step=1)
261
+ guidance_scale_input = gr.Slider(label="ๅผ•ๅฏผๅผบๅบฆ", minimum=0.0, maximum=20.0, value=3.5, step=0.1)
262
+ seed_number_input = gr.Number(label="็งๅญ (-1 ้šๆœบ)", value=-1, minimum=-1, step=1)
263
+ generate_button = gr.Button("๐ŸŽจ ็”ŸๆˆๆตทๆŠฅ", variant="primary", size="lg")
264
 
265
  with gr.Column(scale=1):
266
+ gr.Markdown("### 2. ็ป“ๆžœ")
267
+ image_output = gr.Image(label="็”Ÿๆˆ็š„ๆตทๆŠฅ", type="pil", height=600)
268
+ status_output = gr.Textbox(label="็”Ÿๆˆ็Šถๆ€", lines=2, interactive=False)
269
+ recapped_prompt_output = gr.Textbox(label="ๅขžๅผบๅŽ็š„ๆ็คบ่ฏ", lines=5, interactive=False, info="็”จไบŽ็”Ÿๆˆ็š„ๆœ€็ปˆๆ็คบ่ฏ")
270
 
271
  inputs_list = [
272
  prompt_input, enable_recap_checkbox, height_input, width_input,
 
274
  ]
275
  outputs_list = [image_output, status_output, recapped_prompt_output]
276
 
277
+ generate_button.click(fn=generate_poster, inputs=inputs_list, outputs=outputs_list)
278
 
279
+ # ็คบไพ‹
280
  gr.Examples(
281
  examples=[
282
+ ["ๅคๅค็ง‘ๅนป็”ตๅฝฑๆตทๆŠฅ๏ผŒ้œ“่™น่‰ฒๅฝฉๅ’Œ้ฃž่กŒๆฑฝ่ฝฆ"],
283
+ ["ไผ˜้›…็š„่ฃ…้ฅฐ่‰บๆœฏ้ฃŽๆ ผ่ฑชๅŽ้…’ๅบ—ๆตทๆŠฅ"],
284
+ ["็ฎ€็บฆ้ŸณไนไผšๆตทๆŠฅ๏ผŒ็ฒ—ไฝ“ๅญ—ไฝ“"],
285
+ ["ๆœ‰ๆœบๅ’–ๅ•ก็š„ๅคๅคๅนฟๅ‘Š"],
286
  ],
287
  inputs=[prompt_input]
288
  )
289
 
290
  return demo
291
 
292
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
293
+ # ๅฏๅŠจๅบ”็”จ
294
+ # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
295
  if __name__ == "__main__":
296
  demo = create_interface()
297
  demo.launch(
298
  server_name="0.0.0.0",
299
  server_port=7860,
300
  show_api=False
301
+ )