Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tempfile | |
import uuid | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
from qwen_vl_utils import process_vision_info | |
from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler | |
# ------------------------------------------------------------------- | |
# Helper: Resize & center-crop to a fixed square | |
# ------------------------------------------------------------------- | |
def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
w, h = img.size | |
scale = size / min(w, h) | |
new_w, new_h = int(w * scale), int(h * scale) | |
img = img.resize((new_w, new_h), Image.LANCZOS) | |
left = (new_w - size) // 2 | |
top = (new_h - size) // 2 | |
return img.crop((left, top, left + size, top + size)) | |
# ------------------------------------------------------------------- | |
# Helper: Generate a single VLM prompt for recursive_multiscale | |
# ------------------------------------------------------------------- | |
def _generate_vlm_prompt( | |
vlm_model, | |
vlm_processor, | |
process_vision_info, | |
prev_image_path: str, | |
zoomed_image_path: str, | |
device: str = "cuda" | |
) -> str: | |
""" | |
Given two image file paths: | |
- prev_image_path: the “full” image at the previous recursion. | |
- zoomed_image_path: the cropped+resized (zoom) image for this step. | |
This builds a single “recursive_multiscale” prompt via Qwen2.5-VL. | |
Returns a string like “cat on sofa, pet, indoor, living room”, etc. | |
""" | |
# (1) Define the system message for recursive_multiscale: | |
message_text = ( | |
"The second image is a zoom-in of the first image. " | |
"Based on this knowledge, what is in the second image? " | |
"Give me a set of words." | |
) | |
# (2) Build the two-image “chat” payload: | |
messages = [ | |
{"role": "system", "content": message_text}, | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": prev_image_path}, | |
{"type": "image", "image": zoomed_image_path}, | |
], | |
}, | |
] | |
# (3) Wrap through the VL processor to get “inputs”: | |
text = vlm_processor.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
image_inputs, video_inputs = process_vision_info(messages) | |
inputs = vlm_processor( | |
text=[text], | |
images=image_inputs, | |
videos=video_inputs, | |
padding=True, | |
return_tensors="pt", | |
).to(device) | |
# (4) Generate tokens → decode | |
generated = vlm_model.generate(**inputs, max_new_tokens=128) | |
# strip off the prompt tokens from each generated sequence: | |
trimmed = [ | |
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated) | |
] | |
out_text = vlm_processor.batch_decode( | |
trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
)[0] | |
# (5) Return exactly the bare words (no extra “,” if no additional user prompt) | |
return out_text.strip() | |
# ------------------------------------------------------------------- | |
# Main Function: recursive_multiscale_sr (with multiple centers) | |
# ------------------------------------------------------------------- | |
def recursive_multiscale_sr( | |
input_png_path: str, | |
upscale: int, | |
rec_num: int = 4, | |
centers: list[tuple[float, float]] = None, | |
) -> tuple[list[Image.Image], list[str]]: | |
""" | |
Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG. | |
- input_png_path: path to a single .png file on disk. | |
- upscale: integer up-scale factor per recursion (e.g. 4). | |
- rec_num: how many recursion steps to perform. | |
- centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step, | |
indicating where to center the low-res crop for each step. The list | |
length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5) | |
for all steps. | |
Returns a tuple (sr_pil_list, prompt_list), where: | |
- sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order. | |
- prompt_list: list of the VLM prompts generated at each recursion. | |
""" | |
############################### | |
# 0. Validate / fill default centers | |
############################### | |
if centers is None: | |
# Default: use center (0.5, 0.5) for every recursion | |
centers = [(0.5, 0.5) for _ in range(rec_num)] | |
else: | |
if not isinstance(centers, (list, tuple)) or len(centers) != rec_num: | |
raise ValueError( | |
f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}." | |
) | |
############################### | |
# 1. Fixed hyper-parameters | |
############################### | |
device = "cuda" | |
process_size = 512 # same as args.process_size | |
# model checkpoint paths (hard-coded to your example) | |
LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl" | |
VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt" | |
SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers" | |
# VLM model name (hard-coded) | |
VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" | |
############################### | |
# 2. Build a dummy “args” namespace | |
# to satisfy OSEDiff_SD3_TEST constructor. | |
############################### | |
class _Args: | |
pass | |
args = _Args() | |
args.upscale = upscale | |
args.lora_path = LORA_PATH | |
args.vae_path = VAE_PATH | |
args.pretrained_model_name_or_path = SD3_MODEL | |
args.merge_and_unload_lora = False | |
args.lora_rank = 4 | |
args.vae_decoder_tiled_size = 224 | |
args.vae_encoder_tiled_size = 1024 | |
args.latent_tiled_size = 96 | |
args.latent_tiled_overlap = 32 | |
args.mixed_precision = "fp16" | |
args.efficient_memory = False | |
# (other flags are not used by OSEDiff_SD3_TEST, so we skip them) | |
############################### | |
# 3. Load the SD3 SR model (non-efficient) | |
############################### | |
# 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders | |
sd3 = SD3Euler() | |
# move all text encoders + transformer + VAE to CUDA: | |
sd3.text_enc_1.to(device) | |
sd3.text_enc_2.to(device) | |
sd3.text_enc_3.to(device) | |
sd3.transformer.to(device, dtype=torch.float32) | |
sd3.vae.to(device, dtype=torch.float32) | |
# freeze | |
for p in ( | |
sd3.text_enc_1, | |
sd3.text_enc_2, | |
sd3.text_enc_3, | |
sd3.transformer, | |
sd3.vae, | |
): | |
p.requires_grad_(False) | |
# 3.2 Wrap in OSEDiff_SD3_TEST helper: | |
model_test = OSEDiff_SD3_TEST(args, sd3) | |
# (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor]) | |
############################### | |
# 4. Load the VLM (Qwen2.5-VL) | |
############################### | |
vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
VLM_NAME, | |
torch_dtype="auto", | |
device_map="auto" # immediately dispatches layers onto available GPUs | |
) | |
vlm_processor = AutoProcessor.from_pretrained(VLM_NAME) | |
############################### | |
# 5. Pre-allocate a Temporary Directory | |
# to hold intermediate JPEG/PNG files | |
############################### | |
unique_id = uuid.uuid4().hex | |
prefix = f"recms_{unique_id}_" | |
with tempfile.TemporaryDirectory(prefix=prefix) as td: | |
# (we’ll write “prev.png” and “zoom.png” at each step) | |
############################### | |
# 6. Prepare the very first “full” image | |
############################### | |
# 6.1 Load + center crop → first_image is (512×512) PIL on CPU | |
img0 = Image.open(input_png_path).convert("RGB") | |
img0 = resize_and_center_crop(img0, process_size) | |
# 6.2 Save it once so VLM can read it as “prev.png” | |
prev_path = os.path.join(td, "step0_prev.png") | |
img0.save(prev_path) | |
# We will maintain lists of PIL outputs and prompts: | |
sr_pil_list: list[Image.Image] = [] | |
prompt_list: list[str] = [] | |
############################### | |
# 7. Recursion loop (now up to rec_num times) | |
############################### | |
for rec in range(rec_num): | |
# (A) Load the previous SR output (or original) and compute crop window | |
prev_pil = Image.open(prev_path).convert("RGB") | |
w, h = prev_pil.size # should be (512×512) each time | |
# (1) Compute the “low-res” window size: | |
new_w, new_h = w // upscale, h // upscale # e.g. 128×128 for upscale=4 | |
# (2) Map normalized center → pixel center, then clamp so crop stays in bounds: | |
cx_norm, cy_norm = centers[rec] | |
cx = int(cx_norm * w) | |
cy = int(cy_norm * h) | |
half_w = new_w // 2 | |
half_h = new_h // 2 | |
# If center in pixels is too close to left/top, clamp so left=0 or top=0; same on right/bottom | |
left = cx - half_w | |
top = cy - half_h | |
# clamp left ∈ [0, w - new_w], top ∈ [0, h - new_h] | |
left = max(0, min(left, w - new_w)) | |
top = max(0, min(top, h - new_h)) | |
right = left + new_w | |
bottom = top + new_h | |
cropped = prev_pil.crop((left, top, right, bottom)) | |
# (B) Resize that crop back up to (512×512) via BICUBIC → zoomed | |
zoomed = cropped.resize((w, h), Image.BICUBIC) | |
zoom_path = os.path.join(td, f"step{rec+1}_zoom.png") | |
zoomed.save(zoom_path) | |
# (C) Generate a recursive_multiscale VLM “tag” prompt | |
prompt_tag = _generate_vlm_prompt( | |
vlm_model=vlm_model, | |
vlm_processor=vlm_processor, | |
process_vision_info=process_vision_info, | |
prev_image_path=prev_path, | |
zoomed_image_path=zoom_path, | |
device=device, | |
) | |
# (By default, no extra user prompt is appended.) | |
# (D) Prepare the low-res tensor for SR: convert zoomed → Tensor → [0,1] → [−1,1] | |
to_tensor = transforms.ToTensor() | |
lq = to_tensor(zoomed).unsqueeze(0).to(device) # shape (1,3,512,512) | |
lq = (lq * 2.0) - 1.0 | |
# (E) Do SR inference: | |
with torch.no_grad(): | |
out_tensor = model_test(lq, prompt=prompt_tag)[0] # (3,512,512) on CPU or GPU | |
out_tensor = out_tensor.clamp(-1.0, 1.0).cpu() | |
# back to PIL in [0,1]: | |
out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5) | |
# (F) Save this step’s SR output as “prev.png” for next iteration: | |
out_path = os.path.join(td, f"step{rec+1}_sr.png") | |
out_pil.save(out_path) | |
prev_path = out_path | |
# (G) Append the PIL to our list: | |
sr_pil_list.append(out_pil) | |
prompt_list.append(prompt_tag) | |
# end for(rec) | |
############################### | |
# 8. Return the SR outputs & prompts | |
############################### | |
# The list sr_pil_list = [ SR1, SR2, …, SR_rec_num ] in order. | |
return sr_pil_list, prompt_list | |