dkatz2391 commited on
Commit
1fc85c4
Β·
verified Β·
1 Parent(s): 287cce8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -208
app.py CHANGED
@@ -1,20 +1,18 @@
1
- import gradio as gr
2
- import spaces
3
- from gradio_litmodel3d import LitModel3D
4
  import os
5
  import shutil
6
  import random
7
  import uuid
8
  from datetime import datetime
9
- from diffusers import DiffusionPipeline
10
-
11
- os.environ['SPCONV_ALGO'] = 'native'
12
- from typing import *
13
  import torch
14
  import numpy as np
15
  import imageio
16
  from easydict import EasyDict as edict
17
  from PIL import Image
 
 
 
 
18
  from trellis.pipelines import TrellisImageTo3DPipeline
19
  from trellis.representations import Gaussian, MeshExtractResult
20
  from trellis.utils import render_utils, postprocessing_utils
@@ -22,24 +20,14 @@ from trellis.utils import render_utils, postprocessing_utils
22
  NUM_INFERENCE_STEPS = 8
23
 
24
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
25
- # Constants
26
  MAX_SEED = np.iinfo(np.int32).max
27
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
28
  os.makedirs(TMP_DIR, exist_ok=True)
29
 
30
- # Create permanent storage directory for Flux generated images
31
  SAVE_DIR = "saved_images"
32
  if not os.path.exists(SAVE_DIR):
33
  os.makedirs(SAVE_DIR, exist_ok=True)
34
 
35
- def start_session(req: gr.Request):
36
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
37
- os.makedirs(user_dir, exist_ok=True)
38
-
39
- def end_session(req: gr.Request):
40
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
41
- shutil.rmtree(user_dir)
42
-
43
  def preprocess_image(image: Image.Image) -> Image.Image:
44
  processed_image = trellis_pipeline.preprocess_image(image)
45
  return processed_image
@@ -85,224 +73,163 @@ def unpack_state(state: dict) -> Tuple[Gaussian, edict]:
85
  def get_seed(randomize_seed: bool, seed: int) -> int:
86
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
87
 
88
- @spaces.GPU
89
- def generate_flux_image(
90
- prompt: str,
91
- seed: int,
92
- randomize_seed: bool,
93
- width: int,
94
- height: int,
95
- guidance_scale: float,
96
- progress: gr.Progress = gr.Progress(track_tqdm=True),
97
- ) -> Image.Image:
98
- """Generate image using Flux pipeline"""
99
- if randomize_seed:
100
- seed = random.randint(0, MAX_SEED)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  generator = torch.Generator(device=device).manual_seed(seed)
102
- prompt = "wbgmsst, " + prompt + ", 3D isometric, white background"
103
  image = flux_pipeline(
104
  prompt=prompt,
105
- guidance_scale=guidance_scale,
106
  num_inference_steps=NUM_INFERENCE_STEPS,
107
- width=width,
108
- height=height,
109
  generator=generator,
110
  ).images[0]
111
-
112
- # Save the generated image
113
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
114
  unique_id = str(uuid.uuid4())[:8]
115
  filename = f"{timestamp}_{unique_id}.png"
116
  filepath = os.path.join(SAVE_DIR, filename)
117
  image.save(filepath)
118
-
119
- return image
120
-
121
- @spaces.GPU
122
- def image_to_3d(
123
- image: Image.Image,
124
- seed: int,
125
- ss_guidance_strength: float,
126
- ss_sampling_steps: int,
127
- slat_guidance_strength: float,
128
- slat_sampling_steps: int,
129
- req: gr.Request,
130
- ) -> Tuple[dict, str]:
131
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
132
  outputs = trellis_pipeline.run(
133
  image,
134
- seed=seed,
135
  formats=["gaussian", "mesh"],
136
  preprocess_image=False,
137
  sparse_structure_sampler_params={
138
- "steps": ss_sampling_steps,
139
- "cfg_strength": ss_guidance_strength,
140
  },
141
  slat_sampler_params={
142
- "steps": slat_sampling_steps,
143
- "cfg_strength": slat_guidance_strength,
144
  },
145
  )
146
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
147
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
148
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
 
149
  video_path = os.path.join(user_dir, 'sample.mp4')
150
  imageio.mimsave(video_path, video, fps=15)
151
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
152
  torch.cuda.empty_cache()
