|
from typing import List, Union |
|
from PIL import Image |
|
import torch |
|
import numpy as np |
|
|
|
from diffusers.modular_pipelines import ( |
|
PipelineState, |
|
ModularPipelineBlocks, |
|
InputParam, |
|
ComponentSpec, |
|
OutputParam, |
|
) |
|
from controlnet_aux import CannyDetector |
|
import numpy as np |
|
|
|
|
|
class CannyBlock(ModularPipelineBlocks): |
|
@property |
|
def expected_components(self): |
|
return [] |
|
|
|
@property |
|
def inputs(self) -> List[InputParam]: |
|
return [ |
|
InputParam( |
|
"image", |
|
type_hint=Union[Image.Image, np.ndarray], |
|
required=True, |
|
description="Image to compute canny filter on", |
|
), |
|
InputParam( |
|
"low_threshold", |
|
type_hint=int, |
|
default=50, |
|
), |
|
InputParam("high_threshold", type_hint=int, default=200), |
|
InputParam( |
|
"detect_resolution", |
|
type_hint=int, |
|
default=1024, |
|
description="Resolution to resize to when running the Canny filtering process.", |
|
), |
|
InputParam( |
|
"image_resolution", |
|
type_hint=int, |
|
default=1024, |
|
description="Resolution to resize the detected Canny edge map to.", |
|
), |
|
] |
|
|
|
@property |
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
return [ |
|
OutputParam( |
|
"control_image", |
|
type_hint=Image, |
|
description="Canny map for input image", |
|
) |
|
] |
|
|
|
def compute_canny(self, image, low_threshold, high_threshold, detect_resolution, image_resolution): |
|
canny_detector = CannyDetector() |
|
canny_map = canny_detector( |
|
input_image=image, |
|
low_threshold=low_threshold, |
|
high_threshold=high_threshold, |
|
detect_resolution=detect_resolution, |
|
image_resolution=image_resolution, |
|
) |
|
return canny_map |
|
|
|
@torch.no_grad() |
|
def __call__(self, components, state: PipelineState) -> PipelineState: |
|
block_state = self.get_block_state(state) |
|
|
|
block_state.control_image = self.compute_canny( |
|
block_state.image, |
|
block_state.low_threshold, |
|
block_state.high_threshold, |
|
block_state.detect_resolution, |
|
block_state.image_resolution, |
|
) |
|
self.set_block_state(state, block_state) |
|
|
|
return components, state |
|
|