fantos commited on
Commit
95db5c0
ยท
verified ยท
1 Parent(s): 0d158d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -333
app.py CHANGED
@@ -2,50 +2,20 @@ import spaces
2
  import argparse
3
  import os
4
  import time
5
- import gc
6
  from os import path
7
- import shutil
8
- from datetime import datetime
9
- import traceback
10
  from safetensors.torch import load_file
11
  from huggingface_hub import hf_hub_download
12
- import gradio as gr
13
- import torch
14
- from diffusers import FluxPipeline
15
- from diffusers.pipelines.stable_diffusion import safety_checker
16
- from PIL import Image
17
 
18
- # Setup and initialization code
19
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
20
-
21
  os.environ["TRANSFORMERS_CACHE"] = cache_path
22
  os.environ["HF_HUB_CACHE"] = cache_path
23
  os.environ["HF_HOME"] = cache_path
24
 
25
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์„ค์ • ์ตœ์ ํ™”
26
- torch.backends.cuda.matmul.allow_tf32 = True
27
- torch.backends.cudnn.benchmark = True # ๋ฐ˜๋ณต์ ์ธ ํฌ๊ธฐ์˜ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์„ฑ๋Šฅ ํ–ฅ์ƒ
28
 
29
- def filter_prompt(prompt):
30
- # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ๋ชฉ๋ก
31
- inappropriate_keywords = [
32
- # ์Œ๋ž€/์„ฑ์  ํ‚ค์›Œ๋“œ
33
- "nude", "naked", "nsfw", "porn", "sex", "explicit", "adult", "xxx",
34
- "erotic", "sensual", "seductive", "provocative", "intimate",
35
- # ํญ๋ ฅ์  ํ‚ค์›Œ๋“œ
36
- "violence", "gore", "blood", "death", "kill", "murder", "torture",
37
- # ๊ธฐํƒ€ ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ
38
- "drug", "suicide", "abuse", "hate", "discrimination"
39
- ]
40
-
41
- prompt_lower = prompt.lower()
42
-
43
- # ๋ถ€์ ์ ˆํ•œ ํ‚ค์›Œ๋“œ ์ฒดํฌ
44
- for keyword in inappropriate_keywords:
45
- if keyword in prompt_lower:
46
- return False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค."
47
-
48
- return True, prompt
49
 
50
  class timer:
51
  def __init__(self, method_name="timed process"):
@@ -57,320 +27,172 @@ class timer:
57
  end = time.time()
58
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
59
 
60
- # ๊ธ€๋กœ๋ฒŒ ๋ณ€์ˆ˜๋กœ ํŒŒ์ดํ”„๋ผ์ธ ์„ ์–ธ
61
- pipe = None
62
 