153
- return state, video_path
154
-
155
- @spaces.GPU(duration=90)
156
- def extract_glb(
157
- state: dict,
158
- mesh_simplify: float,
159
- texture_size: int,
160
- req: gr.Request,
161
- ) -> Tuple[str, str]:
162
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
163
- gs, mesh = unpack_state(state)
164
- glb = postprocessing_utils.to_glb(gs, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False)
165
- glb_path = os.path.join(user_dir, 'sample.glb')
166
- glb.export(glb_path)
167
- torch.cuda.empty_cache()
168
- return glb_path, glb_path
169
-
170
- @spaces.GPU
171
- def extract_gaussian(state: dict, req: gr.Request) -> Tuple[str, str]:
172
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
173
- gs, _ = unpack_state(state)
174
- gaussian_path = os.path.join(user_dir, 'sample.ply')
175
- gs.save_ply(gaussian_path)
176
- torch.cuda.empty_cache()
177
- return gaussian_path, gaussian_path
178
-
179
- # Gradio Interface
180
- with gr.Blocks() as demo:
181
- gr.Markdown("""
182
- ## Game Asset Generation to 3D with FLUX and TRELLIS
183
- * Enter a prompt to generate a game asset image, then convert it to 3D
184
- * If you find the generated 3D asset satisfactory, click "Extract GLB" to extract the GLB file and download it.
185
- * [TRELLIS Model](https://huggingface.co/JeffreyXiang/TRELLIS-image-large) [Trellis Github](https://github.com/microsoft/TRELLIS) [Flux-Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
186
- * [Flux Game Assets LoRA](https://huggingface.co/gokaygokay/Flux-Game-Assets-LoRA-v2) [Hyper FLUX 8Steps LoRA](https://huggingface.co/ByteDance/Hyper-SD) [safetensors to GGUF for Flux](https://github.com/ruSauron/to-gguf-bat) [Thanks to John6666](https://huggingface.co/John6666)
187
- """)
188
-
189
- with gr.Row():
190
- with gr.Column():
191
- # Flux image generation inputs
192
- prompt = gr.Text(label="Prompt", placeholder="Enter your game asset description")
193
-
194
- with gr.Accordion("Generation Settings", open=False):
195
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=42, step=1)
196
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
197
- with gr.Row():
198
- width = gr.Slider(512, 1024, label="Width", value=1024, step=16)
199
- height = gr.Slider(512, 1024, label="Height", value=1024, step=16)
200
- with gr.Row():
201
- guidance_scale = gr.Slider(0.0, 10.0, label="Guidance Scale", value=3.5, step=0.1)
202
- # num_inference_steps = gr.Slider(1, 50, label="Steps", value=8, step=1)
203
-
204
- with gr.Accordion("3D Generation Settings", open=False):
205
- gr.Markdown("Stage 1: Sparse Structure Generation")
206
- with gr.Row():
207
- ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
208
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
209
- gr.Markdown("Stage 2: Structured Latent Generation")
210
- with gr.Row():
211
- slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=3.0, step=0.1)
212
- slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
213
-
214
- generate_btn = gr.Button("Generate")
215
-
216
- with gr.Accordion("GLB Extraction Settings", open=False):
217
- mesh_simplify = gr.Slider(0.9, 0.98, label="Simplify", value=0.95, step=0.01)
218
- texture_size = gr.Slider(512, 2048, label="Texture Size", value=1024, step=512)
219
-
220
- with gr.Row():
221
- extract_glb_btn = gr.Button("Extract GLB", interactive=False)
222
- extract_gs_btn = gr.Button("Extract Gaussian", interactive=False)
223
-
224
- with gr.Column():
225
- generated_image = gr.Image(label="Generated Asset", type="pil")
226
-
227
- with gr.Column():
228
-
229
- video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True)
230
- model_output = LitModel3D(label="Extracted GLB/Gaussian", exposure=8.0, height=400)
231
-
232
- with gr.Row():
233
- download_glb = gr.DownloadButton(label="Download GLB", interactive=False)
234
- download_gs = gr.DownloadButton(label="Download Gaussian", interactive=False)
235
-
236
- output_buf = gr.State()
237
 
238
- # Event handlers
239
- demo.load(start_session)
240
- demo.unload(end_session)
 
241
 
242
- generate_btn.click(
243
- generate_flux_image,
244
- inputs=[prompt, seed, randomize_seed, width, height, guidance_scale],
245
- outputs=[generated_image],
246
- ).then(
247
- image_to_3d,
248
- inputs=[generated_image, seed, ss_guidance_strength, ss_sampling_steps, slat_guidance_strength, slat_sampling_steps],
249
- outputs=[output_buf, video_output],
250
- ).then(
251
- lambda: tuple([gr.Button(interactive=True), gr.Button(interactive=True)]),
252
- outputs=[extract_glb_btn, extract_gs_btn],
253
- )
254
 
