Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import subprocess | |
import os | |
import shutil | |
from pathlib import Path | |
import spaces | |
# import the updated recursive_multiscale_sr that expects a list of centers | |
from inference_coz_single import recursive_multiscale_sr | |
from PIL import Image, ImageDraw | |
# ------------------------------------------------------------------ | |
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
# ------------------------------------------------------------------ | |
INPUT_DIR = "samples" | |
OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
# ------------------------------------------------------------------ | |
# HELPER: Resize & center-crop to 512, preserving aspect ratio | |
# ------------------------------------------------------------------ | |
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
""" | |
Resize the input PIL image so that its shorter side == `size`, | |
then center-crop to exactly (size x size). | |
""" | |
w, h = img.size | |
scale = size / min(w, h) | |
new_w, new_h = int(w * scale), int(h * scale) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
left = (new_w - size) // 2 | |
top = (new_h - size) // 2 | |
return img.crop((left, top, left + size, top + size)) | |
# ------------------------------------------------------------------ | |
# HELPER: Draw four true “nested” rectangles, matching the SR logic | |
# ------------------------------------------------------------------ | |
def make_preview_with_boxes( | |
image_path: str, | |
scale_option: str, | |
cx_norm: float, | |
cy_norm: float, | |
) -> Image.Image: | |
""" | |
1) Open the uploaded image, resize & center-crop to 512×512. | |
2) Let scale_int = int(scale_option.replace("x","")). | |
Then the four nested crop‐sizes (in pixels) are: | |
size[0] = 512 / (scale_int^1), | |
size[1] = 512 / (scale_int^2), | |
size[2] = 512 / (scale_int^3), | |
size[3] = 512 / (scale_int^4). | |
3) Iteratively compute each crop’s top-left in “original 512×512” space: | |
- Start with prev_tl = (0,0), prev_size = 512. | |
- For i in [0..3]: | |
center_abs_x = prev_tl_x + cx_norm * prev_size | |
center_abs_y = prev_tl_y + cy_norm * prev_size | |
unc_x0 = center_abs_x - (size[i]/2) | |
unc_y0 = center_abs_y - (size[i]/2) | |
clamp x0 ∈ [prev_tl_x, prev_tl_x + prev_size - size[i]] | |
y0 ∈ [prev_tl_y, prev_tl_y + prev_size - size[i]] | |
Draw a rectangle from (x0, y0) to (x0 + size[i], y0 + size[i]). | |
Then set prev_tl = (x0, y0), prev_size = size[i]. | |
4) Return the PIL image with those four truly nested outlines. | |
""" | |
try: | |
orig = Image.open(image_path).convert("RGB") | |
except Exception as e: | |
# On error, return a gray 512×512 with the error text | |
fallback = Image.new("RGB", (512, 512), (200, 200, 200)) | |
draw = ImageDraw.Draw(fallback) | |
draw.text((20, 20), f"Error:\n{e}", fill="red") | |
return fallback | |
# 1) Resize & center-crop to 512×512 | |
base = resize_and_center_crop(orig, 512) | |
# 2) Compute the four nested crop‐sizes | |
scale_int = int(scale_option.replace("x", "")) # e.g. "4x" → 4 | |
if scale_int <= 1: | |
# If 1×, then all “nested” sizes are 512 (no real nesting) | |
sizes = [512, 512, 512, 512] | |
else: | |
sizes = [ | |
512 // (scale_int ** (i + 1)) | |
for i in range(4) | |
] | |
# e.g. if scale_int=4 → sizes = [128, 32, 8, 2] | |
draw = ImageDraw.Draw(base) | |
colors = ["red", "lime", "cyan", "yellow"] | |
width = 3 | |
# 3) Iteratively compute nested rectangles | |
prev_tl_x, prev_tl_y = 0.0, 0.0 | |
prev_size = 512.0 | |
for idx, crop_size in enumerate(sizes): | |
# 3.a) Where is the “normalized center” in this current 512×512 region? | |
center_abs_x = prev_tl_x + (cx_norm * prev_size) | |
center_abs_y = prev_tl_y + (cy_norm * prev_size) | |
# 3.b) Unclamped top-left for this crop | |
unc_x0 = center_abs_x - (crop_size / 2.0) | |
unc_y0 = center_abs_y - (crop_size / 2.0) | |
# 3.c) Clamp so the crop window stays inside [prev_tl .. prev_tl + prev_size] | |
min_x0 = prev_tl_x | |
max_x0 = prev_tl_x + prev_size - crop_size | |
min_y0 = prev_tl_y | |
max_y0 = prev_tl_y + prev_size - crop_size | |
x0 = max(min_x0, min(unc_x0, max_x0)) | |
y0 = max(min_y0, min(unc_y0, max_y0)) | |
x1 = x0 + crop_size | |
y1 = y0 + crop_size | |
# Draw the rectangle (cast to int for pixels) | |
draw.rectangle( | |
[(int(x0), int(y0)), (int(x1), int(y1))], | |
outline=colors[idx % len(colors)], | |
width=width | |
) | |
# 3.d) Update for the next iteration | |
prev_tl_x, prev_tl_y = x0, y0 | |
prev_size = crop_size | |
return base | |
# ------------------------------------------------------------------ | |
# HELPER FUNCTION FOR INFERENCE (build a list of identical centers) | |
# ------------------------------------------------------------------ | |
def run_with_upload( | |
uploaded_image_path: str, | |
upscale_option: str, | |
cx_norm: float, | |
cy_norm: float, | |
): | |
""" | |
Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point. | |
This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model. | |
Args: | |
uploaded_image_path (str): Path to the input image file on disk. | |
upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x". | |
- "1x" means no upscaling. | |
- "2x" means 2× enlargement per zoom step. | |
- "4x" means 4× enlargement per zoom step. | |
cx_norm (float): Normalized X-coordinate (0 to 1) of the zoom center. | |
cy_norm (float): Normalized Y-coordinate (0 to 1) of the zoom center. | |
Returns: | |
list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4), | |
centered around the user-specified point. | |
Note: | |
The center point is repeated for each recursion level to maintain consistency during zooming. | |
This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference. | |
""" | |
if uploaded_image_path is None: | |
return [] | |
upscale_value = int(upscale_option.replace("x", "")) | |
rec_num = 4 # match the SR pipeline’s default recursion depth | |
centers = [(cx_norm, cy_norm)] * rec_num | |
# Call the modified SR function | |
sr_list, _ = recursive_multiscale_sr( | |
uploaded_image_path, | |
upscale=upscale_value, | |
rec_num=rec_num, | |
centers=centers, | |
) | |
# Return the list of PIL images (Gradio Gallery expects a list) | |
return sr_list | |
# ------------------------------------------------------------------ | |
# BUILD THE GRADIO INTERFACE (two sliders + correct preview) | |
# ------------------------------------------------------------------ | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 1024px; | |
} | |
""" | |
with gr.Blocks(css=css) as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center;"> | |
<h1>Chain-of-Zoom</h1> | |
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p> | |
</div> | |
<br> | |
<div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
<a href="https://github.com/bryanswkim/Chain-of-Zoom"> | |
<img src='https://img.shields.io/badge/GitHub-Repo-blue'> | |
</a> | |
</div> | |
""" | |
) | |
with gr.Column(elem_id="col-container"): | |
with gr.Row(): | |
with gr.Column(): | |
# 1) Image upload component | |
upload_image = gr.Image( | |
label="Input image", | |
type="filepath" | |
) | |
# 2) Radio for choosing 1× / 2× / 4× upscaling | |
upscale_radio = gr.Radio( | |
choices=["1x", "2x", "4x"], | |
value="2x", | |
show_label=False | |
) | |
# 3) Two sliders for normalized center (0..1) | |
center_x = gr.Slider( | |
label="Center X (normalized)", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5 | |
) | |
center_y = gr.Slider( | |
label="Center Y (normalized)", | |
minimum=0.0, | |
maximum=1.0, | |
step=0.01, | |
value=0.5 | |
) | |
# 4) Button to launch inference | |
run_button = gr.Button("Chain-of-Zoom it") | |
# 5) Preview (512×512 + four truly nested boxes) | |
preview_with_box = gr.Image( | |
label="Preview (512×512 with nested boxes)", | |
type="pil", | |
interactive=False | |
) | |
with gr.Column(): | |
# 6) Gallery to display multiple output images | |
output_gallery = gr.Gallery( | |
label="Inference Results", | |
show_label=True, | |
elem_id="gallery", | |
columns=[2], rows=[2] | |
) | |
examples = gr.Examples( | |
# List of example-rows. Each row is [input_image, scale, cx, cy] | |
examples=[["samples/0479.png", "4x", 0.5, 0.5], ["samples/0064.png", "4x", 0.5, 0.5], ["samples/0245.png", "4x", 0.5, 0.5], ["samples/0393.png", "4x", 0.5, 0.5]], | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[output_gallery], | |
fn=run_with_upload, | |
cache_examples=True | |
) | |
# ------------------------------------------------------------------ | |
# CALLBACK #1: update the preview whenever inputs change | |
# ------------------------------------------------------------------ | |
def update_preview( | |
img_path: str, | |
scale_opt: str, | |
cx: float, | |
cy: float | |
) -> Image.Image | None: | |
""" | |
If no image uploaded, show blank; otherwise, draw four nested boxes | |
exactly as the SR pipeline would crop at each recursion. | |
""" | |
if img_path is None: | |
return None | |
return make_preview_with_boxes(img_path, scale_opt, cx, cy) | |
upload_image.change( | |
fn=update_preview, | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[preview_with_box], | |
show_api=False | |
) | |
upscale_radio.change( | |
fn=update_preview, | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[preview_with_box], | |
show_api=False | |
) | |
center_x.change( | |
fn=update_preview, | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[preview_with_box], | |
show_api=False | |
) | |
center_y.change( | |
fn=update_preview, | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[preview_with_box], | |
show_api=False | |
) | |
# ------------------------------------------------------------------ | |
# CALLBACK #2: on button‐click, run the SR pipeline | |
# ------------------------------------------------------------------ | |
run_button.click( | |
fn=run_with_upload, | |
inputs=[upload_image, upscale_radio, center_x, center_y], | |
outputs=[output_gallery] | |
) | |
# ------------------------------------------------------------------ | |
# START THE GRADIO SERVER | |
# ------------------------------------------------------------------ | |
demo.queue() | |
demo.launch(share=True, mcp_server=True) | |