mjoshs commited on
Commit
6ef2343
Β·
verified Β·
1 Parent(s): a02d4f1

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +37 -57
app.py CHANGED
@@ -1,67 +1,47 @@
1
  import gradio as gr
2
- import torch
3
  import os
4
- from diffusers import AmusedPipeline
5
- from PIL import Image
6
  from huggingface_hub import login
 
 
 
 
7
 
8
  # Login to Hugging Face
9
  if os.getenv("HF_TOKEN"):
10
  login(token=os.getenv("HF_TOKEN"))
11
 
12
- # Global variables
13
- pipe = None
14
-
15
- def load_model():
16
- """Load the model (cached after first load)"""
17
- global pipe
18
-
19
- if pipe is not None:
20
- return pipe
21
-
22
  try:
23
- print("Loading aMUSEd (lightweight model)...")
24
- # Load aMUSEd - much smaller model designed for CPU
25
- pipe = AmusedPipeline.from_pretrained(
26
- "amused/amused-512",
27
- torch_dtype=torch.float32, # Use float32 for CPU
28
- use_safetensors=True
29
- )
30
-
31
- # Move to device
32
- device = "cuda" if torch.cuda.is_available() else "cpu"
33
- pipe = pipe.to(device)
34
- print(f"Model loaded on {device}")
35
-
36
- return pipe
37
 
38
- except Exception as e:
39
- print(f"Error loading model: {str(e)}")
40
- raise e
41
-
42
- def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed):
43
- """Generate image using aMUSEd"""
44
- try:
45
- # Load model if not already loaded
46
- pipe = load_model()
47
 
48
- print(f"Generating image with prompt: {prompt}")
49
- # Generate image
50
- with torch.no_grad():
51
- generator = torch.Generator().manual_seed(seed) if seed else None
52
- image = pipe(
53
- prompt=prompt,
54
- negative_prompt=negative_prompt,
55
- num_inference_steps=num_inference_steps,
56
- guidance_scale=guidance_scale,
57
- width=width,
58
- height=height,
59
- generator=generator
60
- ).images[0]
61
 
62
- print("Image generated successfully!")
63
- return image
64
 
 
 
 
 
 
 
 
 
 
65
  except Exception as e:
66
  print(f"Error generating image: {str(e)}")
67
  return None
@@ -69,8 +49,8 @@ def generate_image(prompt, negative_prompt, num_inference_steps, guidance_scale,
69
  # Create Gradio interface
70
  with gr.Blocks(title="Silicon Photonics Image Generator") as demo:
71
  gr.Markdown("# πŸ”¬ Silicon Photonics Image Generator")
72
- gr.Markdown("Generate technical diagrams and visualizations for silicon photonics using aMUSEd (lightweight model).")
73
- gr.Markdown("**πŸš€ Now using aMUSEd for fast CPU inference!**")
74
 
75
  with gr.Row():
76
  with gr.Column():
@@ -86,12 +66,12 @@ with gr.Blocks(title="Silicon Photonics Image Generator") as demo:
86
  )
87
 
88
  with gr.Row():
89
- num_inference_steps = gr.Slider(1, 20, value=10, label="Inference Steps")
90
  guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale")
91
 
92
  with gr.Row():
93
- width = gr.Slider(256, 512, value=512, step=64, label="Width")
94
- height = gr.Slider(256, 512, value=512, step=64, label="Height")
95
  seed = gr.Number(value=42, label="Seed (0 for random)")
96
 
97
  generate_btn = gr.Button("🎨 Generate Image", variant="primary")
@@ -116,7 +96,7 @@ with gr.Blocks(title="Silicon Photonics Image Generator") as demo:
116
 
117
  # Event handlers
118
  generate_btn.click(
119
- fn=generate_image,
120
  inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed],
121
  outputs=[output_image]
122
  )
 
1
  import gradio as gr
 
2
  import os
 
 
3
  from huggingface_hub import login
4
+ import requests
5
+ import base64
6
+ from PIL import Image
7
+ import io
8
 
9
  # Login to Hugging Face
10
  if os.getenv("HF_TOKEN"):
11
  login(token=os.getenv("HF_TOKEN"))
12
 
13
+ def generate_image_mlx(prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed):
14
+ """Generate image using MLX via Hugging Face Inference API"""
 
 
 
 
 
 
 
 
15
  try:
16
+ print(f"Generating image with prompt: {prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Use Hugging Face Inference API with MLX
19
+ api_url = "https://api-inference.huggingface.co/models/stabilityai/stable-diffusion-xl-base-1.0"
20
+ headers = {"Authorization": f"Bearer {os.getenv('HF_TOKEN')}"}
 
 
 
 
 
 
21
 
22
+ payload = {
23
+ "inputs": prompt,
24
+ "parameters": {
25
+ "negative_prompt": negative_prompt,
26
+ "num_inference_steps": num_inference_steps,
27
+ "guidance_scale": guidance_scale,
28
+ "width": width,
29
+ "height": height,
30
+ "seed": seed if seed else None
31
+ }
32
+ }
 
 
33
 
34
+ response = requests.post(api_url, headers=headers, json=payload)
 
35
 
36
+ if response.status_code == 200:
37
+ # Convert response to PIL Image
38
+ image = Image.open(io.BytesIO(response.content))
39
+ print("Image generated successfully with MLX!")
40
+ return image
41
+ else:
42
+ print(f"API request failed: {response.status_code} - {response.text}")
43
+ return None
44
+
45
  except Exception as e:
46
  print(f"Error generating image: {str(e)}")
47
  return None
 
49
  # Create Gradio interface
50
  with gr.Blocks(title="Silicon Photonics Image Generator") as demo:
51
  gr.Markdown("# πŸ”¬ Silicon Photonics Image Generator")
52
+ gr.Markdown("Generate technical diagrams and visualizations for silicon photonics using MLX-powered inference.")
53
+ gr.Markdown("**πŸš€ Powered by MLX for ultra-fast Apple Silicon performance!**")
54
 
55
  with gr.Row():
56
  with gr.Column():
 
66
  )
67
 
68
  with gr.Row():
69
+ num_inference_steps = gr.Slider(1, 50, value=20, label="Inference Steps")
70
  guidance_scale = gr.Slider(1.0, 20.0, value=7.0, label="Guidance Scale")
71
 
72
  with gr.Row():
73
+ width = gr.Slider(512, 1024, value=1024, step=64, label="Width")
74
+ height = gr.Slider(512, 1024, value=1024, step=64, label="Height")
75
  seed = gr.Number(value=42, label="Seed (0 for random)")
76
 
77
  generate_btn = gr.Button("🎨 Generate Image", variant="primary")
 
96
 
97
  # Event handlers
98
  generate_btn.click(
99
+ fn=generate_image_mlx,
100
  inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, width, height, seed],
101
  outputs=[output_image]
102
  )