File size: 2,136 Bytes
313f8f8
1161d79
313f8f8
 
 
 
426d898
313f8f8
 
ee899cc
 
313f8f8
 
 
 
 
 
 
 
 
1161d79
 
 
686892f
92c4446
1161d79
 
ee899cc
313f8f8
 
 
1161d79
ee899cc
29957a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313f8f8
29957a2
 
 
426d898
29957a2
2c298e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from typing import Dict, List, Any
from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel
import torch
import base64
import logging
import numpy as np
import gc
from PIL import Image
from io import BytesIO
import subprocess


logger = logging.getLogger()
logger.setLevel(logging.DEBUG)


class EndpointHandler:
    def __init__(self, path=""):
        # load the model
        self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
        Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
        Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
        model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
        logger.info(model.hf_device_map)
        model.hf_device_map["swin2sr.conv_after_body"] = model.hf_device_map["swin2sr.embeddings"]
        model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
        self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
        print(subprocess.run(["nvidia-smi"]))

    def __call__(self, data: Any):
        image = data["inputs"]
        inputs = self.processor(image, return_tensors="pt")

        try:
            with torch.no_grad():
                outputs = self.model(**inputs)

            print(subprocess.run(["nvidia-smi"]))
            
            output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
            output = np.moveaxis(output, source=0, destination=-1)
            output = (output * 255.0).round().astype(np.uint8)
    
            img = Image.fromarray(output)
            buffered = BytesIO()
            img.save(buffered, format="JPEG")
            img_str = base64.b64encode(buffered.getvalue())
    
            return img_str.decode()
        
        except Exception as e:
            logger.error(str(e))
            del inputs
            gc.collect()
            torch.cuda.empty_cache()

            return {"error": str(e)}