63
- # ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜ (์ง€์—ฐ ๋กœ๋”ฉ)
64
- def initialize_model():
65
- global pipe
66
-
67
- # ์ด๋ฏธ ๋กœ๋“œ๋œ ๊ฒฝ์šฐ ๋‹ค์‹œ ๋กœ๋“œํ•˜์ง€ ์•Š์Œ
68
- if pipe is not None:
69
- return True
70
-
71
- try:
72
- if not path.exists(cache_path):
73
- os.makedirs(cache_path, exist_ok=True)
74
-
75
- # ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด๋ฅผ ์œ„ํ•œ ๊ฐ€๋น„์ง€ ์ปฌ๋ ‰์…˜ ์‹คํ–‰
76
- gc.collect()
77
- torch.cuda.empty_cache()
78
-
79
- with timer("๋ชจ๋ธ ๋กœ๋”ฉ"):
80
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
81
- lora_path = hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors")
82
- pipe.load_lora_weights(lora_path)
83
- pipe.fuse_lora(lora_scale=0.125)
84
-
85
- # ์ฃผ์˜: ์—ฌ๊ธฐ์„œ device๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ง€์ • (๋ชจ๋“  ์ปดํฌ๋„ŒํŠธ์— ์ ์šฉ)
86
- pipe.to(device="cuda", dtype=torch.bfloat16)
87
-
88
- # ์•ˆ์ „ ๊ฒ€์‚ฌ๊ธฐ ์ถ”๊ฐ€ ๋ฐ ์˜ฌ๋ฐ”๋ฅธ ์žฅ์น˜๋กœ ์ด๋™
89
- pipe.safety_checker = safety_checker.StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
90
- if hasattr(pipe, 'safety_checker') and pipe.safety_checker is not None:
91
- pipe.safety_checker.to("cuda")
92
-
93
- print("๋ชจ๋ธ ๋กœ๋”ฉ ์™„๋ฃŒ")
94
- return True
95
- except Exception as e:
96
- print(f"๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}")
97
- traceback.print_exc()
98
- return False
99
 
 
100
  css = """
101
- footer {display: none !important}
102
- .gradio-container {
103
- max-width: 1200px;
104
- margin: auto;
105
- }
106
- .contain {
107
- background: rgba(255, 255, 255, 0.05);
108
- border-radius: 12px;
109
- padding: 20px;
110
- }
111
- .generate-btn {
112
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important;
113
- border: none !important;
114
- color: white !important;
115
- }
116
- .generate-btn:hover {
117
- transform: translateY(-2px);
118
- box-shadow: 0 5px 15px rgba(0,0,0,0.2);
119
- }
120
- .title {
121
- text-align: center;
122
- font-size: 2.5em;
123
- font-weight: bold;
124
- margin-bottom: 1em;
125
- background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%);
126
- -webkit-background-clip: text;
127
- -webkit-text-fill-color: transparent;
128
- }
129
- .output-image {
130
- width: 100% !important;
131
- max-width: 100% !important;
132
- }
133
- .contain > div {
134
- width: 100% !important;
135
- max-width: 100% !important;
136
- }
137
- .fixed-width {
138
- width: 100% !important;
139
- max-width: 100% !important;
140
- }
141
- .loading-indicator {
142
- text-align: center;
143
- padding: 20px;
144
- font-weight: bold;
145
- color: #4B79A1;
146
- }
147
- .error-message {
148
- background-color: rgba(255, 0, 0, 0.1);
149
- color: red;
150
- padding: 10px;
151
- border-radius: 8px;
152
- margin-top: 10px;
153
- text-align: center;
154
- }
 
 
155
  """
156
 
157
- # Create Gradio interface
158
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
159
- gr.HTML('<div class="title">AI Image Generator</div>')
160
- gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
161
-
162
- gr.HTML("""
163
- <div style="color: red; margin-bottom: 1em; text-align: center; padding: 10px; background: rgba(255,0,0,0.1); border-radius: 8px;">
164
- โš ๏ธ Explicit or inappropriate content cannot be generated.
165
- </div>
166
- """)
 
167
 
168
- # ์ƒํƒœ ํ‘œ์‹œ ๋ณ€์ˆ˜ (HTML ๋Œ€์‹  Textbox ์‚ฌ์šฉ)
169
- error_message = gr.Textbox(
170
- value="",
171
- label="Error",
172
- visible=False,
173
- elem_classes=["error-message"]
174
- )
175
-
176
- loading_status = gr.Textbox(
177
- value="",
178
- label="Status",
179
- visible=False,
180
- elem_classes=["loading-indicator"]
181
- )
182
-
183
- with gr.Row():
184
- with gr.Column(scale=3):
185
- prompt = gr.Textbox(
186
- label="Image Description",
187
- placeholder="Describe the image you want to create...",
188
- lines=3
189
- )
190
-
191
- with gr.Accordion("Advanced Settings", open=False):
192
- with gr.Row():
193
- height = gr.Slider(
194
- label="Height",
195
- minimum=256,
196
- maximum=1152,
197
- step=64,
198
- value=1024
199
- )
200
- width = gr.Slider(
201
- label="Width",
202
- minimum=256,
203
- maximum=1152,
204
- step=64,
205
- value=1024
206
  )