255
- extract_glb_btn.click(
256
- extract_glb,
257
- inputs=[output_buf, mesh_simplify, texture_size],
258
- outputs=[model_output, download_glb],
259
- ).then(
260
- lambda: gr.Button(interactive=True),
261
- outputs=[download_glb],
262
- )
263
-
264
- extract_gs_btn.click(
265
- extract_gaussian,
266
- inputs=[output_buf],
267
- outputs=[model_output, download_gs],
268
- ).then(
269
- lambda: gr.Button(interactive=True),
270
- outputs=[download_gs],
271
- )
272
-
273
- model_output.clear(
274
- lambda: gr.Button(interactive=False),
275
- outputs=[download_glb],
276
- )
277
-
278
- # Initialize both pipelines
279
- if __name__ == "__main__":
280
- from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig, GGUFQuantizationConfig
281
- from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
282
-
283
- # Initialize Flux pipeline
284
- device = "cuda" if torch.cuda.is_available() else "cpu"
285
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
286
-
287
- dtype = torch.bfloat16
288
- file_url = "https://huggingface.co/gokaygokay/flux-game/blob/main/hyperflux_00001_.q8_0.gguf"
289
- file_url = file_url.replace("/resolve/main/", "/blob/main/").replace("?download=true", "")
290
- single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
291
- quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
292
- text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf, token=huggingface_token)
293
- if ".gguf" in file_url:
294
- transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", quantization_config=GGUFQuantizationConfig(compute_dtype=dtype), torch_dtype=dtype, config=single_file_base_model)
295
- else:
296
- quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, token=huggingface_token)
297
- transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config, token=huggingface_token)
298
- flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype, token=huggingface_token)
299
- flux_pipeline.to("cuda")
300
- # Initialize Trellis pipeline
301
- trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
302
- trellis_pipeline.cuda()
303
- try:
304
- trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
305
- except:
306
- pass
307
-
308
- demo.launch(share=True)
 
 
 
 
1
  import os
2
  import shutil
3
  import random
4
  import uuid
5
  from datetime import datetime
6
+ from typing import Tuple
 
 
 
7
  import torch
8
  import numpy as np
9
  import imageio
10
  from easydict import EasyDict as edict
11
  from PIL import Image
12
+ from fastapi import FastAPI
13
+ from fastapi.responses import FileResponse
14
+ from pydantic import BaseModel
15
+ from diffusers import DiffusionPipeline
16
  from trellis.pipelines import TrellisImageTo3DPipeline
17
  from trellis.representations import Gaussian, MeshExtractResult
18
  from trellis.utils import render_utils, postprocessing_utils
 
20
  NUM_INFERENCE_STEPS = 8
21
 
22
  huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
25
  os.makedirs(TMP_DIR, exist_ok=True)
26
 
 
27
  SAVE_DIR = "saved_images"
28
  if not os.path.exists(SAVE_DIR):
29
  os.makedirs(SAVE_DIR, exist_ok=True)
30
 
 
 
 
 
 
 
 
 
31
  def preprocess_image(image: Image.Image) -> Image.Image:
32
  processed_image = trellis_pipeline.preprocess_image(image)
33
  return processed_image
 
73
  def get_seed(randomize_seed: bool, seed: int) -> int:
74
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
75
 
76
+ # Initialize both pipelines at startup
77
+ from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig, GGUFQuantizationConfig
78
+ from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
79
+
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ dtype = torch.bfloat16 if device == "cuda" else torch.float32
82
+ file_url = "https://huggingface.co/gokaygokay/flux-game/blob/main/hyperflux_00001_.q8_0.gguf"
83
+ file_url = file_url.replace("/resolve/main/", "/blob/main/").replace("?download=true", "")
84
+ single_file_base_model = "dkatz2391/Flux1Dev"
85
+
86
+ if device == "cuda":
87
+ quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
88
+ text_encoder_2 = T5EncoderModel.from_pretrained(
89
+ single_file_base_model,
90
+ subfolder="text_encoder_2",
91
+ torch_dtype=dtype,
92
+ config=single_file_base_model,
93
+ quantization_config=quantization_config_tf
94
+ )
95
+ if ".gguf" in file_url:
96
+ transformer = FluxTransformer2DModel.from_single_file(
97
+ file_url,
98
+ subfolder="transformer",
99
+ quantization_config=GGUFQuantizationConfig(compute_dtype=dtype),
100
+ torch_dtype=dtype,
101
+ config=single_file_base_model
102
+ )
103
+ else:
104
+ quantization_config = BitsAndBytesConfig(
105
+ load_in_4bit=True,
106
+ bnb_4bit_quant_type="nf4",
107
+ bnb_4bit_use_double_quant=True,
108
+ bnb_4bit_compute_dtype=torch.bfloat16
109
+ )
110
+ transformer = FluxTransformer2DModel.from_single_file(
111
+ file_url,
112
+ subfolder="transformer",
113
+ torch_dtype=dtype,
114
+ config=single_file_base_model,
115
+ quantization_config=quantization_config
116
+ )
117
+ else:
118
+ # CPU fallback: no quantization
119
+ text_encoder_2 = T5EncoderModel.from_pretrained(
120
+ single_file_base_model,
121
+ subfolder="text_encoder_2",
122
+ torch_dtype=dtype,
123
+ config=single_file_base_model
124
+ )
125
+ transformer = FluxTransformer2DModel.from_single_file(
126
+ file_url,
127
+ subfolder="transformer",
128
+ torch_dtype=dtype,
129
+ config=single_file_base_model
130
+ )
131
+
132
+ flux_pipeline = FluxPipeline.from_pretrained(
133
+ single_file_base_model,
134
+ transformer=transformer,
135
+ text_encoder_2=text_encoder_2,
136
+ torch_dtype=dtype
137
+ )
138
+ flux_pipeline.to(device)
139
+
140
+ trellis_pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
141
+ trellis_pipeline.cuda()
142
+ try:
143
+ trellis_pipeline.preprocess_image(Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)))
144
+ except:
145
+ pass
146
+
147
+ # FastAPI app
148
+ app = FastAPI()
149
+
150
+ class TextToImageRequest(BaseModel):
151
+ prompt: str
152
+ seed: int = 42
153
+ randomize_seed: bool = True
154
+ width: int = 1024
155
+ height: int = 1024
156
+ guidance_scale: float = 3.5
157
+
158
+ @app.post("/text-to-image")
159
+ def text_to_image_api(req: TextToImageRequest):
160
+ # Generate image using Flux pipeline
161
+ seed = get_seed(req.randomize_seed, req.seed)
162
  generator = torch.Generator(device=device).manual_seed(seed)
