LIA-X-fast / gradio_tabs /img_edit.py
Julian Bilcke
up
bcafc05
import tempfile
import time
import gradio as gr
import torch
import torchvision
from PIL import Image
import numpy as np
import imageio
import spaces
from einops import rearrange
# lables
labels_k = [
'yaw1',
'yaw2',
'pitch',
'roll1',
'roll2',
'neck',
'pout',
'open->close',
'"O" Mouth',
'smile',
'close->open',
'eyebrows',
'eyeballs1',
'eyeballs2',
]
labels_v = [
37, 39, 28, 15, 33, 31,
6, 25, 16, 19,
13, 24, 17, 26
]
@torch.compiler.allow_in_graph
def load_image(img, size):
img = Image.open(img).convert('RGB')
w, h = img.size
img = img.resize((size, size))
img = np.asarray(img)
# Make a writable copy to avoid torch.compile issues
img = np.copy(img)
img = np.transpose(img, (2, 0, 1)) # 3 x 256 x 256
return img / 255.0, w, h
@torch.compiler.allow_in_graph
def img_preprocessing(img_path, size):
img, w, h = load_image(img_path, size) # [0, 1]
img = torch.from_numpy(img).unsqueeze(0).float() # [0, 1]
imgs_norm = (img - 0.5) * 2.0 # [-1, 1]
return imgs_norm, w, h
# Pre-compile resize transforms for better performance
resize_transform_cache = {}
def get_resize_transform(size):
"""Get cached resize transform - creates once, reuses many times"""
if size not in resize_transform_cache:
# Only create the transform if it doesn't exist in cache
resize_transform_cache[size] = torchvision.transforms.Resize(
size,
interpolation=torchvision.transforms.InterpolationMode.BILINEAR,
antialias=True
)
return resize_transform_cache[size]
def resize(img, size):
"""Use cached resize transform"""
transform = get_resize_transform((size, size))
return transform(img)
def resize_back(img, w, h):
"""Use cached resize transform for back operation"""
transform = get_resize_transform((h, w))
return transform(img)
def img_denorm(img):
img = img.clamp(-1, 1).cpu()
img = (img - img.min()) / (img.max() - img.min())
return img
def img_postprocessing(img, w, h):
# Resize on GPU (using cached transform)
img = resize_back(img, w, h)
# Denormalize ON GPU (avoid early CPU transfer)
img = img.clamp(-1, 1) # Still on GPU
img = (img - img.min()) / (img.max() - img.min()) # Still on GPU
# Single optimized CPU transfer
img = img.squeeze(0).permute(1, 2, 0).contiguous() # contiguous() for fast transfer
img_output = (img.cpu().numpy() * 255).astype(np.uint8) # Single CPU transfer
# return the Numpy array directly, since Gradio supports it
return img_output
def img_edit(gen, device):
@torch.compile
def compiled_inference(image_tensor, selected_s):
"""Compiled version of just the model inference"""
return gen.edit_img(image_tensor, labels_v, selected_s)
# Pre-warm the compiled model with dummy data to reduce first-run compilation time
def _warmup_model():
"""Pre-warm the model compilation with representative shapes"""
print("[img_edit] Pre-warming model compilation...")
dummy_image = torch.randn(1, 3, 512, 512, device=device)
dummy_selected_s = [0.0] * len(labels_v)
try:
with torch.inference_mode():
_ = compiled_inference(dummy_image, dummy_selected_s)
print("[img_edit] Model pre-warming completed successfully")
except Exception as e:
print(f"[img_edit] Model pre-warming failed (will compile on first use): {e}")
# Pre-warm the model
_warmup_model()
@spaces.GPU
@torch.inference_mode()
def edit_img(image, *selected_s):
# Start timing (outside compiled function)
start_time = time.time()
print(f"[edit_img] Starting image editing...")
# Image preprocessing timing
preprocess_start = time.time()
image_tensor, w, h = img_preprocessing(image, 512)
image_tensor = image_tensor.to(device)
preprocess_end = time.time()
print(f"[edit_img] Preprocessing took: {(preprocess_end - preprocess_start) * 1000:.2f} ms")
# Model inference timing (compile only the core computation)
inference_start = time.time()
edited_image_tensor = compiled_inference(image_tensor, selected_s)
inference_end = time.time()
print(f"[edit_img] Model inference took: {(inference_end - inference_start) * 1000:.2f} ms")
# Post-processing timing
postprocess_start = time.time()
edited_image = img_postprocessing(edited_image_tensor, w, h)
postprocess_end = time.time()
print(f"[edit_img] Post-processing took: {(postprocess_end - postprocess_start) * 1000:.2f} ms")
# Total time
end_time = time.time()
total_time_ms = (end_time - start_time) * 1000
print(f"[edit_img] Total execution time: {total_time_ms:.2f} ms")
print(f"[edit_img] ----------------------------------------")
return edited_image
def clear_media():
return None, *([0] * len(labels_k))
with gr.Tab("Image Editing"):
inputs_s = []
with gr.Row():
with gr.Column(scale=1):
with gr.Row():
with gr.Accordion(open=True, label="Image"):
image_input = gr.Image(type="filepath", width=512) # , height=550)
gr.Examples(
examples=[
["./data/source/macron.png"],
["./data/source/einstein.png"],
["./data/source/taylor.png"],
["./data/source/portrait1.png"],
["./data/source/portrait2.png"],
["./data/source/portrait3.png"],
],
inputs=[image_input],
#cache_mode="lazy",
visible=True,
)
with gr.Row():
with gr.Column(scale=1):
with gr.Row(): # Buttons now within a single Row
edit_btn = gr.Button("Edit")
clear_btn = gr.Button("Clear")
#with gr.Row():
# animate_btn = gr.Button("Generate")
with gr.Column(scale=1):
with gr.Row():
with gr.Accordion(open=True, label="Edited Image"):
image_output = gr.Image(label="Output Image", type='numpy', interactive=False, width=512)
sliders = []
with gr.Accordion("Control Panel", open=True):
with gr.Tab("Head"):
with gr.Row():
for k in labels_k[:3]:
slider = gr.Slider(minimum=-1.0, maximum=0.5, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[3:6]:
slider = gr.Slider(minimum=-0.5, maximum=0.5, value=0, label=k)
inputs_s.append(slider)
with gr.Tab("Mouth"):
with gr.Row():
for k in labels_k[6:8]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[8:10]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Tab("Eyes"):
with gr.Row():
for k in labels_k[10:12]:
slider = gr.Slider(minimum=-0.4, maximum=0.4, value=0, label=k)
inputs_s.append(slider)
with gr.Row():
for k in labels_k[12:14]:
slider = gr.Slider(minimum=-0.2, maximum=0.2, value=0, label=k)
inputs_s.append(slider)
for slider in inputs_s:
slider.change(
fn=edit_img,
inputs=[image_input] + inputs_s,
outputs=[image_output],
show_progress='hidden',
trigger_mode='always_last',
# currently we have a latency around 450ms
stream_every=0.5
)
edit_btn.click(
fn=edit_img,
inputs=[image_input] + inputs_s,
outputs=[image_output],
show_progress=True
)
clear_btn.click(
fn=clear_media,
outputs=[image_output] + inputs_s
)