207
-
208
- with gr.Row():
209
- steps = gr.Slider(
210
- label="Inference Steps",
211
- minimum=6,
212
- maximum=25,
213
- step=1,
214
- value=8
215
- )
216
- scales = gr.Slider(
217
- label="Guidance Scale",
218
- minimum=0.0,
219
- maximum=5.0,
220
- step=0.1,
221
- value=3.5
222
- )
223
-
224
- def get_random_seed():
225
- return int(torch.randint(0, 1000000, (1,)).item())
226
-
227
- seed = gr.Number(
228
- label="Seed (random by default, set for reproducibility)",
229
- value=get_random_seed(),
230
- precision=0
231
- )
232
-
233
- randomize_seed = gr.Button("๐ŸŽฒ Randomize Seed", elem_classes=["generate-btn"])
234
-
235
- generate_btn = gr.Button(
236
- "โœจ Generate Image",
237
- elem_classes=["generate-btn"]
238
- )
239
-
240
- gr.HTML("""
241
- <div style="margin-top: 1em; padding: 1em; border-radius: 8px; background: rgba(255, 255, 255, 0.05);">
242
- <h4 style="margin: 0 0 0.5em 0;">Example Prompts:</h4>
243
- <div style="background: rgba(75, 121, 161, 0.1); padding: 1em; border-radius: 8px; margin-bottom: 1em;">
244
- <p style="font-weight: bold; margin: 0 0 0.5em 0;">๐ŸŒ… Cinematic Landscape</p>
245
- <p style="margin: 0; font-style: italic;">"A breathtaking mountain vista at golden hour, dramatic sunbeams piercing through clouds, snow-capped peaks reflecting warm light, ultra-high detail photography, artistically composed, award-winning landscape photo, shot on Hasselblad"</p>
246
- </div>
247
- <div style="background: rgba(75, 121, 161, 0.1); padding: 1em; border-radius: 8px; margin-bottom: 1em;">
248
- <p style="font-weight: bold; margin: 0 0 0.5em 0;">๐Ÿ–ผ๏ธ Fantasy Portrait</p>
249
- <p style="margin: 0; font-style: italic;">"Ethereal portrait of an elven queen with flowing silver hair, adorned with luminescent crystals, intricate crown of twisted gold and moonstone, soft ethereal lighting, detailed facial features, fantasy art style, highly detailed, painted by Artgerm and Charlie Bowater"</p>
250
- </div>
251
- <div style="background: rgba(75, 121, 161, 0.1); padding: 1em; border-radius: 8px; margin-bottom: 1em;">
252
- <p style="font-weight: bold; margin: 0 0 0.5em 0;">๐ŸŒƒ Cyberpunk Scene</p>
253
- <p style="margin: 0; font-style: italic;">"Neon-lit cyberpunk street market in rain, holographic advertisements reflecting in puddles, street vendors with glowing cyber-augmentations, dense urban environment, atmospheric fog, cinematic lighting, inspired by Blade Runner 2049"</p>
254
- </div>
255
- <div style="background: rgba(75, 121, 161, 0.1); padding: 1em; border-radius: 8px; margin-bottom: 1em;">
256
- <p style="font-weight: bold; margin: 0 0 0.5em 0;">๐ŸŽจ Abstract Art</p>
257
- <p style="margin: 0; font-style: italic;">"Vibrant abstract composition of flowing liquid colors, dynamic swirls of iridescent purples and teals, golden geometric patterns emerging from chaos, luxury art style, ultra-detailed, painted in oil on canvas, inspired by James Jean and Gustav Klimt"</p>
258
- </div>
259
- <div style="background: rgba(75, 121, 161, 0.1); padding: 1em; border-radius: 8px; margin-bottom: 1em;">
260
- <p style="font-weight: bold; margin: 0 0 0.5em 0;">๐ŸŒฟ Macro Nature</p>
261
- <p style="margin: 0; font-style: italic;">"Extreme macro photography of a dewdrop on a butterfly wing, rainbow light refraction, crystalline clarity, intricate wing scales visible, natural bokeh background, professional studio lighting, shot with Canon MP-E 65mm lens"</p>
262
- </div>
263
- </div>
264
- """)
265
 
266
- with gr.Column(scale=4, elem_classes=["fixed-width"]):
267
- output = gr.Image(
268
- label="Generated Image",
269
- elem_id="output-image",
270
- elem_classes=["output-image", "fixed-width"]
271
- )
272
-
273
  @spaces.GPU
274
  def process_image(height, width, steps, scales, prompt, seed):
275
  global pipe
