File size: 2,454 Bytes
62cc7ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# import numpy as np
import PIL.Image
import torch
import gc
# from controlnet_aux_local import NormalBaeDetector#, CannyDetector
from controlnet_aux import NormalBaeDetector

# from controlnet_aux.util import HWC3
# import cv2
# from cv_utils import resize_image

class Preprocessor:
    MODEL_ID = "lllyasviel/Annotators"
    
    # def resize_image(input_image, resolution, interpolation=None):
    #     H, W, C = input_image.shape
    #     H = float(H)
    #     W = float(W)
    #     k = float(resolution) / max(H, W)
    #     H *= k
    #     W *= k
    #     H = int(np.round(H / 64.0)) * 64
    #     W = int(np.round(W / 64.0)) * 64
    #     if interpolation is None:
    #         interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
    #     img = cv2.resize(input_image, (W, H), interpolation=interpolation)
    #     return img


    def __init__(self):
        self.model = None
        self.name = ""

    def load(self, name: str) -> None:
        if name == self.name:
            return
        elif name == "NormalBae":
            print("Loading NormalBae")
            self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
        # elif name == "Canny":
        #     self.model = CannyDetector()
        else:
            raise ValueError
        torch.cuda.empty_cache()
        gc.collect()
        
        self.name = name

    def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
        # if self.name == "Canny":
        #     if "detect_resolution" in kwargs:
        #         detect_resolution = kwargs.pop("detect_resolution")
        #         image = np.array(image)
        #         image = HWC3(image)
        #         image = resize_image(image, resolution=detect_resolution)
        #     image = self.model(image, **kwargs)
        #     return PIL.Image.fromarray(image)
        # elif self.name == "Midas":
        #     detect_resolution = kwargs.pop("detect_resolution", 512)
        #     image_resolution = kwargs.pop("image_resolution", 512)
        #     image = np.array(image)
        #     image = HWC3(image)
        #     image = resize_image(image, resolution=detect_resolution)
        #     image = self.model(image, **kwargs)
        #     image = HWC3(image)
        #     image = resize_image(image, resolution=image_resolution)
        #     return PIL.Image.fromarray(image)
        # else:
        return self.model(image, **kwargs)