Chain-of-Zoom / app.py
alexnasa's picture
Update app.py
f1cb021 verified
import gradio as gr
import subprocess
import os
import shutil
from pathlib import Path
import spaces
# import the updated recursive_multiscale_sr that expects a list of centers
from inference_coz_single import recursive_multiscale_sr
from PIL import Image, ImageDraw
# ------------------------------------------------------------------
# CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE
# ------------------------------------------------------------------
INPUT_DIR = "samples"
OUTPUT_DIR = "inference_results/coz_vlmprompt"
# ------------------------------------------------------------------
# HELPER: Resize & center-crop to 512, preserving aspect ratio
# ------------------------------------------------------------------
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image:
"""
Resize the input PIL image so that its shorter side == `size`,
then center-crop to exactly (size x size).
"""
w, h = img.size
scale = size / min(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.LANCZOS)
left = (new_w - size) // 2
top = (new_h - size) // 2
return img.crop((left, top, left + size, top + size))
# ------------------------------------------------------------------
# HELPER: Draw four true “nested” rectangles, matching the SR logic
# ------------------------------------------------------------------
def make_preview_with_boxes(
image_path: str,
scale_option: str,
cx_norm: float,
cy_norm: float,
) -> Image.Image:
"""
1) Open the uploaded image, resize & center-crop to 512×512.
2) Let scale_int = int(scale_option.replace("x","")).
Then the four nested crop‐sizes (in pixels) are:
size[0] = 512 / (scale_int^1),
size[1] = 512 / (scale_int^2),
size[2] = 512 / (scale_int^3),
size[3] = 512 / (scale_int^4).
3) Iteratively compute each crop’s top-left in “original 512×512” space:
- Start with prev_tl = (0,0), prev_size = 512.
- For i in [0..3]:
center_abs_x = prev_tl_x + cx_norm * prev_size
center_abs_y = prev_tl_y + cy_norm * prev_size
unc_x0 = center_abs_x - (size[i]/2)
unc_y0 = center_abs_y - (size[i]/2)
clamp x0 ∈ [prev_tl_x, prev_tl_x + prev_size - size[i]]
y0 ∈ [prev_tl_y, prev_tl_y + prev_size - size[i]]
Draw a rectangle from (x0, y0) to (x0 + size[i], y0 + size[i]).
Then set prev_tl = (x0, y0), prev_size = size[i].
4) Return the PIL image with those four truly nested outlines.
"""
try:
orig = Image.open(image_path).convert("RGB")
except Exception as e:
# On error, return a gray 512×512 with the error text
fallback = Image.new("RGB", (512, 512), (200, 200, 200))
draw = ImageDraw.Draw(fallback)
draw.text((20, 20), f"Error:\n{e}", fill="red")
return fallback
# 1) Resize & center-crop to 512×512
base = resize_and_center_crop(orig, 512)
# 2) Compute the four nested crop‐sizes
scale_int = int(scale_option.replace("x", "")) # e.g. "4x" → 4
if scale_int <= 1:
# If 1×, then all “nested” sizes are 512 (no real nesting)
sizes = [512, 512, 512, 512]
else:
sizes = [
512 // (scale_int ** (i + 1))
for i in range(4)
]
# e.g. if scale_int=4 → sizes = [128, 32, 8, 2]
draw = ImageDraw.Draw(base)
colors = ["red", "lime", "cyan", "yellow"]
width = 3
# 3) Iteratively compute nested rectangles
prev_tl_x, prev_tl_y = 0.0, 0.0
prev_size = 512.0
for idx, crop_size in enumerate(sizes):
# 3.a) Where is the “normalized center” in this current 512×512 region?
center_abs_x = prev_tl_x + (cx_norm * prev_size)
center_abs_y = prev_tl_y + (cy_norm * prev_size)
# 3.b) Unclamped top-left for this crop
unc_x0 = center_abs_x - (crop_size / 2.0)
unc_y0 = center_abs_y - (crop_size / 2.0)
# 3.c) Clamp so the crop window stays inside [prev_tl .. prev_tl + prev_size]
min_x0 = prev_tl_x
max_x0 = prev_tl_x + prev_size - crop_size
min_y0 = prev_tl_y
max_y0 = prev_tl_y + prev_size - crop_size
x0 = max(min_x0, min(unc_x0, max_x0))
y0 = max(min_y0, min(unc_y0, max_y0))
x1 = x0 + crop_size
y1 = y0 + crop_size
# Draw the rectangle (cast to int for pixels)
draw.rectangle(
[(int(x0), int(y0)), (int(x1), int(y1))],
outline=colors[idx % len(colors)],
width=width
)
# 3.d) Update for the next iteration
prev_tl_x, prev_tl_y = x0, y0
prev_size = crop_size
return base
# ------------------------------------------------------------------
# HELPER FUNCTION FOR INFERENCE (build a list of identical centers)
# ------------------------------------------------------------------
@spaces.GPU()
def run_with_upload(
uploaded_image_path: str,
upscale_option: str,
cx_norm: float,
cy_norm: float,
):
"""
Perform chain-of-zoom super-resolution on a given image, using recursive multi-scale upscaling centered on a specific point.
This function enhances a given image by progressively zooming into a specific point, using a recursive deep super-resolution model.
Args:
uploaded_image_path (str): Path to the input image file on disk.
upscale_option (str): The desired upscale factor as a string. Valid options are "1x", "2x", and "4x".
- "1x" means no upscaling.
- "2x" means 2× enlargement per zoom step.
- "4x" means 4× enlargement per zoom step.
cx_norm (float): Normalized X-coordinate (0 to 1) of the zoom center.
cy_norm (float): Normalized Y-coordinate (0 to 1) of the zoom center.
Returns:
list[PIL.Image.Image]: A list of progressively zoomed-in and super-resolved images at each recursion step (typically 4),
centered around the user-specified point.
Note:
The center point is repeated for each recursion level to maintain consistency during zooming.
This function uses a modified version of the `recursive_multiscale_sr` pipeline for inference.
"""
if uploaded_image_path is None:
return []
upscale_value = int(upscale_option.replace("x", ""))
rec_num = 4 # match the SR pipeline’s default recursion depth
centers = [(cx_norm, cy_norm)] * rec_num
# Call the modified SR function
sr_list, _ = recursive_multiscale_sr(
uploaded_image_path,
upscale=upscale_value,
rec_num=rec_num,
centers=centers,
)
# Return the list of PIL images (Gradio Gallery expects a list)
return sr_list
# ------------------------------------------------------------------
# BUILD THE GRADIO INTERFACE (two sliders + correct preview)
# ------------------------------------------------------------------
css = """
#col-container {
margin: 0 auto;
max-width: 1024px;
}
"""
with gr.Blocks(css=css) as demo:
gr.HTML(
"""
<div style="text-align: center;">
<h1>Chain-of-Zoom</h1>
<p style="font-size:16px;">Extreme Super-Resolution via Scale Autoregression and Preference Alignment</p>
</div>
<br>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://github.com/bryanswkim/Chain-of-Zoom">
<img src='https://img.shields.io/badge/GitHub-Repo-blue'>
</a>
</div>
"""
)
with gr.Column(elem_id="col-container"):
with gr.Row():
with gr.Column():
# 1) Image upload component
upload_image = gr.Image(
label="Input image",
type="filepath"
)
# 2) Radio for choosing 1× / 2× / 4× upscaling
upscale_radio = gr.Radio(
choices=["1x", "2x", "4x"],
value="2x",
show_label=False
)
# 3) Two sliders for normalized center (0..1)
center_x = gr.Slider(
label="Center X (normalized)",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5
)
center_y = gr.Slider(
label="Center Y (normalized)",
minimum=0.0,
maximum=1.0,
step=0.01,
value=0.5
)
# 4) Button to launch inference
run_button = gr.Button("Chain-of-Zoom it")
# 5) Preview (512×512 + four truly nested boxes)
preview_with_box = gr.Image(
label="Preview (512×512 with nested boxes)",
type="pil",
interactive=False
)
with gr.Column():
# 6) Gallery to display multiple output images
output_gallery = gr.Gallery(
label="Inference Results",
show_label=True,
elem_id="gallery",
columns=[2], rows=[2]
)
examples = gr.Examples(
# List of example-rows. Each row is [input_image, scale, cx, cy]
examples=[["samples/0479.png", "4x", 0.5, 0.5], ["samples/0064.png", "4x", 0.5, 0.5], ["samples/0245.png", "4x", 0.5, 0.5], ["samples/0393.png", "4x", 0.5, 0.5]],
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[output_gallery],
fn=run_with_upload,
cache_examples=True
)
# ------------------------------------------------------------------
# CALLBACK #1: update the preview whenever inputs change
# ------------------------------------------------------------------
def update_preview(
img_path: str,
scale_opt: str,
cx: float,
cy: float
) -> Image.Image | None:
"""
If no image uploaded, show blank; otherwise, draw four nested boxes
exactly as the SR pipeline would crop at each recursion.
"""
if img_path is None:
return None
return make_preview_with_boxes(img_path, scale_opt, cx, cy)
upload_image.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box],
show_api=False
)
upscale_radio.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box],
show_api=False
)
center_x.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box],
show_api=False
)
center_y.change(
fn=update_preview,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[preview_with_box],
show_api=False
)
# ------------------------------------------------------------------
# CALLBACK #2: on button‐click, run the SR pipeline
# ------------------------------------------------------------------
run_button.click(
fn=run_with_upload,
inputs=[upload_image, upscale_radio, center_x, center_y],
outputs=[output_gallery]
)
# ------------------------------------------------------------------
# START THE GRADIO SERVER
# ------------------------------------------------------------------
demo.queue()
demo.launch(share=True, mcp_server=True)