sergeipetrov commited on
Commit
313f8f8
·
verified ·
0 Parent(s):

Duplicate from sergeipetrov/swin2SR-classical-sr-x2-64-IE

Browse files
Files changed (2) hide show
  1. handler.py +40 -0
  2. requirements.txt +1 -0
handler.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Any
2
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
3
+ import torch
4
+ import base64
5
+ import logging
6
+ import numpy as np
7
+ from PIL import Image
8
+ from io import BytesIO
9
+
10
+ logger = logging.getLogger()
11
+ logger.setLevel(logging.DEBUG)
12
+
13
+ # check for GPU
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
+ class EndpointHandler:
18
+ def __init__(self, path=""):
19
+ # load the model
20
+ self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
21
+ self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
22
+ # move model to device
23
+ self.model.to(device)
24
+
25
+ def __call__(self, data: Any):
26
+ image = data["inputs"]
27
+ inputs = self.processor(image, return_tensors="pt").to(device)
28
+ with torch.no_grad():
29
+ outputs = self.model(**inputs)
30
+
31
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
32
+ output = np.moveaxis(output, source=0, destination=-1)
33
+ output = (output * 255.0).round().astype(np.uint8)
34
+
35
+ img = Image.fromarray(output)
36
+ buffered = BytesIO()
37
+ img.save(buffered, format="JPEG")
38
+ img_str = base64.b64encode(buffered.getvalue())
39
+
40
+ return img_str.decode()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ pillow