sergeipetrov commited on
Commit
1161d79
·
verified ·
1 Parent(s): 313f8f8

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +7 -8
handler.py CHANGED
@@ -1,5 +1,5 @@
1
  from typing import Dict, List, Any
2
- from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution
3
  import torch
4
  import base64
5
  import logging
@@ -10,21 +10,20 @@ from io import BytesIO
10
  logger = logging.getLogger()
11
  logger.setLevel(logging.DEBUG)
12
 
13
- # check for GPU
14
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
-
16
 
17
  class EndpointHandler:
18
  def __init__(self, path=""):
19
  # load the model
20
  self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
21
- self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
22
- # move model to device
23
- self.model.to(device)
 
 
24
 
25
  def __call__(self, data: Any):
26
  image = data["inputs"]
27
- inputs = self.processor(image, return_tensors="pt").to(device)
28
  with torch.no_grad():
29
  outputs = self.model(**inputs)
30
 
 
1
  from typing import Dict, List, Any
2
+ from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, Swin2SRModel
3
  import torch
4
  import base64
5
  import logging
 
10
  logger = logging.getLogger()
11
  logger.setLevel(logging.DEBUG)
12
 
 
 
 
13
 
14
  class EndpointHandler:
15
  def __init__(self, path=""):
16
  # load the model
17
  self.processor = AutoImageProcessor.from_pretrained("caidas/swin2SR-classical-sr-x2-64")
18
+ Swin2SRModel._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
19
+ Swin2SRForImageSuperResolution._no_split_modules = ["Swin2SREmbeddings", "Swin2SRStage"]
20
+ model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map="auto")
21
+ model.hf_device_map["upsample"] = model.hf_device_map["swin2sr.embeddings"]
22
+ self.model = Swin2SRForImageSuperResolution.from_pretrained("caidas/swin2SR-classical-sr-x2-64", device_map=model.hf_device_map)
23
 
24
  def __call__(self, data: Any):
25
  image = data["inputs"]
26
+ inputs = self.processor(image, return_tensors="pt")
27
  with torch.no_grad():
28
  outputs = self.model(**inputs)
29