import os import tempfile import uuid import torch from PIL import Image from torchvision import transforms from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler # ------------------------------------------------------------------- # Helper: Resize & center-crop to a fixed square # ------------------------------------------------------------------- def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: w, h = img.size scale = size / min(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.LANCZOS) left = (new_w - size) // 2 top = (new_h - size) // 2 return img.crop((left, top, left + size, top + size)) # ------------------------------------------------------------------- # Helper: Generate a single VLM prompt for recursive_multiscale # ------------------------------------------------------------------- def _generate_vlm_prompt( vlm_model: Qwen2_5_VLForConditionalGeneration, vlm_processor: AutoProcessor, process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs prev_pil: Image.Image, # <– pass PIL instead of path zoomed_pil: Image.Image, # <– pass PIL instead of path device: str = "cuda" ) -> str: """ Given two PIL.Image inputs: - prev_pil: the “full” image at the previous recursion. - zoomed_pil: the cropped+resized (zoom) image for this step. Returns a single “recursive_multiscale” prompt string. """ # (1) System message message_text = ( "The second image is a zoom-in of the first image. " "Based on this knowledge, what is in the second image? " "Give me a set of words." ) # (2) Build the two-image “chat” payload # # Instead of passing a filename, we pass the actual PIL.Image. # The processor’s `process_vision_info` should know how to turn # a message of the form {"type":"image","image": PIL_IMAGE} into tensors. messages = [ {"role": "system", "content": message_text}, { "role": "user", "content": [ {"type": "image", "image": prev_pil}, {"type": "image", "image": zoomed_pil}, ], }, ] # (3) Now run the “chat” through the VL processor # # - `apply_chat_template` will build the tokenized prompt (without running it yet). # - `process_vision_info` should inspect the same `messages` list and return # `image_inputs` and `video_inputs` (tensors) for any attached PIL images. text = vlm_processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) image_inputs, video_inputs = process_vision_info(messages) inputs = vlm_processor( text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt", ).to(device) # (4) Generate and decode generated = vlm_model.generate(**inputs, max_new_tokens=128) trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated) ] out_text = vlm_processor.batch_decode( trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False )[0] return out_text.strip() # ------------------------------------------------------------------- # Main Function: recursive_multiscale_sr (with multiple centers) # ------------------------------------------------------------------- def recursive_multiscale_sr( input_png_path: str, upscale: int, rec_num: int = 4, centers: list[tuple[float, float]] = None, ) -> tuple[list[Image.Image], list[str]]: """ Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG. - input_png_path: path to a single .png file on disk. - upscale: integer up-scale factor per recursion (e.g. 4). - rec_num: how many recursion steps to perform. - centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step, indicating where to center the low-res crop for each step. The list length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5) for all steps. Returns a tuple (sr_pil_list, prompt_list), where: - sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order. - prompt_list: list of the VLM prompts generated at each recursion. """ ############################### # 0. Validate / fill default centers ############################### if centers is None: # Default: use center (0.5, 0.5) for every recursion centers = [(0.5, 0.5) for _ in range(rec_num)] else: if not isinstance(centers, (list, tuple)) or len(centers) != rec_num: raise ValueError( f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}." ) ############################### # 1. Fixed hyper-parameters ############################### device = "cuda" process_size = 512 # same as args.process_size # model checkpoint paths (hard-coded to your example) LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl" VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt" SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers" # VLM model name (hard-coded) VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" ############################### # 2. Build a dummy “args” namespace # to satisfy OSEDiff_SD3_TEST constructor. ############################### class _Args: pass args = _Args() args.upscale = upscale args.lora_path = LORA_PATH args.vae_path = VAE_PATH args.pretrained_model_name_or_path = SD3_MODEL args.merge_and_unload_lora = False args.lora_rank = 4 args.vae_decoder_tiled_size = 224 args.vae_encoder_tiled_size = 1024 args.latent_tiled_size = 96 args.latent_tiled_overlap = 32 args.mixed_precision = "fp16" args.efficient_memory = False # (other flags are not used by OSEDiff_SD3_TEST, so we skip them) ############################### # 3. Load the SD3 SR model (non-efficient) ############################### # 3.1 Instantiate the underlying SD3-Euler UNet/VAE/text encoders sd3 = SD3Euler() # move all text encoders + transformer + VAE to CUDA: sd3.text_enc_1.to(device) sd3.text_enc_2.to(device) sd3.text_enc_3.to(device) sd3.transformer.to(device, dtype=torch.float32) sd3.vae.to(device, dtype=torch.float32) # freeze for p in ( sd3.text_enc_1, sd3.text_enc_2, sd3.text_enc_3, sd3.transformer, sd3.vae, ): p.requires_grad_(False) # 3.2 Wrap in OSEDiff_SD3_TEST helper: model_test = OSEDiff_SD3_TEST(args, sd3) # (by default, “model_test(...)” takes (lq_tensor, prompt=str) and returns a list[tensor]) ############################### # 4. Load the VLM (Qwen2.5-VL) ############################### vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( VLM_NAME, torch_dtype="auto", device_map="auto" # immediately dispatches layers onto available GPUs ) vlm_processor = AutoProcessor.from_pretrained(VLM_NAME) ############################### # 5. Pre-allocate a Temporary Directory # to hold intermediate JPEG/PNG files ############################### unique_id = uuid.uuid4().hex prefix = f"recms_{unique_id}_" with tempfile.TemporaryDirectory(prefix=prefix) as td: # (we’ll write “prev.png” and “zoom.png” at each step) ############################### # 6. Prepare the very first “full” image ############################### # (6.1) Load + center crop → first_image (512×512) img0 = Image.open(input_png_path).convert("RGB") img0 = resize_and_center_crop(img0, process_size) # Note: we no longer need to write “prev.png” to disk. Just keep it in memory. prev_pil = img0.copy() sr_pil_list: list[Image.Image] = [] prompt_list: list[str] = [] for rec in range(rec_num): # (A) Compute low-res crop window on prev_pil w, h = prev_pil.size # (512×512) new_w, new_h = w // upscale, h // upscale cx_norm, cy_norm = centers[rec] cx = int(cx_norm * w) cy = int(cy_norm * h) half_w, half_h = new_w // 2, new_h // 2 left = max(0, min(cx - half_w, w - new_w)) top = max(0, min(cy - half_h, h - new_h)) right, bottom = left + new_w, top + new_h cropped = prev_pil.crop((left, top, right, bottom)) # (B) Upsample that crop back to (512×512) zoomed_pil = cropped.resize((w, h), Image.BICUBIC) # (C) Generate VLM prompt by passing PILs directly: prompt_tag = _generate_vlm_prompt( vlm_model=vlm_model, vlm_processor=vlm_processor, process_vision_info=process_vision_info, prev_pil=prev_pil, # <– PIL zoomed_pil=zoomed_pil, # <– PIL device=device, ) # (D) Prepare “zoomed_pil” → tensor in [−1, 1] to_tensor = transforms.ToTensor() lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512) lq = (lq * 2.0) - 1.0 # (E) Run SR inference with torch.no_grad(): out_tensor = model_test(lq, prompt=prompt_tag)[0] out_tensor = out_tensor.clamp(-1.0, 1.0).cpu() out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5) # (F) Bookkeeping: set prev_pil = out_pil for next iteration prev_pil = out_pil # (G) Append to results sr_pil_list.append(out_pil) prompt_list.append(prompt_tag) return sr_pil_list, prompt_list