Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,7 @@ import os
|
|
5 |
import time
|
6 |
import torch
|
7 |
from diffusers import FluxPipeline
|
|
|
8 |
|
9 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
print(f"Using device: {DEVICE}")
|
@@ -16,8 +17,103 @@ DEFAULT_NUM_INFERENCE_STEPS = 15
|
|
16 |
DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
17 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
18 |
|
19 |
-
# Cache for the
|
20 |
CACHED_PIPE = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def load_bnb_4bit_pipeline():
|
23 |
"""Load the 4-bit quantized pipeline"""
|
@@ -45,12 +141,19 @@ def load_bnb_4bit_pipeline():
|
|
45 |
raise
|
46 |
|
47 |
@spaces.GPU(duration=240)
|
48 |
-
def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
|
49 |
-
"""Generate image using 4-bit quantized model"""
|
50 |
if not prompt:
|
51 |
-
return None, "Please enter a prompt."
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
try:
|
56 |
# Load the 4-bit pipeline
|
@@ -58,7 +161,7 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
|
|
58 |
|
59 |
# Set up generation parameters
|
60 |
pipe_kwargs = {
|
61 |
-
"prompt":
|
62 |
"height": DEFAULT_HEIGHT,
|
63 |
"width": DEFAULT_WIDTH,
|
64 |
"guidance_scale": DEFAULT_GUIDANCE_SCALE,
|
@@ -70,7 +173,7 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
|
|
70 |
seed = random.getrandbits(64)
|
71 |
print(f"Using seed: {seed}")
|
72 |
|
73 |
-
progress(0.
|
74 |
|
75 |
# Generate image
|
76 |
gen_start_time = time.time()
|
@@ -81,19 +184,29 @@ def generate_image(prompt, progress=gr.Progress(track_tqdm=True)):
|
|
81 |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
|
82 |
print(f"Memory reserved: {mem_reserved:.2f} GB")
|
83 |
|
84 |
-
|
|
|
|
|
|
|
|
|
85 |
|
86 |
except Exception as e:
|
87 |
print(f"Error during generation: {e}")
|
88 |
-
return None, f"Error: {e}"
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Create Gradio interface
|
91 |
-
with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
|
92 |
gr.HTML(
|
93 |
"""
|
94 |
<div style='text-align: center; margin-bottom: 20px;'>
|
95 |
-
<h1>FLUXllama</h1>
|
96 |
-
<p>FLUX.1-dev 4-bit Quantized Version</p>
|
97 |
</div>
|
98 |
"""
|
99 |
)
|
@@ -112,14 +225,31 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
|
|
112 |
"""
|
113 |
)
|
114 |
|
115 |
-
with gr.
|
116 |
prompt_input = gr.Textbox(
|
117 |
label="Enter your prompt",
|
118 |
placeholder="e.g., A photorealistic portrait of an astronaut on Mars",
|
119 |
-
lines=
|
120 |
-
scale=4
|
121 |
)
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
output_image = gr.Image(
|
125 |
label="Generated Image (4-bit Quantized)",
|
@@ -135,16 +265,28 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
|
|
135 |
|
136 |
# Connect components
|
137 |
generate_button.click(
|
138 |
-
fn=generate_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
inputs=[prompt_input],
|
140 |
-
outputs=[
|
141 |
)
|
142 |
|
143 |
-
# Enter key to submit
|
144 |
prompt_input.submit(
|
145 |
fn=generate_image,
|
146 |
-
inputs=[prompt_input],
|
147 |
-
outputs=[output_image, status_text]
|
148 |
)
|
149 |
|
150 |
# Example prompts
|
@@ -161,6 +303,26 @@ with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo:
|
|
161 |
],
|
162 |
inputs=prompt_input
|
163 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
if __name__ == "__main__":
|
166 |
demo.launch(share=True)
|
|
|
5 |
import time
|
6 |
import torch
|
7 |
from diffusers import FluxPipeline
|
8 |
+
from transformers import pipeline
|
9 |
|
10 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
print(f"Using device: {DEVICE}")
|
|
|
17 |
DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
18 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
19 |
|
20 |
+
# Cache for the pipelines
|
21 |
CACHED_PIPE = None
|
22 |
+
CACHED_LLM_PIPE = None
|
23 |
+
|
24 |
+
def load_llm_pipeline():
|
25 |
+
"""Load the LLM pipeline for prompt enhancement"""
|
26 |
+
global CACHED_LLM_PIPE
|
27 |
+
if CACHED_LLM_PIPE is not None:
|
28 |
+
return CACHED_LLM_PIPE
|
29 |
+
|
30 |
+
print("Loading LLM pipeline for prompt enhancement...")
|
31 |
+
try:
|
32 |
+
# Note: Using a smaller model that's actually available
|
33 |
+
# You can replace this with "openai/gpt-oss-120b" if you have access
|
34 |
+
llm_pipe = pipeline(
|
35 |
+
"text-generation",
|
36 |
+
model="microsoft/Phi-3-mini-4k-instruct", # Alternative smaller model
|
37 |
+
torch_dtype=torch.bfloat16,
|
38 |
+
device_map="auto"
|
39 |
+
)
|
40 |
+
CACHED_LLM_PIPE = llm_pipe
|
41 |
+
print("LLM pipeline loaded successfully")
|
42 |
+
return llm_pipe
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error loading LLM pipeline: {e}")
|
45 |
+
# Fallback to a simpler model if the main one fails
|
46 |
+
try:
|
47 |
+
llm_pipe = pipeline(
|
48 |
+
"text-generation",
|
49 |
+
model="gpt2", # Fallback to GPT-2
|
50 |
+
device_map="auto"
|
51 |
+
)
|
52 |
+
CACHED_LLM_PIPE = llm_pipe
|
53 |
+
print("Loaded fallback LLM pipeline (GPT-2)")
|
54 |
+
return llm_pipe
|
55 |
+
except Exception as e2:
|
56 |
+
print(f"Error loading fallback LLM pipeline: {e2}")
|
57 |
+
return None
|
58 |
+
|
59 |
+
def enhance_prompt(prompt, progress=gr.Progress()):
|
60 |
+
"""Enhance the prompt using LLM"""
|
61 |
+
if not prompt:
|
62 |
+
return prompt, "Please enter a prompt first."
|
63 |
+
|
64 |
+
progress(0.3, desc="Enhancing prompt with AI...")
|
65 |
+
|
66 |
+
try:
|
67 |
+
llm_pipe = load_llm_pipeline()
|
68 |
+
if llm_pipe is None:
|
69 |
+
return prompt, "LLM pipeline not available, using original prompt."
|
70 |
+
|
71 |
+
# Create enhancement prompt
|
72 |
+
messages = [
|
73 |
+
{
|
74 |
+
"role": "system",
|
75 |
+
"content": "You are a helpful assistant that enhances image generation prompts. Make prompts more detailed, artistic, and visually descriptive while keeping the core concept. Add details about lighting, style, colors, mood, and composition. Keep the enhanced prompt under 200 words."
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"role": "user",
|
79 |
+
"content": f"Enhance this image generation prompt, making it more detailed and artistic: '{prompt}'"
|
80 |
+
}
|
81 |
+
]
|
82 |
+
|
83 |
+
# Generate enhanced prompt
|
84 |
+
result = llm_pipe(
|
85 |
+
messages,
|
86 |
+
max_new_tokens=200,
|
87 |
+
temperature=0.7,
|
88 |
+
do_sample=True,
|
89 |
+
top_p=0.9
|
90 |
+
)
|
91 |
+
|
92 |
+
# Extract the enhanced prompt from the response
|
93 |
+
if isinstance(result, list) and len(result) > 0:
|
94 |
+
enhanced = result[0].get('generated_text', '')
|
95 |
+
# Extract only the assistant's response
|
96 |
+
if isinstance(enhanced, list):
|
97 |
+
for msg in enhanced:
|
98 |
+
if msg.get('role') == 'assistant':
|
99 |
+
enhanced = msg.get('content', prompt)
|
100 |
+
break
|
101 |
+
elif isinstance(enhanced, str):
|
102 |
+
# Clean up the response if needed
|
103 |
+
enhanced = enhanced.strip()
|
104 |
+
if enhanced.startswith("Enhanced prompt:"):
|
105 |
+
enhanced = enhanced.replace("Enhanced prompt:", "").strip()
|
106 |
+
|
107 |
+
if enhanced and enhanced != prompt:
|
108 |
+
return enhanced, "Prompt enhanced successfully!"
|
109 |
+
else:
|
110 |
+
return prompt, "Using original prompt."
|
111 |
+
else:
|
112 |
+
return prompt, "Enhancement failed, using original prompt."
|
113 |
+
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error during prompt enhancement: {e}")
|
116 |
+
return prompt, f"Enhancement error: {e}. Using original prompt."
|
117 |
|
118 |
def load_bnb_4bit_pipeline():
|
119 |
"""Load the 4-bit quantized pipeline"""
|
|
|
141 |
raise
|
142 |
|
143 |
@spaces.GPU(duration=240)
|
144 |
+
def generate_image(prompt, use_enhancement=False, progress=gr.Progress(track_tqdm=True)):
|
145 |
+
"""Generate image using 4-bit quantized model with optional prompt enhancement"""
|
146 |
if not prompt:
|
147 |
+
return None, prompt, "Please enter a prompt."
|
148 |
|
149 |
+
enhanced_prompt = prompt
|
150 |
+
enhancement_status = ""
|
151 |
+
|
152 |
+
# Enhance prompt if requested
|
153 |
+
if use_enhancement:
|
154 |
+
enhanced_prompt, enhancement_status = enhance_prompt(prompt, progress)
|
155 |
+
|
156 |
+
progress(0.5, desc="Loading 4-bit quantized model...")
|
157 |
|
158 |
try:
|
159 |
# Load the 4-bit pipeline
|
|
|
161 |
|
162 |
# Set up generation parameters
|
163 |
pipe_kwargs = {
|
164 |
+
"prompt": enhanced_prompt,
|
165 |
"height": DEFAULT_HEIGHT,
|
166 |
"width": DEFAULT_WIDTH,
|
167 |
"guidance_scale": DEFAULT_GUIDANCE_SCALE,
|
|
|
173 |
seed = random.getrandbits(64)
|
174 |
print(f"Using seed: {seed}")
|
175 |
|
176 |
+
progress(0.7, desc="Generating image...")
|
177 |
|
178 |
# Generate image
|
179 |
gen_start_time = time.time()
|
|
|
184 |
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0
|
185 |
print(f"Memory reserved: {mem_reserved:.2f} GB")
|
186 |
|
187 |
+
status_msg = f"Generation complete! (Seed: {seed})"
|
188 |
+
if enhancement_status:
|
189 |
+
status_msg = f"{enhancement_status} | {status_msg}"
|
190 |
+
|
191 |
+
return image, enhanced_prompt, status_msg
|
192 |
|
193 |
except Exception as e:
|
194 |
print(f"Error during generation: {e}")
|
195 |
+
return None, enhanced_prompt, f"Error: {e}"
|
196 |
+
|
197 |
+
@spaces.GPU(duration=60)
|
198 |
+
def enhance_only(prompt, progress=gr.Progress()):
|
199 |
+
"""Only enhance the prompt without generating an image"""
|
200 |
+
enhanced_prompt, status = enhance_prompt(prompt, progress)
|
201 |
+
return enhanced_prompt, status
|
202 |
|
203 |
# Create Gradio interface
|
204 |
+
with gr.Blocks(title="FLUXllama Enhanced", theme=gr.themes.Soft()) as demo:
|
205 |
gr.HTML(
|
206 |
"""
|
207 |
<div style='text-align: center; margin-bottom: 20px;'>
|
208 |
+
<h1>FLUXllama Enhanced</h1>
|
209 |
+
<p>FLUX.1-dev 4-bit Quantized Version with AI Prompt Enhancement</p>
|
210 |
</div>
|
211 |
"""
|
212 |
)
|
|
|
225 |
"""
|
226 |
)
|
227 |
|
228 |
+
with gr.Column():
|
229 |
prompt_input = gr.Textbox(
|
230 |
label="Enter your prompt",
|
231 |
placeholder="e.g., A photorealistic portrait of an astronaut on Mars",
|
232 |
+
lines=3
|
|
|
233 |
)
|
234 |
+
|
235 |
+
with gr.Row():
|
236 |
+
enhance_checkbox = gr.Checkbox(
|
237 |
+
label="π¨ Use AI Prompt Enhancement",
|
238 |
+
value=False,
|
239 |
+
info="Automatically enhance your prompt for better results"
|
240 |
+
)
|
241 |
+
enhance_only_button = gr.Button("β¨ Enhance Only", variant="secondary", scale=1)
|
242 |
+
|
243 |
+
enhanced_prompt_display = gr.Textbox(
|
244 |
+
label="Enhanced Prompt (will appear after enhancement)",
|
245 |
+
lines=3,
|
246 |
+
interactive=False,
|
247 |
+
visible=True
|
248 |
+
)
|
249 |
+
|
250 |
+
with gr.Row():
|
251 |
+
generate_button = gr.Button("π Generate Image", variant="primary", scale=2)
|
252 |
+
generate_enhanced_button = gr.Button("π¨ Enhance & Generate", variant="primary", scale=2)
|
253 |
|
254 |
output_image = gr.Image(
|
255 |
label="Generated Image (4-bit Quantized)",
|
|
|
265 |
|
266 |
# Connect components
|
267 |
generate_button.click(
|
268 |
+
fn=lambda p: generate_image(p, use_enhancement=False),
|
269 |
+
inputs=[prompt_input],
|
270 |
+
outputs=[output_image, enhanced_prompt_display, status_text]
|
271 |
+
)
|
272 |
+
|
273 |
+
generate_enhanced_button.click(
|
274 |
+
fn=lambda p: generate_image(p, use_enhancement=True),
|
275 |
+
inputs=[prompt_input],
|
276 |
+
outputs=[output_image, enhanced_prompt_display, status_text]
|
277 |
+
)
|
278 |
+
|
279 |
+
enhance_only_button.click(
|
280 |
+
fn=enhance_only,
|
281 |
inputs=[prompt_input],
|
282 |
+
outputs=[enhanced_prompt_display, status_text]
|
283 |
)
|
284 |
|
285 |
+
# Enter key to submit (with enhancement checkbox consideration)
|
286 |
prompt_input.submit(
|
287 |
fn=generate_image,
|
288 |
+
inputs=[prompt_input, enhance_checkbox],
|
289 |
+
outputs=[output_image, enhanced_prompt_display, status_text]
|
290 |
)
|
291 |
|
292 |
# Example prompts
|
|
|
303 |
],
|
304 |
inputs=prompt_input
|
305 |
)
|
306 |
+
|
307 |
+
gr.HTML(
|
308 |
+
"""
|
309 |
+
<div style='text-align: center; margin-top: 20px; padding: 20px; background-color: #f0f0f0; border-radius: 10px;'>
|
310 |
+
<h3>β¨ Prompt Enhancement Feature</h3>
|
311 |
+
<p>This app now includes AI-powered prompt enhancement! The enhancement feature will:</p>
|
312 |
+
<ul style='text-align: left; display: inline-block;'>
|
313 |
+
<li>Add artistic details and visual descriptions</li>
|
314 |
+
<li>Specify lighting, mood, and atmosphere</li>
|
315 |
+
<li>Include style and composition elements</li>
|
316 |
+
<li>Make your prompts more effective for image generation</li>
|
317 |
+
</ul>
|
318 |
+
<p><strong>How to use:</strong></p>
|
319 |
+
<p>1. Enter a simple prompt</p>
|
320 |
+
<p>2. Click "β¨ Enhance Only" to preview the enhanced version</p>
|
321 |
+
<p>3. Click "π¨ Enhance & Generate" to enhance and generate in one step</p>
|
322 |
+
<p>4. Or check the enhancement checkbox and click Generate</p>
|
323 |
+
</div>
|
324 |
+
"""
|
325 |
+
)
|
326 |
|
327 |
if __name__ == "__main__":
|
328 |
demo.launch(share=True)
|