File size: 2,136 Bytes
313f8f8 1161d79 313f8f8 426d898 313f8f8 ee899cc 313f8f8 1161d79 686892f 92c4446 1161d79 ee899cc 313f8f8 1161d79 ee899cc 29957a2 313f8f8 29957a2 426d898 29957a2 2c298e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel
import torch
import base64
import logging
import numpy as np
import gc
from PIL import Image
from io import BytesIO
import subprocess
logger = logging.getLogger()
logger.setLevel(logging.DEBUG)
class EndpointHandler:
def __init__(self, path=""):
# load the model
self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
logger.info(model.hf_device_map)
model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
print(subprocess.run(["nvidia-smi"]))
def __call__(self, data: Any):
image = data["inputs"]
inputs = self.processor(image, return_tensors="pt")
try:
with torch.no_grad():
outputs = self.model(**inputs)
print(subprocess.run(["nvidia-smi"]))
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, source=0, destination=-1)
output = (output * 255.0).round().astype(np.uint8)
img = Image.fromarray(output)
buffered = BytesIO()
img.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue())
return img_str.decode()
except Exception as e:
logger.error(str(e))
del inputs
gc.collect()
torch.cuda.empty_cache()
return {"error": str(e)}
|