ginipick commited on
Commit
4d79b21
Β·
verified Β·
1 Parent(s): 50b7ccb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -21
app.py CHANGED
@@ -5,6 +5,7 @@ import os
5
  import time
6
  import torch
7
  from diffusers import FluxPipeline
 
8
 
9
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
  print(f"Using device: {DEVICE}")
@@ -16,8 +17,103 @@ DEFAULT_NUM_INFERENCE_STEPS = 15
16
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
17
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
18
 
19
- # Cache for the pipeline
20
  CACHED_PIPE = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def load_bnb_4bit_pipeline():
23
  """Load the 4-bit quantized pipeline"""
@@ -45,12 +141,19 @@ def load_bnb_4bit_pipeline():
45
  raise
46
 
47
  @spaces.GPU(duration=240)
48
- def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
49
- """Generate image using 4-bit quantized model"""
50
  if not prompt:
51
- return None, "Please enter a prompt."
52
 
53
- progress(0.2, desc="Loading 4-bit quantized model...")
 
 
 
 
 
 
 
54
 
55
  try:
56
  # Load the 4-bit pipeline
@@ -58,7 +161,7 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
58
 
59
  # Set up generation parameters
60
  pipe_kwargs = {
61
- "prompt": prompt,
62
  "height": DEFAULT_HEIGHT,
63
  "width": DEFAULT_WIDTH,
64
  "guidance_scale": DEFAULT_GUIDANCE_SCALE,
@@ -70,7 +173,7 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
70
  seed = random.getrandbits(64)
71
  print(f"Using seed: {seed}")
72
 
73
- progress(0.5, desc="Generating image...")
74
 
75
  # Generate image
76
  gen_start_time = time.time()
@@ -81,19 +184,29 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
81
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
82
  print(f"Memory reserved: {mem_reserved:.2f} GB")
83
 
84
- return image, f"Generation complete! (Seed: {seed})"
 
 
 
 
85
 
86
  except Exception as e:
87
  print(f"Error during generation: {e}")
88
- return None, f"Error: {e}"
 
 
 
 
 
 
89
 
90
  # Create Gradio interface
91
- with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
92
  gr.HTML(
93
  """
94
  <div style='text-align: center; margin-bottom: 20px;'>
95
- <h1>FLUXllama</h1>
96
- <p>FLUX.1-dev 4-bit Quantized Version</p>
97
  </div>
98
  """
99
  )
@@ -112,14 +225,31 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
112
  """
113
  )
114
 
115
- with gr.Row():
116
  prompt_input = gr.Textbox(
117
  label="Enter your prompt",
118
  placeholder="e.g., A photorealistic portrait of an astronaut on Mars",
119
- lines=2,
120
- scale=4
121
  )
122
- generate_button = gr.Button("Generate", variant="primary", scale=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  output_image = gr.Image(
125
  label="Generated Image (4-bit Quantized)",
@@ -135,16 +265,28 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
135
 
136
  # Connect components
137
  generate_button.click(
138
- fn=generate_image,
 
 
 
 
 
 
 
 
 
 
 
 
139
  inputs=[prompt_input],
140
- outputs=[output_image, status_text]
141
  )
142
 
143
- # Enter key to submit
144
  prompt_input.submit(
145
  fn=generate_image,
146
- inputs=[prompt_input],
147
- outputs=[output_image, status_text]
148
  )
149
 
150
  # Example prompts
@@ -161,6 +303,26 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
161
  ],
162
  inputs=prompt_input
163
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  if __name__ == "__main__":
166
  demo.launch(share=True)
 
5
  import time
6
  import torch
7
  from diffusers import FluxPipeline
8
+ from transformers import pipeline
9
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"Using device: {DEVICE}")
 
17
  DEFAULT_MAX_SEQUENCE_LENGTH = 512
18
  HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
19
 
20
+ # Cache for the pipelines
21
  CACHED_PIPE = None
22
+ CACHED_LLM_PIPE = None
23
+
24
+ def load_llm_pipeline():
25
+ """Load the LLM pipeline for prompt enhancement"""
26
+ global CACHED_LLM_PIPE
27
+ if CACHED_LLM_PIPE is not None:
28
+ return CACHED_LLM_PIPE
29
+
30
+ print("Loading LLM pipeline for prompt enhancement...")
31
+ try:
32
+ # Note: Using a smaller model that's actually available
33
+ # You can replace this with "openai/gpt-oss-120b" if you have access
34
+ llm_pipe = pipeline(
35
+ "text-generation",
36
+ model="microsoft/Phi-3-mini-4k-instruct", # Alternative smaller model
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto"
39
+ )
40
+ CACHED_LLM_PIPE = llm_pipe
41
+ print("LLM pipeline loaded successfully")
42
+ return llm_pipe
43
+ except Exception as e:
44
+ print(f"Error loading LLM pipeline: {e}")
45
+ # Fallback to a simpler model if the main one fails
46
+ try:
47
+ llm_pipe = pipeline(
48
+ "text-generation",
49
+ model="gpt2", # Fallback to GPT-2
50
+ device_map="auto"
51
+ )
52
+ CACHED_LLM_PIPE = llm_pipe
53
+ print("Loaded fallback LLM pipeline (GPT-2)")
54
+ return llm_pipe
55
+ except Exception as e2:
56
+ print(f"Error loading fallback LLM pipeline: {e2}")
57
+ return None
58
+
59
+ def enhance_prompt(prompt, progress=gr.Progress()):
60
+ """Enhance the prompt using LLM"""
61
+ if not prompt:
62
+ return prompt, "Please enter a prompt first."
63
+
64
+ progress(0.3, desc="Enhancing prompt with AI...")
65
+
66
+ try:
67
+ llm_pipe = load_llm_pipeline()
68
+ if llm_pipe is None:
69
+ return prompt, "LLM pipeline not available, using original prompt."
70
+
71
+ # Create enhancement prompt
72
+ messages = [
73
+ {
74
+ "role": "system",
75
+ "content": "You are a helpful assistant that enhances image generation prompts. Make prompts more detailed, artistic, and visually descriptive while keeping the core concept. Add details about lighting, style, colors, mood, and composition. Keep the enhanced prompt under 200 words."
76
+ },
77
+ {
78
+ "role": "user",
79
+ "content": f"Enhance this image generation prompt, making it more detailed and artistic: '{prompt}'"
80
+ }
81
+ ]
82
+
83
+ # Generate enhanced prompt
84
+ result = llm_pipe(
85
+ messages,
86
+ max_new_tokens=200,
87
+ temperature=0.7,
88
+ do_sample=True,
89
+ top_p=0.9
90
+ )
91
+
92
+ # Extract the enhanced prompt from the response
93
+ if isinstance(result, list) and len(result) > 0:
94
+ enhanced = result[0].get('generated_text', '')
95
+ # Extract only the assistant's response
96
+ if isinstance(enhanced, list):
97
+ for msg in enhanced:
98
+ if msg.get('role') == 'assistant':
99
+ enhanced = msg.get('content', prompt)
100
+ break
101
+ elif isinstance(enhanced, str):
102
+ # Clean up the response if needed
103
+ enhanced = enhanced.strip()
104
+ if enhanced.startswith("Enhanced prompt:"):
105
+ enhanced = enhanced.replace("Enhanced prompt:", "").strip()
106
+
107
+ if enhanced and enhanced != prompt:
108
+ return enhanced, "Prompt enhanced successfully!"
109
+ else:
110
+ return prompt, "Using original prompt."
111
+ else:
112
+ return prompt, "Enhancement failed, using original prompt."
113
+
114
+ except Exception as e:
115
+ print(f"Error during prompt enhancement: {e}")
116
+ return prompt, f"Enhancement error: {e}. Using original prompt."
117
 
118
  def load_bnb_4bit_pipeline():
119
  """Load the 4-bit quantized pipeline"""
 
141
  raise
142
 
143
  @spaces.GPU(duration=240)
144
+ def generate_image(prompt, use_enhancement=False, progress=gr.Progress(track_tqdm=True)):
145
+ """Generate image using 4-bit quantized model with optional prompt enhancement"""
146
  if not prompt:
147
+ return None, prompt, "Please enter a prompt."
148
 
149
+ enhanced_prompt = prompt
150
+ enhancement_status = ""
151
+
152
+ # Enhance prompt if requested
153
+ if use_enhancement:
154
+ enhanced_prompt, enhancement_status = enhance_prompt(prompt, progress)
155
+
156
+ progress(0.5, desc="Loading 4-bit quantized model...")
157
 
158
  try:
159
  # Load the 4-bit pipeline
 
161
 
162
  # Set up generation parameters
163
  pipe_kwargs = {
164
+ "prompt": enhanced_prompt,
165
  "height": DEFAULT_HEIGHT,
166
  "width": DEFAULT_WIDTH,
167
  "guidance_scale": DEFAULT_GUIDANCE_SCALE,
 
173
  seed = random.getrandbits(64)
174
  print(f"Using seed: {seed}")
175
 
176
+ progress(0.7, desc="Generating image...")
177
 
178
  # Generate image
179
  gen_start_time = time.time()
 
184
  mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
185
  print(f"Memory reserved: {mem_reserved:.2f} GB")
186
 
187
+ status_msg = f"Generation complete! (Seed: {seed})"
188
+ if enhancement_status:
189
+ status_msg = f"{enhancement_status} | {status_msg}"
190
+
191
+ return image, enhanced_prompt, status_msg
192
 
193
  except Exception as e:
194
  print(f"Error during generation: {e}")
195
+ return None, enhanced_prompt, f"Error: {e}"
196
+
197
+ @spaces.GPU(duration=60)
198
+ def enhance_only(prompt, progress=gr.Progress()):
199
+ """Only enhance the prompt without generating an image"""
200
+ enhanced_prompt, status = enhance_prompt(prompt, progress)
201
+ return enhanced_prompt, status
202
 
203
  # Create Gradio interface
204
+ with gr.Blocks(title="FLUXllama Enhanced", theme=gr.themes.Soft()) as demo:
205
  gr.HTML(
206
  """
207
  <div style='text-align: center; margin-bottom: 20px;'>
208
+ <h1>FLUXllama Enhanced</h1>
209
+ <p>FLUX.1-dev 4-bit Quantized Version with AI Prompt Enhancement</p>
210
  </div>
211
  """
212
  )
 
225
  """
226
  )
