File size: 2,535 Bytes
6d3538a 66cadc5 6d3538a 17ad199 6d3538a 66cadc5 6d3538a 17ad199 6d3538a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from typing import List, Union
from PIL import Image
import torch
import numpy as np
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
ComponentSpec,
OutputParam,
)
from controlnet_aux import CannyDetector
import numpy as np
class CannyBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[Image.Image, np.ndarray],
required=True,
description="Image to compute canny filter on",
),
InputParam(
"low_threshold",
type_hint=int,
default=50,
),
InputParam("high_threshold", type_hint=int, default=200),
InputParam(
"detect_resolution",
type_hint=int,
default=1024,
description="Resolution to resize to when running the Canny filtering process.",
),
InputParam(
"image_resolution",
type_hint=int,
default=1024,
description="Resolution to resize the detected Canny edge map to.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"control_image",
type_hint=Image,
description="Canny map for input image",
)
]
def compute_canny(self, image, low_threshold, high_threshold, detect_resolution, image_resolution):
canny_detector = CannyDetector()
canny_map = canny_detector(
input_image=image,
low_threshold=low_threshold,
high_threshold=high_threshold,
detect_resolution=detect_resolution,
image_resolution=image_resolution,
)
return canny_map
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.control_image = self.compute_canny(
block_state.image,
block_state.low_threshold,
block_state.high_threshold,
block_state.detect_resolution,
block_state.image_resolution,
)
self.set_block_state(state, block_state)
return components, state
|