276
-
277
- # ์ž…๋ ฅ๊ฐ’ ๊ฒ€์ฆ
278
- if not prompt or prompt.strip() == "":
279
- return None, "", False, "์ด๋ฏธ์ง€ ์„ค๋ช…์„ ์ž…๋ ฅํ•ด์ฃผ์„ธ์š”.", True
280
-
281
- # ํ”„๋กฌํ”„ํŠธ ํ•„ํ„ฐ๋ง
282
- is_safe, filtered_prompt = filter_prompt(prompt)
283
- if not is_safe:
284
- return None, "", False, "๋ถ€์ ์ ˆํ•œ ๋‚ด์šฉ์ด ํฌํ•จ๋œ ํ”„๋กฌํ”„ํŠธ์ž…๋‹ˆ๋‹ค.", True
285
-
286
- try:
287
- # ๋ฉ”๋ชจ๋ฆฌ ํ™•๋ณด๋ฅผ ์œ„ํ•œ ๊ฐ€๋น„์ง€ ์ฝœ๋ ‰์…˜
288
- gc.collect()
289
- torch.cuda.empty_cache()
290
-
291
- # ์‹œ๋“œ ๊ฐ’ ํ™•์ธ ๋ฐ ๋ณด์ •
292
- if seed is None or not isinstance(seed, (int, float)):
293
- seed = get_random_seed()
294
- else:
295
- seed = int(seed) # ํƒ€์ž… ๋ณ€ํ™˜ ์•ˆ์ „ํ•˜๊ฒŒ ์ฒ˜๋ฆฌ
296
-
297
- # ๋†’์ด์™€ ๋„ˆ๋น„๋ฅผ 64์˜ ๋ฐฐ์ˆ˜๋กœ ์กฐ์ • (FLUX ๋ชจ๋ธ ์š”๊ตฌ์‚ฌํ•ญ)
298
- height = (int(height) // 64) * 64
299
- width = (int(width) // 64) * 64
300
-
301
- # ์•ˆ์ „์žฅ์น˜ - ์ตœ๋Œ€๊ฐ’ ์ œํ•œ
302
- steps = min(int(steps), 25)
303
- scales = max(min(float(scales), 5.0), 0.0)
304
-
305
- # ์ด๋ฏธ์ง€ ์ƒ์„ฑ
306
- with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
307
- # ์ค‘์š”: generator ์„ค์ • ์‹œ device๋ฅผ ๋ช…์‹œ์ ์œผ๋กœ ์ง€์ •
308
- generator = torch.Generator(device="cuda").manual_seed(seed)
309
-
310
- # ๋ชจ๋“  ํ…์„œ๊ฐ€ ๊ฐ™์€ ๋””๋ฐ”์ด์Šค์— ์žˆ๋Š”์ง€ ํ™•์ธ
311
- for name, module in pipe.components.items():
312
- if hasattr(module, 'device') and module.device.type != "cuda":
313
- module.to("cuda")
314
-
315
- # ์ด๋ฏธ์ง€ ์ƒ์„ฑ - ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜์— device๋ฅผ ๋ช…์‹œ์  ์ง€์ •
316
- generated_image = pipe(
317
- prompt=[filtered_prompt],
318
- generator=generator,
319
- num_inference_steps=steps,
320
- guidance_scale=scales,
321
- height=height,
322
- width=width,
323
- max_sequence_length=256,
324
- device="cuda" # ๋ช…์‹œ์  device ์ง€์ •
325
- ).images[0]
326
-
327
- return generated_image, "", False, "", False
328
-
329
- except Exception as e:
330
- error_msg = f"์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
331
- print(error_msg)
332
- traceback.print_exc()
333
-
334
- # ์˜ค๋ฅ˜ ํ›„ ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
335
- gc.collect()
336
- torch.cuda.empty_cache()
337
-
338
- return None, "", False, error_msg, True
339
-
340
- def update_seed():
341
- return get_random_seed()
342
-
343
- # ์ด๋ฏธ์ง€ ์ƒ์„ฑ ์ค€๋น„ ํ•จ์ˆ˜
344
- def prepare_generation(height, width, steps, scales, prompt, seed):
345
- global pipe
346
-
347
- # ๋ชจ๋ธ์ด ์•„์ง ๋กœ๋“œ๋˜์ง€ ์•Š์•˜๋‹ค๋ฉด ๋กœ๋“œ
348
- if pipe is None:
349
- # ๋กœ๋”ฉ ์ƒํƒœ ํ‘œ์‹œ
350
- loading_message = "๋ชจ๋ธ์„ ๋กœ๋”ฉ ์ค‘์ž…๋‹ˆ๋‹ค... ์ฒ˜์Œ ์‹คํ–‰ ์‹œ ์‹œ๊ฐ„์ด ์†Œ์š”๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค."
351
-
352
- is_loaded = initialize_model()
353
- if not is_loaded:
354
- return None, "", False, "๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค. ํŽ˜์ด์ง€๋ฅผ ์ƒˆ๋กœ๊ณ ์นจํ•˜๊ณ  ๋‹ค์‹œ ์‹œ๋„ํ•ด ์ฃผ์„ธ์š”.", True
355
-
356
- # ์ƒ์„ฑ ์ƒํƒœ ํ‘œ์‹œ
357
- loading_message = "์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑ ์ค‘์ž…๋‹ˆ๋‹ค..."
358
-
359
- # ์ƒ์„ฑ ํ”„๋กœ์„ธ์Šค ์‹œ์ž‘
360
- return process_image(height, width, steps, scales, prompt, seed)
361
 
362
- # ๋ฒ„ํŠผ ํด๋ฆญ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
363
  generate_btn.click(
364
- fn=prepare_generation,
365
  inputs=[height, width, steps, scales, prompt, seed],
366
- outputs=[output, loading_status, loading_status, error_message, error_message]
367
- )
368
-
369
- randomize_seed.click(
370
- fn=update_seed,
371
- outputs=[seed]
372
  )
373
 
374
  if __name__ == "__main__":
375
-
376
- demo.queue(max_size=10).launch()
 
2
  import argparse
3
  import os
4
  import time
 
5
  from os import path
 
 
 
6
  from safetensors.torch import load_file
7
  from huggingface_hub import hf_hub_download
 
 
 
 
 
8
 
 
9
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
 
10
  os.environ["TRANSFORMERS_CACHE"] = cache_path
11
  os.environ["HF_HUB_CACHE"] = cache_path
12
  os.environ["HF_HOME"] = cache_path
13
 
14
+ import gradio as gr
15
+ import torch
16
+ from diffusers import FluxPipeline
17
 
18
+ torch.backends.cuda.matmul.allow_tf32 = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  class timer:
21
  def __init__(self, method_name="timed process"):
 
27
  end = time.time()
28
  print(f"{self.method} took {str(round(end - self.start, 2))}s")
29
 
30
+ if not path.exists(cache_path):
31
+ os.makedirs(cache_path, exist_ok=True)
32
 
33
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
34
+ pipe.load_lora_weights(hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"))
35
+ pipe.fuse_lora(lora_scale=0.125)
36
+ pipe.to(device="cuda", dtype=torch.bfloat16)
37
+
38
+ # Define example prompts
39
+ example_prompts = [
40
+ "A cyberpunk cityscape at night with neon lights reflecting in puddles, towering skyscrapers and flying cars",
41
+ "An ethereal fairy with translucent iridescent wings standing in an enchanted forest with glowing mushrooms and floating light particles",
42
+ "A majestic dragon soaring through stormy clouds above jagged mountain peaks as lightning strikes in the background",
43
+ "A futuristic space station orbiting a vibrant nebula with multiple colorful ringed planets visible through a massive observation window",
44
+ "An underwater scene of an ancient lost city with ornate temples and statues covered in bioluminescent coral and swimming sea creatures"
45
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ # Custom CSS for neon theme
48
  css = """
49
+ .neon-container {
50
+ background: linear-gradient(to right, #000428, #004e92);
51
+ border-radius: 16px;
52
+ box-shadow: 0 0 15px #00f3ff, 0 0 25px #00f3ff;
53
+ }
54
+ .neon-title {
55
+ text-shadow: 0 0 5px #fff, 0 0 10px #fff, 0 0 15px #0073e6, 0 0 20px #0073e6, 0 0 25px #0073e6;
56
+ color: #ffffff;
57
+ font-weight: bold !important;
58
+ }
59
+ .neon-text {
60
+ color: #00ff9d;
61
+ text-shadow: 0 0 5px #00ff9d;
62
+ }
63
+ .neon-button {
64
+ box-shadow: 0 0 5px #ff00dd, 0 0 10px #ff00dd !important;
65
+ background: linear-gradient(90deg, #ff00dd, #8b00ff) !important;
66
+ border: none !important;
67
+ color: white !important;
68
+ font-weight: bold !important;
69
+ }
70
+ .neon-button:hover {
71
+ box-shadow: 0 0 10px #ff00dd, 0 0 20px #ff00dd !important;
72
+ }
73
+ .neon-input {
74
+ border: 1px solid #00f3ff !important;
75
+ box-shadow: 0 0 5px #00f3ff !important;
76
+ }
77
+ .neon-slider > div {
78
+ background: linear-gradient(90deg, #00ff9d, #00f3ff) !important;
79
+ }
80
+ .neon-slider > div > div {
81
+ background: #ff00dd !important;
82
+ box-shadow: 0 0 5px #ff00dd !important;
83
+ }
84
+ .neon-card {
85
+ background-color: rgba(0, 0, 0, 0.7) !important;
86
+ border: 1px solid #00f3ff !important;
87
+ box-shadow: 0 0 10px #00f3ff !important;
88
+ padding: 16px !important;
89
+ border-radius: 8px !important;
90
+ }
91
+ .neon-example {
92
+ background: rgba(0, 0, 0, 0.5) !important;
93
+ border: 1px solid #00ff9d !important;
94
+ border-radius: 8px !important;
95
+ padding: 8px !important;
96
+ color: #00ff9d !important;
97
+ box-shadow: 0 0 5px #00ff9d !important;
98
+ margin: 4px !important;
99
+ cursor: pointer !important;
100
+ }
101
+ .neon-example:hover {
102
+ box-shadow: 0 0 10px #00ff9d, 0 0 15px #00ff9d !important;
103
+ background: rgba(0, 255, 157, 0.2) !important;
104
+ }
105
  """
106
 
 
107
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
108
+ with gr.Blocks(elem_classes=["neon-container"]):
109
+ gr.Markdown(
110
+ """
111
+ <div style="text-align: center; max-width: 650px; margin: 0 auto;">
112
+ <h1 style="font-size: 3rem; font-weight: 700; margin-bottom: 1rem; display: contents;" class="neon-title">FLUX: Fast & Furious</h1>
113
+ <p style="font-size: 1.2rem; margin-bottom: 1.5rem;" class="neon-text">AutoML team from ByteDance</p>
114
+ </div>
115
+ """
116
+ )
117
 
118
+ with gr.Row():
119
+ with gr.Column(scale=3, elem_classes=["neon-card"]):
120
+ with gr.Group():
121
+ prompt = gr.Textbox(
122
+ label="Your Image Description",
123
+ placeholder="E.g., A serene landscape with mountains and a lake at sunset",
124
+ lines=3,
125
+ elem_classes=["neon-input"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
+
128
+ # Examples section
129
+ gr.Markdown('<p class="neon-text">Click on any example to use it:</p>')
130
+ with gr.Row():
131
+ example_boxes = [gr.Button(ex[:40] + "...", elem_classes=["neon-example"]) for ex in example_prompts]
132
+
133
+ # Connect example buttons to the prompt textbox
134
+ for i, example_btn in enumerate(example_boxes):
135
+ example_btn.click(
136
+ fn=lambda x=example_prompts[i]: x,
137
+ outputs=prompt
138
+ )
139
+
140
+ with gr.Accordion("Advanced Settings", open=False):
141
+ with gr.Group():
142
+ with gr.Row():
143
+ height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024,
144
+ elem_classes=["neon-slider"])
145
+ width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024,
146
+ elem_classes=["neon-slider"])
147
+
148
+ with gr.Row():
149
+ steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8,
150
+ elem_classes=["neon-slider"])
151
+ scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5,
152
+ elem_classes=["neon-slider"])
153
+
154
+ seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0,
155
+ elem_classes=["neon-input"])
156
+
157
+ generate_btn = gr.Button("Generate Image", variant="primary", scale=1, elem_classes=["neon-button"])
158
+
159
+ with gr.Column(scale=4, elem_classes=["neon-card"]):
160
+ output = gr.Image(label="Your Generated Image")
161
+
162
+ gr.Markdown(
163
+ """
164
+ <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px;" class="neon-card">
165
+ <h2 style="font-size: 1.5rem; margin-bottom: 1rem;" class="neon-text">How to Use</h2>
166
+ <ol style="padding-left: 1.5rem; color: #00f3ff;">
167
+ <li>Enter a detailed description of the image you want to create.</li>
168
+ <li>Or click one of our exciting example prompts above!</li>
169
+ <li>Adjust advanced settings if desired (tap to expand).</li>
170
+ <li>Tap "Generate Image" and wait for your creation!</li>
171
+ </ol>
172
+ <p style="margin-top: 1rem; font-style: italic; color: #ff00dd;">Tip: Be specific in your description for best results!</p>
173
+ </div>
174
+ """
175
+ )
 
 
 
 
 
 
 
 
 
176
 
 
 
 
 
 
 
 
177
  @spaces.GPU
178
  def process_image(height, width, steps, scales, prompt, seed):
179
  global pipe
180
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
181
+ return pipe(
182
+ prompt=[prompt],
183
+ generator=torch.Generator().manual_seed(int(seed)),
184
+ num_inference_steps=int(steps),
185
+ guidance_scale=float(scales),
186
+ height=int(height),
187
+ width=int(width),
188
+ max_sequence_length=256
189
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
 
191
  generate_btn.click(
192
+ process_image,
193
  inputs=[height, width, steps, scales, prompt, seed],
194
+ outputs=output
 
 
 
 
 
195
  )
196
 
197
  if __name__ == "__main__":
198
+ demo.launch()