from typing import Dict, List, Any from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel import torch import base64 import logging import numpy as np from PIL import Image from io import BytesIO logger = logging.getLogger() logger.setLevel(logging.DEBUG) class EndpointHandler: def __init__(self, path=""): # load the model self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64") Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"] model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto") logger.info(model.hf_device_map) model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"] model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"] self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map) def __call__(self, data: Any): image = data["inputs"] inputs = self.processor(image, return_tensors="pt") with torch.no_grad(): outputs = self.model(**inputs) output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() output = np.moveaxis(output, source=0, destination=-1) output = (output * 255.0).round().astype(np.uint8) img = Image.fromarray(output) buffered = BytesIO() img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()) return img_str.decode()