227
 
228
+ with gr.Column():
229
  prompt_input = gr.Textbox(
230
  label="Enter your prompt",
231
  placeholder="e.g., A photorealistic portrait of an astronaut on Mars",
232
+ lines=3
 
233
  )
234
+
235
+ with gr.Row():
236
+ enhance_checkbox = gr.Checkbox(
237
+ label="🎨 Use AI Prompt Enhancement",
238
+ value=False,
239
+ info="Automatically enhance your prompt for better results"
240
+ )
241
+ enhance_only_button = gr.Button("✨ Enhance Only", variant="secondary", scale=1)
242
+
243
+ enhanced_prompt_display = gr.Textbox(
244
+ label="Enhanced Prompt (will appear after enhancement)",
245
+ lines=3,
246
+ interactive=False,
247
+ visible=True
248
+ )
249
+
250
+ with gr.Row():
251
+ generate_button = gr.Button("πŸš€ Generate Image", variant="primary", scale=2)
252
+ generate_enhanced_button = gr.Button("🎨 Enhance & Generate", variant="primary", scale=2)
253
 
254
  output_image = gr.Image(
255
  label="Generated Image (4-bit Quantized)",
 
265
 
266
  # Connect components
267
  generate_button.click(
268
+ fn=lambda p: generate_image(p, use_enhancement=False),
269
+ inputs=[prompt_input],
270
+ outputs=[output_image, enhanced_prompt_display, status_text]
271
+ )
272
+
273
+ generate_enhanced_button.click(
274
+ fn=lambda p: generate_image(p, use_enhancement=True),
275
+ inputs=[prompt_input],
276
+ outputs=[output_image, enhanced_prompt_display, status_text]
277
+ )
278
+
279
+ enhance_only_button.click(
280
+ fn=enhance_only,
281
  inputs=[prompt_input],
282
+ outputs=[enhanced_prompt_display, status_text]
283
  )
284
 
285
+ # Enter key to submit (with enhancement checkbox consideration)
286
  prompt_input.submit(
287
  fn=generate_image,
288
+ inputs=[prompt_input, enhance_checkbox],
289
+ outputs=[output_image, enhanced_prompt_display, status_text]
290
  )
291
 
292
  # Example prompts
 
303
  ],
304
  inputs=prompt_input
305
  )
306
+
307
+ gr.HTML(
308
+ """
309
+ <div style='text-align: center; margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;'>
310
+ <h3>✨ Prompt Enhancement Feature</h3>
311
+ <p>This app now includes AI-powered prompt enhancement! The enhancement feature will:</p>
312
+ <ul style='text-align: left; display: inline-block;'>
313
+ <li>Add artistic details and visual descriptions</li>
314
+ <li>Specify lighting, mood, and atmosphere</li>
315
+ <li>Include style and composition elements</li>
316
+ <li>Make your prompts more effective for image generation</li>
317
+ </ul>
318
+ <p><strong>How to use:</strong></p>
319
+ <p>1. Enter a simple prompt</p>
320
+ <p>2. Click "✨ Enhance Only" to preview the enhanced version</p>
321
+ <p>3. Click "🎨 Enhance & Generate" to enhance and generate in one step</p>
322
+ <p>4. Or check the enhancement checkbox and click Generate</p>
323
+ </div>
324
+ """
325
+ )
326
 
327
  if __name__ == "__main__":
328
  demo.launch(share=True)