163
+ prompt = "wbgmsst, " + req.prompt + ", 3D isometric, white background"
164
  image = flux_pipeline(
165
  prompt=prompt,
166
+ guidance_scale=req.guidance_scale,
167
  num_inference_steps=NUM_INFERENCE_STEPS,
168
+ width=req.width,
169
+ height=req.height,
170
  generator=generator,
171
  ).images[0]
 
 
172
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
173
  unique_id = str(uuid.uuid4())[:8]
174
  filename = f"{timestamp}_{unique_id}.png"
175
  filepath = os.path.join(SAVE_DIR, filename)
176
  image.save(filepath)
177
+ return {"image_path": filepath}
178
+
179
+ class ImageTo3DRequest(BaseModel):
180
+ image_path: str
181
+ seed: int = 42
182
+ ss_guidance_strength: float = 7.5
183
+ ss_sampling_steps: int = 12
184
+ slat_guidance_strength: float = 3.0
185
+ slat_sampling_steps: int = 12
186
+
187
+ @app.post("/image-to-3d")
188
+ def image_to_3d_api(req: ImageTo3DRequest):
189
+ # Load image
190
+ image = Image.open(req.image_path)
191
  outputs = trellis_pipeline.run(
192
  image,
193
+ seed=req.seed,
194
  formats=["gaussian", "mesh"],
195
  preprocess_image=False,
196
  sparse_structure_sampler_params={
197
+ "steps": req.ss_sampling_steps,
198
+ "cfg_strength": req.ss_guidance_strength,
199
  },
200
  slat_sampler_params={
201
+ "steps": req.slat_sampling_steps,
202
+ "cfg_strength": req.slat_guidance_strength,
203
  },
204
  )
205
  video = render_utils.render_video(outputs['gaussian'][0], num_frames=120)['color']
206
  video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=120)['normal']
207
  video = [np.concatenate([video[i], video_geo[i]], axis=1) for i in range(len(video))]
208
+ user_dir = TMP_DIR # You can customize this per request if needed
209
  video_path = os.path.join(user_dir, 'sample.mp4')
210
  imageio.mimsave(video_path, video, fps=15)
211
  state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
212
  torch.cuda.empty_cache()
213
+ # Save GLB (placeholder: you must implement actual GLB saving logic)
214
+ glb_path = os.path.join(user_dir, 'output.glb')
215
+ # TODO: Save the actual GLB to glb_path
216
+ return {
217
+ "state": state,
218
+ "video_path": video_path,
219
+ "glb_path": glb_path
220
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
+ @app.get("/image/{filename}")
223
+ def get_image(filename: str):
224
+ file_path = os.path.join(SAVE_DIR, filename)
225
+ return FileResponse(file_path, media_type="image/png")
226
 
227
+ @app.get("/mp4/{filename}")
228
+ def get_mp4(filename: str):
229
+ file_path = os.path.join(TMP_DIR, filename)
230
+ return FileResponse(file_path, media_type="video/mp4")
 
 
 
 
 
 
 
 
231
 
232
+ @app.get("/glb/{filename}")
233
+ def get_glb(filename: str):
234
+ file_path = os.path.join(TMP_DIR, filename)
235
+ return FileResponse(file_path, media_type="model/gltf-binary")