Spaces:
Running
on
T4
Running
on
T4
DeepBeepMeep
commited on
Commit
·
5747c0d
1
Parent(s):
12652e0
Simplified Vace, added auto open pose and depth extrators
Browse files- README.md +9 -8
- preprocessing/dwpose/__init__.py +2 -0
- preprocessing/dwpose/onnxdet.py +127 -0
- preprocessing/dwpose/onnxpose.py +362 -0
- preprocessing/dwpose/pose.py +183 -0
- preprocessing/dwpose/util.py +299 -0
- preprocessing/dwpose/wholebody.py +80 -0
- preprocessing/gray.py +35 -0
- preprocessing/midas/__init__.py +2 -0
- preprocessing/midas/api.py +166 -0
- preprocessing/midas/base_model.py +18 -0
- preprocessing/midas/blocks.py +391 -0
- preprocessing/midas/depth.py +84 -0
- preprocessing/midas/dpt_depth.py +107 -0
- preprocessing/midas/midas_net.py +80 -0
- preprocessing/midas/midas_net_custom.py +167 -0
- preprocessing/midas/transforms.py +231 -0
- preprocessing/midas/utils.py +193 -0
- preprocessing/midas/vit.py +510 -0
- wan/text2video.py +5 -3
- wan/utils/utils.py +24 -0
- wan/utils/vace_preprocessor.py +39 -34
- wgp.py +204 -61
README.md
CHANGED
@@ -14,7 +14,7 @@
|
|
14 |
|
15 |
|
16 |
## 🔥 Latest News!!
|
17 |
-
* April
|
18 |
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
|
19 |
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
|
20 |
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
|
@@ -272,23 +272,24 @@ You can define multiple lines of macros. If there is only one macro line, the ap
|
|
272 |
|
273 |
### VACE ControlNet introduction
|
274 |
|
275 |
-
Vace is a ControlNet 1.3B text2video model that allows you
|
276 |
|
277 |
First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
|
278 |
|
279 |
Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
|
280 |
-
-
|
281 |
|
282 |
-
-
|
283 |
|
284 |
- a Video Mask
|
285 |
-
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint.
|
286 |
|
287 |
|
288 |
Examples:
|
289 |
-
- Inject people and / objects into a scene describe by a text
|
290 |
-
- Animate a character described in a text prompt:
|
291 |
-
- Animate a character of your choice : Ref Images +
|
|
|
292 |
|
293 |
|
294 |
There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).
|
|
|
14 |
|
15 |
|
16 |
## 🔥 Latest News!!
|
17 |
+
* April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
|
18 |
- A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
|
19 |
- Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
|
20 |
- Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
|
|
|
272 |
|
273 |
### VACE ControlNet introduction
|
274 |
|
275 |
+
Vace is a ControlNet 1.3B text2video model that allows you to do Video to Video and Reference to Video (inject your own images into the output video). So with Vace you can inject in the scene people or objects of your choice, animate a person, perform inpainting or outpainting, continue a video, ...
|
276 |
|
277 |
First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
|
278 |
|
279 |
Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
|
280 |
+
- a Control Video: Based on your choice, you can decide to transfer the motion, the depth in a new Video. You can tell WanGP to use only the first n frames of Control Video and to extrapolate the rest. You can also do inpainting ). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images.
|
281 |
|
282 |
+
- reference Images: Use this to inject people or objects of your choice in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
|
283 |
|
284 |
- a Video Mask
|
285 |
+
This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint. If a video mask is white, it will be generated so with black frames at the beginning and at the end and the rest white, you could generate the missing frames in between.
|
286 |
|
287 |
|
288 |
Examples:
|
289 |
+
- Inject people and / objects into a scene describe by a text prompt: Ref. Images + text Prompt
|
290 |
+
- Animate a character described in a text prompt: a Video of person moving + text Prompt
|
291 |
+
- Animate a character of your choice (pose transfer) : Ref Images + a Video of person moving + text Prompt
|
292 |
+
- Change the style of a scene (depth transfer): a Video that contains objects / person at differen depths + text Prompt
|
293 |
|
294 |
|
295 |
There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).
|
preprocessing/dwpose/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
preprocessing/dwpose/onnxdet.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import onnxruntime
|
7 |
+
|
8 |
+
def nms(boxes, scores, nms_thr):
|
9 |
+
"""Single class NMS implemented in Numpy."""
|
10 |
+
x1 = boxes[:, 0]
|
11 |
+
y1 = boxes[:, 1]
|
12 |
+
x2 = boxes[:, 2]
|
13 |
+
y2 = boxes[:, 3]
|
14 |
+
|
15 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
16 |
+
order = scores.argsort()[::-1]
|
17 |
+
|
18 |
+
keep = []
|
19 |
+
while order.size > 0:
|
20 |
+
i = order[0]
|
21 |
+
keep.append(i)
|
22 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
23 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
24 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
25 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
26 |
+
|
27 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
28 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
29 |
+
inter = w * h
|
30 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
31 |
+
|
32 |
+
inds = np.where(ovr <= nms_thr)[0]
|
33 |
+
order = order[inds + 1]
|
34 |
+
|
35 |
+
return keep
|
36 |
+
|
37 |
+
def multiclass_nms(boxes, scores, nms_thr, score_thr):
|
38 |
+
"""Multiclass NMS implemented in Numpy. Class-aware version."""
|
39 |
+
final_dets = []
|
40 |
+
num_classes = scores.shape[1]
|
41 |
+
for cls_ind in range(num_classes):
|
42 |
+
cls_scores = scores[:, cls_ind]
|
43 |
+
valid_score_mask = cls_scores > score_thr
|
44 |
+
if valid_score_mask.sum() == 0:
|
45 |
+
continue
|
46 |
+
else:
|
47 |
+
valid_scores = cls_scores[valid_score_mask]
|
48 |
+
valid_boxes = boxes[valid_score_mask]
|
49 |
+
keep = nms(valid_boxes, valid_scores, nms_thr)
|
50 |
+
if len(keep) > 0:
|
51 |
+
cls_inds = np.ones((len(keep), 1)) * cls_ind
|
52 |
+
dets = np.concatenate(
|
53 |
+
[valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
|
54 |
+
)
|
55 |
+
final_dets.append(dets)
|
56 |
+
if len(final_dets) == 0:
|
57 |
+
return None
|
58 |
+
return np.concatenate(final_dets, 0)
|
59 |
+
|
60 |
+
def demo_postprocess(outputs, img_size, p6=False):
|
61 |
+
grids = []
|
62 |
+
expanded_strides = []
|
63 |
+
strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
|
64 |
+
|
65 |
+
hsizes = [img_size[0] // stride for stride in strides]
|
66 |
+
wsizes = [img_size[1] // stride for stride in strides]
|
67 |
+
|
68 |
+
for hsize, wsize, stride in zip(hsizes, wsizes, strides):
|
69 |
+
xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
|
70 |
+
grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
|
71 |
+
grids.append(grid)
|
72 |
+
shape = grid.shape[:2]
|
73 |
+
expanded_strides.append(np.full((*shape, 1), stride))
|
74 |
+
|
75 |
+
grids = np.concatenate(grids, 1)
|
76 |
+
expanded_strides = np.concatenate(expanded_strides, 1)
|
77 |
+
outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
|
78 |
+
outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
|
79 |
+
|
80 |
+
return outputs
|
81 |
+
|
82 |
+
def preprocess(img, input_size, swap=(2, 0, 1)):
|
83 |
+
if len(img.shape) == 3:
|
84 |
+
padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
|
85 |
+
else:
|
86 |
+
padded_img = np.ones(input_size, dtype=np.uint8) * 114
|
87 |
+
|
88 |
+
r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
|
89 |
+
resized_img = cv2.resize(
|
90 |
+
img,
|
91 |
+
(int(img.shape[1] * r), int(img.shape[0] * r)),
|
92 |
+
interpolation=cv2.INTER_LINEAR,
|
93 |
+
).astype(np.uint8)
|
94 |
+
padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
|
95 |
+
|
96 |
+
padded_img = padded_img.transpose(swap)
|
97 |
+
padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
|
98 |
+
return padded_img, r
|
99 |
+
|
100 |
+
def inference_detector(session, oriImg):
|
101 |
+
input_shape = (640,640)
|
102 |
+
img, ratio = preprocess(oriImg, input_shape)
|
103 |
+
|
104 |
+
ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
|
105 |
+
output = session.run(None, ort_inputs)
|
106 |
+
predictions = demo_postprocess(output[0], input_shape)[0]
|
107 |
+
|
108 |
+
boxes = predictions[:, :4]
|
109 |
+
scores = predictions[:, 4:5] * predictions[:, 5:]
|
110 |
+
|
111 |
+
boxes_xyxy = np.ones_like(boxes)
|
112 |
+
boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
|
113 |
+
boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
|
114 |
+
boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
|
115 |
+
boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
|
116 |
+
boxes_xyxy /= ratio
|
117 |
+
dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
|
118 |
+
if dets is not None:
|
119 |
+
final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
|
120 |
+
isscore = final_scores>0.3
|
121 |
+
iscat = final_cls_inds == 0
|
122 |
+
isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
|
123 |
+
final_boxes = final_boxes[isbbox]
|
124 |
+
else:
|
125 |
+
final_boxes = np.array([])
|
126 |
+
|
127 |
+
return final_boxes
|
preprocessing/dwpose/onnxpose.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import onnxruntime as ort
|
8 |
+
|
9 |
+
def preprocess(
|
10 |
+
img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
|
11 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
12 |
+
"""Do preprocessing for RTMPose model inference.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
img (np.ndarray): Input image in shape.
|
16 |
+
input_size (tuple): Input image size in shape (w, h).
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple:
|
20 |
+
- resized_img (np.ndarray): Preprocessed image.
|
21 |
+
- center (np.ndarray): Center of image.
|
22 |
+
- scale (np.ndarray): Scale of image.
|
23 |
+
"""
|
24 |
+
# get shape of image
|
25 |
+
img_shape = img.shape[:2]
|
26 |
+
out_img, out_center, out_scale = [], [], []
|
27 |
+
if len(out_bbox) == 0:
|
28 |
+
out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
|
29 |
+
for i in range(len(out_bbox)):
|
30 |
+
x0 = out_bbox[i][0]
|
31 |
+
y0 = out_bbox[i][1]
|
32 |
+
x1 = out_bbox[i][2]
|
33 |
+
y1 = out_bbox[i][3]
|
34 |
+
bbox = np.array([x0, y0, x1, y1])
|
35 |
+
|
36 |
+
# get center and scale
|
37 |
+
center, scale = bbox_xyxy2cs(bbox, padding=1.25)
|
38 |
+
|
39 |
+
# do affine transformation
|
40 |
+
resized_img, scale = top_down_affine(input_size, scale, center, img)
|
41 |
+
|
42 |
+
# normalize image
|
43 |
+
mean = np.array([123.675, 116.28, 103.53])
|
44 |
+
std = np.array([58.395, 57.12, 57.375])
|
45 |
+
resized_img = (resized_img - mean) / std
|
46 |
+
|
47 |
+
out_img.append(resized_img)
|
48 |
+
out_center.append(center)
|
49 |
+
out_scale.append(scale)
|
50 |
+
|
51 |
+
return out_img, out_center, out_scale
|
52 |
+
|
53 |
+
|
54 |
+
def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
|
55 |
+
"""Inference RTMPose model.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
sess (ort.InferenceSession): ONNXRuntime session.
|
59 |
+
img (np.ndarray): Input image in shape.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
outputs (np.ndarray): Output of RTMPose model.
|
63 |
+
"""
|
64 |
+
all_out = []
|
65 |
+
# build input
|
66 |
+
for i in range(len(img)):
|
67 |
+
input = [img[i].transpose(2, 0, 1)]
|
68 |
+
|
69 |
+
# build output
|
70 |
+
sess_input = {sess.get_inputs()[0].name: input}
|
71 |
+
sess_output = []
|
72 |
+
for out in sess.get_outputs():
|
73 |
+
sess_output.append(out.name)
|
74 |
+
|
75 |
+
# run model
|
76 |
+
outputs = sess.run(sess_output, sess_input)
|
77 |
+
all_out.append(outputs)
|
78 |
+
|
79 |
+
return all_out
|
80 |
+
|
81 |
+
|
82 |
+
def postprocess(outputs: List[np.ndarray],
|
83 |
+
model_input_size: Tuple[int, int],
|
84 |
+
center: Tuple[int, int],
|
85 |
+
scale: Tuple[int, int],
|
86 |
+
simcc_split_ratio: float = 2.0
|
87 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
88 |
+
"""Postprocess for RTMPose model output.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
outputs (np.ndarray): Output of RTMPose model.
|
92 |
+
model_input_size (tuple): RTMPose model Input image size.
|
93 |
+
center (tuple): Center of bbox in shape (x, y).
|
94 |
+
scale (tuple): Scale of bbox in shape (w, h).
|
95 |
+
simcc_split_ratio (float): Split ratio of simcc.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
tuple:
|
99 |
+
- keypoints (np.ndarray): Rescaled keypoints.
|
100 |
+
- scores (np.ndarray): Model predict scores.
|
101 |
+
"""
|
102 |
+
all_key = []
|
103 |
+
all_score = []
|
104 |
+
for i in range(len(outputs)):
|
105 |
+
# use simcc to decode
|
106 |
+
simcc_x, simcc_y = outputs[i]
|
107 |
+
keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
|
108 |
+
|
109 |
+
# rescale keypoints
|
110 |
+
keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
|
111 |
+
all_key.append(keypoints[0])
|
112 |
+
all_score.append(scores[0])
|
113 |
+
|
114 |
+
return np.array(all_key), np.array(all_score)
|
115 |
+
|
116 |
+
|
117 |
+
def bbox_xyxy2cs(bbox: np.ndarray,
|
118 |
+
padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
|
119 |
+
"""Transform the bbox format from (x,y,w,h) into (center, scale)
|
120 |
+
|
121 |
+
Args:
|
122 |
+
bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
|
123 |
+
as (left, top, right, bottom)
|
124 |
+
padding (float): BBox padding factor that will be multilied to scale.
|
125 |
+
Default: 1.0
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
tuple: A tuple containing center and scale.
|
129 |
+
- np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
|
130 |
+
(n, 2)
|
131 |
+
- np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
|
132 |
+
(n, 2)
|
133 |
+
"""
|
134 |
+
# convert single bbox from (4, ) to (1, 4)
|
135 |
+
dim = bbox.ndim
|
136 |
+
if dim == 1:
|
137 |
+
bbox = bbox[None, :]
|
138 |
+
|
139 |
+
# get bbox center and scale
|
140 |
+
x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
|
141 |
+
center = np.hstack([x1 + x2, y1 + y2]) * 0.5
|
142 |
+
scale = np.hstack([x2 - x1, y2 - y1]) * padding
|
143 |
+
|
144 |
+
if dim == 1:
|
145 |
+
center = center[0]
|
146 |
+
scale = scale[0]
|
147 |
+
|
148 |
+
return center, scale
|
149 |
+
|
150 |
+
|
151 |
+
def _fix_aspect_ratio(bbox_scale: np.ndarray,
|
152 |
+
aspect_ratio: float) -> np.ndarray:
|
153 |
+
"""Extend the scale to match the given aspect ratio.
|
154 |
+
|
155 |
+
Args:
|
156 |
+
scale (np.ndarray): The image scale (w, h) in shape (2, )
|
157 |
+
aspect_ratio (float): The ratio of ``w/h``
|
158 |
+
|
159 |
+
Returns:
|
160 |
+
np.ndarray: The reshaped image scale in (2, )
|
161 |
+
"""
|
162 |
+
w, h = np.hsplit(bbox_scale, [1])
|
163 |
+
bbox_scale = np.where(w > h * aspect_ratio,
|
164 |
+
np.hstack([w, w / aspect_ratio]),
|
165 |
+
np.hstack([h * aspect_ratio, h]))
|
166 |
+
return bbox_scale
|
167 |
+
|
168 |
+
|
169 |
+
def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
|
170 |
+
"""Rotate a point by an angle.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
|
174 |
+
angle_rad (float): rotation angle in radian
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
np.ndarray: Rotated point in shape (2, )
|
178 |
+
"""
|
179 |
+
sn, cs = np.sin(angle_rad), np.cos(angle_rad)
|
180 |
+
rot_mat = np.array([[cs, -sn], [sn, cs]])
|
181 |
+
return rot_mat @ pt
|
182 |
+
|
183 |
+
|
184 |
+
def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
|
185 |
+
"""To calculate the affine matrix, three pairs of points are required. This
|
186 |
+
function is used to get the 3rd point, given 2D points a & b.
|
187 |
+
|
188 |
+
The 3rd point is defined by rotating vector `a - b` by 90 degrees
|
189 |
+
anticlockwise, using b as the rotation center.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
a (np.ndarray): The 1st point (x,y) in shape (2, )
|
193 |
+
b (np.ndarray): The 2nd point (x,y) in shape (2, )
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
np.ndarray: The 3rd point.
|
197 |
+
"""
|
198 |
+
direction = a - b
|
199 |
+
c = b + np.r_[-direction[1], direction[0]]
|
200 |
+
return c
|
201 |
+
|
202 |
+
|
203 |
+
def get_warp_matrix(center: np.ndarray,
|
204 |
+
scale: np.ndarray,
|
205 |
+
rot: float,
|
206 |
+
output_size: Tuple[int, int],
|
207 |
+
shift: Tuple[float, float] = (0., 0.),
|
208 |
+
inv: bool = False) -> np.ndarray:
|
209 |
+
"""Calculate the affine transformation matrix that can warp the bbox area
|
210 |
+
in the input image to the output size.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
center (np.ndarray[2, ]): Center of the bounding box (x, y).
|
214 |
+
scale (np.ndarray[2, ]): Scale of the bounding box
|
215 |
+
wrt [width, height].
|
216 |
+
rot (float): Rotation angle (degree).
|
217 |
+
output_size (np.ndarray[2, ] | list(2,)): Size of the
|
218 |
+
destination heatmaps.
|
219 |
+
shift (0-100%): Shift translation ratio wrt the width/height.
|
220 |
+
Default (0., 0.).
|
221 |
+
inv (bool): Option to inverse the affine transform direction.
|
222 |
+
(inv=False: src->dst or inv=True: dst->src)
|
223 |
+
|
224 |
+
Returns:
|
225 |
+
np.ndarray: A 2x3 transformation matrix
|
226 |
+
"""
|
227 |
+
shift = np.array(shift)
|
228 |
+
src_w = scale[0]
|
229 |
+
dst_w = output_size[0]
|
230 |
+
dst_h = output_size[1]
|
231 |
+
|
232 |
+
# compute transformation matrix
|
233 |
+
rot_rad = np.deg2rad(rot)
|
234 |
+
src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
|
235 |
+
dst_dir = np.array([0., dst_w * -0.5])
|
236 |
+
|
237 |
+
# get four corners of the src rectangle in the original image
|
238 |
+
src = np.zeros((3, 2), dtype=np.float32)
|
239 |
+
src[0, :] = center + scale * shift
|
240 |
+
src[1, :] = center + src_dir + scale * shift
|
241 |
+
src[2, :] = _get_3rd_point(src[0, :], src[1, :])
|
242 |
+
|
243 |
+
# get four corners of the dst rectangle in the input image
|
244 |
+
dst = np.zeros((3, 2), dtype=np.float32)
|
245 |
+
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
|
246 |
+
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
|
247 |
+
dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
|
248 |
+
|
249 |
+
if inv:
|
250 |
+
warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
|
251 |
+
else:
|
252 |
+
warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
|
253 |
+
|
254 |
+
return warp_mat
|
255 |
+
|
256 |
+
|
257 |
+
def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
|
258 |
+
img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
259 |
+
"""Get the bbox image as the model input by affine transform.
|
260 |
+
|
261 |
+
Args:
|
262 |
+
input_size (dict): The input size of the model.
|
263 |
+
bbox_scale (dict): The bbox scale of the img.
|
264 |
+
bbox_center (dict): The bbox center of the img.
|
265 |
+
img (np.ndarray): The original image.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
tuple: A tuple containing center and scale.
|
269 |
+
- np.ndarray[float32]: img after affine transform.
|
270 |
+
- np.ndarray[float32]: bbox scale after affine transform.
|
271 |
+
"""
|
272 |
+
w, h = input_size
|
273 |
+
warp_size = (int(w), int(h))
|
274 |
+
|
275 |
+
# reshape bbox to fixed aspect ratio
|
276 |
+
bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
|
277 |
+
|
278 |
+
# get the affine matrix
|
279 |
+
center = bbox_center
|
280 |
+
scale = bbox_scale
|
281 |
+
rot = 0
|
282 |
+
warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
|
283 |
+
|
284 |
+
# do affine transform
|
285 |
+
img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
|
286 |
+
|
287 |
+
return img, bbox_scale
|
288 |
+
|
289 |
+
|
290 |
+
def get_simcc_maximum(simcc_x: np.ndarray,
|
291 |
+
simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
292 |
+
"""Get maximum response location and value from simcc representations.
|
293 |
+
|
294 |
+
Note:
|
295 |
+
instance number: N
|
296 |
+
num_keypoints: K
|
297 |
+
heatmap height: H
|
298 |
+
heatmap width: W
|
299 |
+
|
300 |
+
Args:
|
301 |
+
simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
|
302 |
+
simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
|
303 |
+
|
304 |
+
Returns:
|
305 |
+
tuple:
|
306 |
+
- locs (np.ndarray): locations of maximum heatmap responses in shape
|
307 |
+
(K, 2) or (N, K, 2)
|
308 |
+
- vals (np.ndarray): values of maximum heatmap responses in shape
|
309 |
+
(K,) or (N, K)
|
310 |
+
"""
|
311 |
+
N, K, Wx = simcc_x.shape
|
312 |
+
simcc_x = simcc_x.reshape(N * K, -1)
|
313 |
+
simcc_y = simcc_y.reshape(N * K, -1)
|
314 |
+
|
315 |
+
# get maximum value locations
|
316 |
+
x_locs = np.argmax(simcc_x, axis=1)
|
317 |
+
y_locs = np.argmax(simcc_y, axis=1)
|
318 |
+
locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
|
319 |
+
max_val_x = np.amax(simcc_x, axis=1)
|
320 |
+
max_val_y = np.amax(simcc_y, axis=1)
|
321 |
+
|
322 |
+
# get maximum value across x and y axis
|
323 |
+
mask = max_val_x > max_val_y
|
324 |
+
max_val_x[mask] = max_val_y[mask]
|
325 |
+
vals = max_val_x
|
326 |
+
locs[vals <= 0.] = -1
|
327 |
+
|
328 |
+
# reshape
|
329 |
+
locs = locs.reshape(N, K, 2)
|
330 |
+
vals = vals.reshape(N, K)
|
331 |
+
|
332 |
+
return locs, vals
|
333 |
+
|
334 |
+
|
335 |
+
def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
|
336 |
+
simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
|
337 |
+
"""Modulate simcc distribution with Gaussian.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
|
341 |
+
simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
|
342 |
+
simcc_split_ratio (int): The split ratio of simcc.
|
343 |
+
|
344 |
+
Returns:
|
345 |
+
tuple: A tuple containing center and scale.
|
346 |
+
- np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
|
347 |
+
- np.ndarray[float32]: scores in shape (K,) or (n, K)
|
348 |
+
"""
|
349 |
+
keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
|
350 |
+
keypoints /= simcc_split_ratio
|
351 |
+
|
352 |
+
return keypoints, scores
|
353 |
+
|
354 |
+
|
355 |
+
def inference_pose(session, out_bbox, oriImg):
|
356 |
+
h, w = session.get_inputs()[0].shape[2:]
|
357 |
+
model_input_size = (w, h)
|
358 |
+
resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
|
359 |
+
outputs = inference(session, resized_img)
|
360 |
+
keypoints, scores = postprocess(outputs, model_input_size, center, scale)
|
361 |
+
|
362 |
+
return keypoints, scores
|
preprocessing/dwpose/pose.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import os
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
from . import util
|
10 |
+
from .wholebody import Wholebody, HWC3, resize_image
|
11 |
+
from PIL import Image
|
12 |
+
|
13 |
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
14 |
+
|
15 |
+
def convert_to_numpy(image):
|
16 |
+
if isinstance(image, Image.Image):
|
17 |
+
image = np.array(image)
|
18 |
+
elif isinstance(image, torch.Tensor):
|
19 |
+
image = image.detach().cpu().numpy()
|
20 |
+
elif isinstance(image, np.ndarray):
|
21 |
+
image = image.copy()
|
22 |
+
else:
|
23 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
24 |
+
return image
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
|
29 |
+
bodies = pose['bodies']
|
30 |
+
faces = pose['faces']
|
31 |
+
hands = pose['hands']
|
32 |
+
candidate = bodies['candidate']
|
33 |
+
subset = bodies['subset']
|
34 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
35 |
+
|
36 |
+
if use_body:
|
37 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
38 |
+
if use_hand:
|
39 |
+
canvas = util.draw_handpose(canvas, hands)
|
40 |
+
if use_face:
|
41 |
+
canvas = util.draw_facepose(canvas, faces)
|
42 |
+
|
43 |
+
return canvas
|
44 |
+
|
45 |
+
|
46 |
+
class PoseAnnotator:
|
47 |
+
def __init__(self, cfg, device=None):
|
48 |
+
onnx_det = cfg['DETECTION_MODEL']
|
49 |
+
onnx_pose = cfg['POSE_MODEL']
|
50 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
51 |
+
self.pose_estimation = Wholebody(onnx_det, onnx_pose, device=self.device)
|
52 |
+
self.resize_size = cfg.get("RESIZE_SIZE", 1024)
|
53 |
+
self.use_body = cfg.get('USE_BODY', True)
|
54 |
+
self.use_face = cfg.get('USE_FACE', True)
|
55 |
+
self.use_hand = cfg.get('USE_HAND', True)
|
56 |
+
|
57 |
+
@torch.no_grad()
|
58 |
+
@torch.inference_mode
|
59 |
+
def forward(self, image):
|
60 |
+
image = convert_to_numpy(image)
|
61 |
+
input_image = HWC3(image[..., ::-1])
|
62 |
+
return self.process(resize_image(input_image, self.resize_size), image.shape[:2])
|
63 |
+
|
64 |
+
def process(self, ori_img, ori_shape):
|
65 |
+
ori_h, ori_w = ori_shape
|
66 |
+
ori_img = ori_img.copy()
|
67 |
+
H, W, C = ori_img.shape
|
68 |
+
with torch.no_grad():
|
69 |
+
candidate, subset, det_result = self.pose_estimation(ori_img)
|
70 |
+
nums, keys, locs = candidate.shape
|
71 |
+
candidate[..., 0] /= float(W)
|
72 |
+
candidate[..., 1] /= float(H)
|
73 |
+
body = candidate[:, :18].copy()
|
74 |
+
body = body.reshape(nums * 18, locs)
|
75 |
+
score = subset[:, :18]
|
76 |
+
for i in range(len(score)):
|
77 |
+
for j in range(len(score[i])):
|
78 |
+
if score[i][j] > 0.3:
|
79 |
+
score[i][j] = int(18 * i + j)
|
80 |
+
else:
|
81 |
+
score[i][j] = -1
|
82 |
+
|
83 |
+
un_visible = subset < 0.3
|
84 |
+
candidate[un_visible] = -1
|
85 |
+
|
86 |
+
foot = candidate[:, 18:24]
|
87 |
+
|
88 |
+
faces = candidate[:, 24:92]
|
89 |
+
|
90 |
+
hands = candidate[:, 92:113]
|
91 |
+
hands = np.vstack([hands, candidate[:, 113:]])
|
92 |
+
|
93 |
+
bodies = dict(candidate=body, subset=score)
|
94 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
95 |
+
|
96 |
+
ret_data = {}
|
97 |
+
if self.use_body:
|
98 |
+
detected_map_body = draw_pose(pose, H, W, use_body=True)
|
99 |
+
detected_map_body = cv2.resize(detected_map_body[..., ::-1], (ori_w, ori_h),
|
100 |
+
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
101 |
+
ret_data["detected_map_body"] = detected_map_body
|
102 |
+
|
103 |
+
if self.use_face:
|
104 |
+
detected_map_face = draw_pose(pose, H, W, use_face=True)
|
105 |
+
detected_map_face = cv2.resize(detected_map_face[..., ::-1], (ori_w, ori_h),
|
106 |
+
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
107 |
+
ret_data["detected_map_face"] = detected_map_face
|
108 |
+
|
109 |
+
if self.use_body and self.use_face:
|
110 |
+
detected_map_bodyface = draw_pose(pose, H, W, use_body=True, use_face=True)
|
111 |
+
detected_map_bodyface = cv2.resize(detected_map_bodyface[..., ::-1], (ori_w, ori_h),
|
112 |
+
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
113 |
+
ret_data["detected_map_bodyface"] = detected_map_bodyface
|
114 |
+
|
115 |
+
if self.use_hand and self.use_body and self.use_face:
|
116 |
+
detected_map_handbodyface = draw_pose(pose, H, W, use_hand=True, use_body=True, use_face=True)
|
117 |
+
detected_map_handbodyface = cv2.resize(detected_map_handbodyface[..., ::-1], (ori_w, ori_h),
|
118 |
+
interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
|
119 |
+
ret_data["detected_map_handbodyface"] = detected_map_handbodyface
|
120 |
+
|
121 |
+
# convert_size
|
122 |
+
if det_result.shape[0] > 0:
|
123 |
+
w_ratio, h_ratio = ori_w / W, ori_h / H
|
124 |
+
det_result[..., ::2] *= h_ratio
|
125 |
+
det_result[..., 1::2] *= w_ratio
|
126 |
+
det_result = det_result.astype(np.int32)
|
127 |
+
return ret_data, det_result
|
128 |
+
|
129 |
+
|
130 |
+
class PoseBodyFaceAnnotator(PoseAnnotator):
|
131 |
+
def __init__(self, cfg):
|
132 |
+
super().__init__(cfg)
|
133 |
+
self.use_body, self.use_face, self.use_hand = True, True, False
|
134 |
+
@torch.no_grad()
|
135 |
+
@torch.inference_mode
|
136 |
+
def forward(self, image):
|
137 |
+
ret_data, det_result = super().forward(image)
|
138 |
+
return ret_data['detected_map_bodyface']
|
139 |
+
|
140 |
+
|
141 |
+
class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
|
142 |
+
def forward(self, frames):
|
143 |
+
ret_frames = []
|
144 |
+
for frame in frames:
|
145 |
+
anno_frame = super().forward(np.array(frame))
|
146 |
+
ret_frames.append(anno_frame)
|
147 |
+
return ret_frames
|
148 |
+
|
149 |
+
import imageio
|
150 |
+
|
151 |
+
def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
|
152 |
+
try:
|
153 |
+
video_writer = imageio.get_writer(file_path, fps=fps, codec='libx264', quality=quality, macro_block_size=macro_block_size)
|
154 |
+
for frame in videos:
|
155 |
+
video_writer.append_data(frame)
|
156 |
+
video_writer.close()
|
157 |
+
return True
|
158 |
+
except Exception as e:
|
159 |
+
print(f"Video save error: {e}")
|
160 |
+
return False
|
161 |
+
|
162 |
+
def get_frames(video_path):
|
163 |
+
frames = []
|
164 |
+
|
165 |
+
|
166 |
+
# Opens the Video file with CV2
|
167 |
+
cap = cv2.VideoCapture(video_path)
|
168 |
+
|
169 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
170 |
+
print("video fps: " + str(fps))
|
171 |
+
i = 0
|
172 |
+
while cap.isOpened():
|
173 |
+
ret, frame = cap.read()
|
174 |
+
if ret == False:
|
175 |
+
break
|
176 |
+
frames.append(frame)
|
177 |
+
i += 1
|
178 |
+
|
179 |
+
cap.release()
|
180 |
+
cv2.destroyAllWindows()
|
181 |
+
|
182 |
+
return frames, fps
|
183 |
+
|
preprocessing/dwpose/util.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
eps = 0.01
|
10 |
+
|
11 |
+
|
12 |
+
def smart_resize(x, s):
|
13 |
+
Ht, Wt = s
|
14 |
+
if x.ndim == 2:
|
15 |
+
Ho, Wo = x.shape
|
16 |
+
Co = 1
|
17 |
+
else:
|
18 |
+
Ho, Wo, Co = x.shape
|
19 |
+
if Co == 3 or Co == 1:
|
20 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
21 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
22 |
+
else:
|
23 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
24 |
+
|
25 |
+
|
26 |
+
def smart_resize_k(x, fx, fy):
|
27 |
+
if x.ndim == 2:
|
28 |
+
Ho, Wo = x.shape
|
29 |
+
Co = 1
|
30 |
+
else:
|
31 |
+
Ho, Wo, Co = x.shape
|
32 |
+
Ht, Wt = Ho * fy, Wo * fx
|
33 |
+
if Co == 3 or Co == 1:
|
34 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
35 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
36 |
+
else:
|
37 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
38 |
+
|
39 |
+
|
40 |
+
def padRightDownCorner(img, stride, padValue):
|
41 |
+
h = img.shape[0]
|
42 |
+
w = img.shape[1]
|
43 |
+
|
44 |
+
pad = 4 * [None]
|
45 |
+
pad[0] = 0 # up
|
46 |
+
pad[1] = 0 # left
|
47 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
48 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
49 |
+
|
50 |
+
img_padded = img
|
51 |
+
pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
|
52 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
53 |
+
pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
|
54 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
55 |
+
pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
|
56 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
57 |
+
pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
|
58 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
59 |
+
|
60 |
+
return img_padded, pad
|
61 |
+
|
62 |
+
|
63 |
+
def transfer(model, model_weights):
|
64 |
+
transfered_model_weights = {}
|
65 |
+
for weights_name in model.state_dict().keys():
|
66 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
67 |
+
return transfered_model_weights
|
68 |
+
|
69 |
+
|
70 |
+
def draw_bodypose(canvas, candidate, subset):
|
71 |
+
H, W, C = canvas.shape
|
72 |
+
candidate = np.array(candidate)
|
73 |
+
subset = np.array(subset)
|
74 |
+
|
75 |
+
stickwidth = 4
|
76 |
+
|
77 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
78 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
79 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
80 |
+
|
81 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
82 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
83 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
84 |
+
|
85 |
+
for i in range(17):
|
86 |
+
for n in range(len(subset)):
|
87 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
88 |
+
if -1 in index:
|
89 |
+
continue
|
90 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
91 |
+
X = candidate[index.astype(int), 1] * float(H)
|
92 |
+
mX = np.mean(X)
|
93 |
+
mY = np.mean(Y)
|
94 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
95 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
96 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
97 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
98 |
+
|
99 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
100 |
+
|
101 |
+
for i in range(18):
|
102 |
+
for n in range(len(subset)):
|
103 |
+
index = int(subset[n][i])
|
104 |
+
if index == -1:
|
105 |
+
continue
|
106 |
+
x, y = candidate[index][0:2]
|
107 |
+
x = int(x * W)
|
108 |
+
y = int(y * H)
|
109 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
110 |
+
|
111 |
+
return canvas
|
112 |
+
|
113 |
+
|
114 |
+
def draw_handpose(canvas, all_hand_peaks):
|
115 |
+
H, W, C = canvas.shape
|
116 |
+
|
117 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
118 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
119 |
+
|
120 |
+
for peaks in all_hand_peaks:
|
121 |
+
peaks = np.array(peaks)
|
122 |
+
|
123 |
+
for ie, e in enumerate(edges):
|
124 |
+
x1, y1 = peaks[e[0]]
|
125 |
+
x2, y2 = peaks[e[1]]
|
126 |
+
x1 = int(x1 * W)
|
127 |
+
y1 = int(y1 * H)
|
128 |
+
x2 = int(x2 * W)
|
129 |
+
y2 = int(y2 * H)
|
130 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
131 |
+
cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
|
132 |
+
|
133 |
+
for i, keyponit in enumerate(peaks):
|
134 |
+
x, y = keyponit
|
135 |
+
x = int(x * W)
|
136 |
+
y = int(y * H)
|
137 |
+
if x > eps and y > eps:
|
138 |
+
cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
|
139 |
+
return canvas
|
140 |
+
|
141 |
+
|
142 |
+
def draw_facepose(canvas, all_lmks):
|
143 |
+
H, W, C = canvas.shape
|
144 |
+
for lmks in all_lmks:
|
145 |
+
lmks = np.array(lmks)
|
146 |
+
for lmk in lmks:
|
147 |
+
x, y = lmk
|
148 |
+
x = int(x * W)
|
149 |
+
y = int(y * H)
|
150 |
+
if x > eps and y > eps:
|
151 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
152 |
+
return canvas
|
153 |
+
|
154 |
+
|
155 |
+
# detect hand according to body pose keypoints
|
156 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
157 |
+
def handDetect(candidate, subset, oriImg):
|
158 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
159 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
160 |
+
ratioWristElbow = 0.33
|
161 |
+
detect_result = []
|
162 |
+
image_height, image_width = oriImg.shape[0:2]
|
163 |
+
for person in subset.astype(int):
|
164 |
+
# if any of three not detected
|
165 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
166 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
167 |
+
if not (has_left or has_right):
|
168 |
+
continue
|
169 |
+
hands = []
|
170 |
+
#left hand
|
171 |
+
if has_left:
|
172 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
173 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
174 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
175 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
176 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
177 |
+
# right hand
|
178 |
+
if has_right:
|
179 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
180 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
181 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
182 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
183 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
184 |
+
|
185 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
186 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
187 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
188 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
189 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
190 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
191 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
192 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
193 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
194 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
195 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
196 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
197 |
+
# x-y refers to the center --> offset to topLeft point
|
198 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
199 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
200 |
+
x -= width / 2
|
201 |
+
y -= width / 2 # width = height
|
202 |
+
# overflow the image
|
203 |
+
if x < 0: x = 0
|
204 |
+
if y < 0: y = 0
|
205 |
+
width1 = width
|
206 |
+
width2 = width
|
207 |
+
if x + width > image_width: width1 = image_width - x
|
208 |
+
if y + width > image_height: width2 = image_height - y
|
209 |
+
width = min(width1, width2)
|
210 |
+
# the max hand box value is 20 pixels
|
211 |
+
if width >= 20:
|
212 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
213 |
+
|
214 |
+
'''
|
215 |
+
return value: [[x, y, w, True if left hand else False]].
|
216 |
+
width=height since the network require squared input.
|
217 |
+
x, y is the coordinate of top left
|
218 |
+
'''
|
219 |
+
return detect_result
|
220 |
+
|
221 |
+
|
222 |
+
# Written by Lvmin
|
223 |
+
def faceDetect(candidate, subset, oriImg):
|
224 |
+
# left right eye ear 14 15 16 17
|
225 |
+
detect_result = []
|
226 |
+
image_height, image_width = oriImg.shape[0:2]
|
227 |
+
for person in subset.astype(int):
|
228 |
+
has_head = person[0] > -1
|
229 |
+
if not has_head:
|
230 |
+
continue
|
231 |
+
|
232 |
+
has_left_eye = person[14] > -1
|
233 |
+
has_right_eye = person[15] > -1
|
234 |
+
has_left_ear = person[16] > -1
|
235 |
+
has_right_ear = person[17] > -1
|
236 |
+
|
237 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
238 |
+
continue
|
239 |
+
|
240 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
241 |
+
|
242 |
+
width = 0.0
|
243 |
+
x0, y0 = candidate[head][:2]
|
244 |
+
|
245 |
+
if has_left_eye:
|
246 |
+
x1, y1 = candidate[left_eye][:2]
|
247 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
248 |
+
width = max(width, d * 3.0)
|
249 |
+
|
250 |
+
if has_right_eye:
|
251 |
+
x1, y1 = candidate[right_eye][:2]
|
252 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
253 |
+
width = max(width, d * 3.0)
|
254 |
+
|
255 |
+
if has_left_ear:
|
256 |
+
x1, y1 = candidate[left_ear][:2]
|
257 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
258 |
+
width = max(width, d * 1.5)
|
259 |
+
|
260 |
+
if has_right_ear:
|
261 |
+
x1, y1 = candidate[right_ear][:2]
|
262 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
263 |
+
width = max(width, d * 1.5)
|
264 |
+
|
265 |
+
x, y = x0, y0
|
266 |
+
|
267 |
+
x -= width
|
268 |
+
y -= width
|
269 |
+
|
270 |
+
if x < 0:
|
271 |
+
x = 0
|
272 |
+
|
273 |
+
if y < 0:
|
274 |
+
y = 0
|
275 |
+
|
276 |
+
width1 = width * 2
|
277 |
+
width2 = width * 2
|
278 |
+
|
279 |
+
if x + width > image_width:
|
280 |
+
width1 = image_width - x
|
281 |
+
|
282 |
+
if y + width > image_height:
|
283 |
+
width2 = image_height - y
|
284 |
+
|
285 |
+
width = min(width1, width2)
|
286 |
+
|
287 |
+
if width >= 20:
|
288 |
+
detect_result.append([int(x), int(y), int(width)])
|
289 |
+
|
290 |
+
return detect_result
|
291 |
+
|
292 |
+
|
293 |
+
# get max index of 2d array
|
294 |
+
def npmax(array):
|
295 |
+
arrayindex = array.argmax(1)
|
296 |
+
arrayvalue = array.max(1)
|
297 |
+
i = arrayvalue.argmax()
|
298 |
+
j = arrayindex[i]
|
299 |
+
return i, j
|
preprocessing/dwpose/wholebody.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import onnxruntime as ort
|
6 |
+
from .onnxdet import inference_detector
|
7 |
+
from .onnxpose import inference_pose
|
8 |
+
|
9 |
+
def HWC3(x):
|
10 |
+
assert x.dtype == np.uint8
|
11 |
+
if x.ndim == 2:
|
12 |
+
x = x[:, :, None]
|
13 |
+
assert x.ndim == 3
|
14 |
+
H, W, C = x.shape
|
15 |
+
assert C == 1 or C == 3 or C == 4
|
16 |
+
if C == 3:
|
17 |
+
return x
|
18 |
+
if C == 1:
|
19 |
+
return np.concatenate([x, x, x], axis=2)
|
20 |
+
if C == 4:
|
21 |
+
color = x[:, :, 0:3].astype(np.float32)
|
22 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
23 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
24 |
+
y = y.clip(0, 255).astype(np.uint8)
|
25 |
+
return y
|
26 |
+
|
27 |
+
|
28 |
+
def resize_image(input_image, resolution):
|
29 |
+
H, W, C = input_image.shape
|
30 |
+
H = float(H)
|
31 |
+
W = float(W)
|
32 |
+
k = float(resolution) / min(H, W)
|
33 |
+
H *= k
|
34 |
+
W *= k
|
35 |
+
H = int(np.round(H / 64.0)) * 64
|
36 |
+
W = int(np.round(W / 64.0)) * 64
|
37 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
38 |
+
return img
|
39 |
+
|
40 |
+
class Wholebody:
|
41 |
+
def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
|
42 |
+
|
43 |
+
providers = ['CPUExecutionProvider'
|
44 |
+
] if device == 'cpu' else ['CUDAExecutionProvider']
|
45 |
+
# onnx_det = 'annotator/ckpts/yolox_l.onnx'
|
46 |
+
# onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
|
47 |
+
|
48 |
+
self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
|
49 |
+
self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
|
50 |
+
|
51 |
+
def __call__(self, ori_img):
|
52 |
+
det_result = inference_detector(self.session_det, ori_img)
|
53 |
+
keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
|
54 |
+
|
55 |
+
keypoints_info = np.concatenate(
|
56 |
+
(keypoints, scores[..., None]), axis=-1)
|
57 |
+
# compute neck joint
|
58 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
59 |
+
# neck score when visualizing pred
|
60 |
+
neck[:, 2:4] = np.logical_and(
|
61 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
62 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
63 |
+
new_keypoints_info = np.insert(
|
64 |
+
keypoints_info, 17, neck, axis=1)
|
65 |
+
mmpose_idx = [
|
66 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
67 |
+
]
|
68 |
+
openpose_idx = [
|
69 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
70 |
+
]
|
71 |
+
new_keypoints_info[:, openpose_idx] = \
|
72 |
+
new_keypoints_info[:, mmpose_idx]
|
73 |
+
keypoints_info = new_keypoints_info
|
74 |
+
|
75 |
+
keypoints, scores = keypoints_info[
|
76 |
+
..., :2], keypoints_info[..., 2]
|
77 |
+
|
78 |
+
return keypoints, scores, det_result
|
79 |
+
|
80 |
+
|
preprocessing/gray.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
|
9 |
+
def convert_to_numpy(image):
|
10 |
+
if isinstance(image, Image.Image):
|
11 |
+
image = np.array(image)
|
12 |
+
elif isinstance(image, torch.Tensor):
|
13 |
+
image = image.detach().cpu().numpy()
|
14 |
+
elif isinstance(image, np.ndarray):
|
15 |
+
image = image.copy()
|
16 |
+
else:
|
17 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
18 |
+
return image
|
19 |
+
|
20 |
+
class GrayAnnotator:
|
21 |
+
def __init__(self, cfg):
|
22 |
+
pass
|
23 |
+
def forward(self, image):
|
24 |
+
image = convert_to_numpy(image)
|
25 |
+
gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
26 |
+
return gray_map[..., None].repeat(3, axis=2)
|
27 |
+
|
28 |
+
|
29 |
+
class GrayVideoAnnotator(GrayAnnotator):
|
30 |
+
def forward(self, frames):
|
31 |
+
ret_frames = []
|
32 |
+
for frame in frames:
|
33 |
+
anno_frame = super().forward(np.array(frame))
|
34 |
+
ret_frames.append(anno_frame)
|
35 |
+
return ret_frames
|
preprocessing/midas/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
preprocessing/midas/api.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
# based on https://github.com/isl-org/MiDaS
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torchvision.transforms import Compose
|
9 |
+
|
10 |
+
from .dpt_depth import DPTDepthModel
|
11 |
+
from .midas_net import MidasNet
|
12 |
+
from .midas_net_custom import MidasNet_small
|
13 |
+
from .transforms import NormalizeImage, PrepareForNet, Resize
|
14 |
+
|
15 |
+
# ISL_PATHS = {
|
16 |
+
# "dpt_large": "dpt_large-midas-2f21e586.pt",
|
17 |
+
# "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
|
18 |
+
# "midas_v21": "",
|
19 |
+
# "midas_v21_small": "",
|
20 |
+
# }
|
21 |
+
|
22 |
+
# remote_model_path =
|
23 |
+
# "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
|
24 |
+
|
25 |
+
|
26 |
+
def disabled_train(self, mode=True):
|
27 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
28 |
+
does not change anymore."""
|
29 |
+
return self
|
30 |
+
|
31 |
+
|
32 |
+
def load_midas_transform(model_type):
|
33 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
34 |
+
# load transform only
|
35 |
+
if model_type == 'dpt_large': # DPT-Large
|
36 |
+
net_w, net_h = 384, 384
|
37 |
+
resize_mode = 'minimal'
|
38 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
39 |
+
std=[0.5, 0.5, 0.5])
|
40 |
+
|
41 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
42 |
+
net_w, net_h = 384, 384
|
43 |
+
resize_mode = 'minimal'
|
44 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
45 |
+
std=[0.5, 0.5, 0.5])
|
46 |
+
|
47 |
+
elif model_type == 'midas_v21':
|
48 |
+
net_w, net_h = 384, 384
|
49 |
+
resize_mode = 'upper_bound'
|
50 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
51 |
+
std=[0.229, 0.224, 0.225])
|
52 |
+
|
53 |
+
elif model_type == 'midas_v21_small':
|
54 |
+
net_w, net_h = 256, 256
|
55 |
+
resize_mode = 'upper_bound'
|
56 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
57 |
+
std=[0.229, 0.224, 0.225])
|
58 |
+
|
59 |
+
else:
|
60 |
+
assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
|
61 |
+
|
62 |
+
transform = Compose([
|
63 |
+
Resize(
|
64 |
+
net_w,
|
65 |
+
net_h,
|
66 |
+
resize_target=None,
|
67 |
+
keep_aspect_ratio=True,
|
68 |
+
ensure_multiple_of=32,
|
69 |
+
resize_method=resize_mode,
|
70 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
71 |
+
),
|
72 |
+
normalization,
|
73 |
+
PrepareForNet(),
|
74 |
+
])
|
75 |
+
|
76 |
+
return transform
|
77 |
+
|
78 |
+
|
79 |
+
def load_model(model_type, model_path):
|
80 |
+
# https://github.com/isl-org/MiDaS/blob/master/run.py
|
81 |
+
# load network
|
82 |
+
# model_path = ISL_PATHS[model_type]
|
83 |
+
if model_type == 'dpt_large': # DPT-Large
|
84 |
+
model = DPTDepthModel(
|
85 |
+
path=model_path,
|
86 |
+
backbone='vitl16_384',
|
87 |
+
non_negative=True,
|
88 |
+
)
|
89 |
+
net_w, net_h = 384, 384
|
90 |
+
resize_mode = 'minimal'
|
91 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
92 |
+
std=[0.5, 0.5, 0.5])
|
93 |
+
|
94 |
+
elif model_type == 'dpt_hybrid': # DPT-Hybrid
|
95 |
+
model = DPTDepthModel(
|
96 |
+
path=model_path,
|
97 |
+
backbone='vitb_rn50_384',
|
98 |
+
non_negative=True,
|
99 |
+
)
|
100 |
+
net_w, net_h = 384, 384
|
101 |
+
resize_mode = 'minimal'
|
102 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
|
103 |
+
std=[0.5, 0.5, 0.5])
|
104 |
+
|
105 |
+
elif model_type == 'midas_v21':
|
106 |
+
model = MidasNet(model_path, non_negative=True)
|
107 |
+
net_w, net_h = 384, 384
|
108 |
+
resize_mode = 'upper_bound'
|
109 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
110 |
+
std=[0.229, 0.224, 0.225])
|
111 |
+
|
112 |
+
elif model_type == 'midas_v21_small':
|
113 |
+
model = MidasNet_small(model_path,
|
114 |
+
features=64,
|
115 |
+
backbone='efficientnet_lite3',
|
116 |
+
exportable=True,
|
117 |
+
non_negative=True,
|
118 |
+
blocks={'expand': True})
|
119 |
+
net_w, net_h = 256, 256
|
120 |
+
resize_mode = 'upper_bound'
|
121 |
+
normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
|
122 |
+
std=[0.229, 0.224, 0.225])
|
123 |
+
|
124 |
+
else:
|
125 |
+
print(
|
126 |
+
f"model_type '{model_type}' not implemented, use: --model_type large"
|
127 |
+
)
|
128 |
+
assert False
|
129 |
+
|
130 |
+
transform = Compose([
|
131 |
+
Resize(
|
132 |
+
net_w,
|
133 |
+
net_h,
|
134 |
+
resize_target=None,
|
135 |
+
keep_aspect_ratio=True,
|
136 |
+
ensure_multiple_of=32,
|
137 |
+
resize_method=resize_mode,
|
138 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
139 |
+
),
|
140 |
+
normalization,
|
141 |
+
PrepareForNet(),
|
142 |
+
])
|
143 |
+
|
144 |
+
return model.eval(), transform
|
145 |
+
|
146 |
+
|
147 |
+
class MiDaSInference(nn.Module):
|
148 |
+
MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
|
149 |
+
MODEL_TYPES_ISL = [
|
150 |
+
'dpt_large',
|
151 |
+
'dpt_hybrid',
|
152 |
+
'midas_v21',
|
153 |
+
'midas_v21_small',
|
154 |
+
]
|
155 |
+
|
156 |
+
def __init__(self, model_type, model_path):
|
157 |
+
super().__init__()
|
158 |
+
assert (model_type in self.MODEL_TYPES_ISL)
|
159 |
+
model, _ = load_model(model_type, model_path)
|
160 |
+
self.model = model
|
161 |
+
self.model.train = disabled_train
|
162 |
+
|
163 |
+
def forward(self, x):
|
164 |
+
with torch.no_grad():
|
165 |
+
prediction = self.model(x)
|
166 |
+
return prediction
|
preprocessing/midas/base_model.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class BaseModel(torch.nn.Module):
|
7 |
+
def load(self, path):
|
8 |
+
"""Load model from file.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
path (str): file path
|
12 |
+
"""
|
13 |
+
parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
|
14 |
+
|
15 |
+
if 'optimizer' in parameters:
|
16 |
+
parameters = parameters['model']
|
17 |
+
|
18 |
+
self.load_state_dict(parameters)
|
preprocessing/midas/blocks.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
|
7 |
+
_make_pretrained_vitl16_384)
|
8 |
+
|
9 |
+
|
10 |
+
def _make_encoder(
|
11 |
+
backbone,
|
12 |
+
features,
|
13 |
+
use_pretrained,
|
14 |
+
groups=1,
|
15 |
+
expand=False,
|
16 |
+
exportable=True,
|
17 |
+
hooks=None,
|
18 |
+
use_vit_only=False,
|
19 |
+
use_readout='ignore',
|
20 |
+
):
|
21 |
+
if backbone == 'vitl16_384':
|
22 |
+
pretrained = _make_pretrained_vitl16_384(use_pretrained,
|
23 |
+
hooks=hooks,
|
24 |
+
use_readout=use_readout)
|
25 |
+
scratch = _make_scratch(
|
26 |
+
[256, 512, 1024, 1024], features, groups=groups,
|
27 |
+
expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
|
28 |
+
elif backbone == 'vitb_rn50_384':
|
29 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
30 |
+
use_pretrained,
|
31 |
+
hooks=hooks,
|
32 |
+
use_vit_only=use_vit_only,
|
33 |
+
use_readout=use_readout,
|
34 |
+
)
|
35 |
+
scratch = _make_scratch(
|
36 |
+
[256, 512, 768, 768], features, groups=groups,
|
37 |
+
expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
|
38 |
+
elif backbone == 'vitb16_384':
|
39 |
+
pretrained = _make_pretrained_vitb16_384(use_pretrained,
|
40 |
+
hooks=hooks,
|
41 |
+
use_readout=use_readout)
|
42 |
+
scratch = _make_scratch(
|
43 |
+
[96, 192, 384, 768], features, groups=groups,
|
44 |
+
expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
|
45 |
+
elif backbone == 'resnext101_wsl':
|
46 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
47 |
+
scratch = _make_scratch([256, 512, 1024, 2048],
|
48 |
+
features,
|
49 |
+
groups=groups,
|
50 |
+
expand=expand) # efficientnet_lite3
|
51 |
+
elif backbone == 'efficientnet_lite3':
|
52 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
|
53 |
+
exportable=exportable)
|
54 |
+
scratch = _make_scratch([32, 48, 136, 384],
|
55 |
+
features,
|
56 |
+
groups=groups,
|
57 |
+
expand=expand) # efficientnet_lite3
|
58 |
+
else:
|
59 |
+
print(f"Backbone '{backbone}' not implemented")
|
60 |
+
assert False
|
61 |
+
|
62 |
+
return pretrained, scratch
|
63 |
+
|
64 |
+
|
65 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
66 |
+
scratch = nn.Module()
|
67 |
+
|
68 |
+
out_shape1 = out_shape
|
69 |
+
out_shape2 = out_shape
|
70 |
+
out_shape3 = out_shape
|
71 |
+
out_shape4 = out_shape
|
72 |
+
if expand is True:
|
73 |
+
out_shape1 = out_shape
|
74 |
+
out_shape2 = out_shape * 2
|
75 |
+
out_shape3 = out_shape * 4
|
76 |
+
out_shape4 = out_shape * 8
|
77 |
+
|
78 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0],
|
79 |
+
out_shape1,
|
80 |
+
kernel_size=3,
|
81 |
+
stride=1,
|
82 |
+
padding=1,
|
83 |
+
bias=False,
|
84 |
+
groups=groups)
|
85 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1],
|
86 |
+
out_shape2,
|
87 |
+
kernel_size=3,
|
88 |
+
stride=1,
|
89 |
+
padding=1,
|
90 |
+
bias=False,
|
91 |
+
groups=groups)
|
92 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2],
|
93 |
+
out_shape3,
|
94 |
+
kernel_size=3,
|
95 |
+
stride=1,
|
96 |
+
padding=1,
|
97 |
+
bias=False,
|
98 |
+
groups=groups)
|
99 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3],
|
100 |
+
out_shape4,
|
101 |
+
kernel_size=3,
|
102 |
+
stride=1,
|
103 |
+
padding=1,
|
104 |
+
bias=False,
|
105 |
+
groups=groups)
|
106 |
+
|
107 |
+
return scratch
|
108 |
+
|
109 |
+
|
110 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
111 |
+
efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
|
112 |
+
'tf_efficientnet_lite3',
|
113 |
+
pretrained=use_pretrained,
|
114 |
+
exportable=exportable)
|
115 |
+
return _make_efficientnet_backbone(efficientnet)
|
116 |
+
|
117 |
+
|
118 |
+
def _make_efficientnet_backbone(effnet):
|
119 |
+
pretrained = nn.Module()
|
120 |
+
|
121 |
+
pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
|
122 |
+
effnet.act1, *effnet.blocks[0:2])
|
123 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
124 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
125 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
126 |
+
|
127 |
+
return pretrained
|
128 |
+
|
129 |
+
|
130 |
+
def _make_resnet_backbone(resnet):
|
131 |
+
pretrained = nn.Module()
|
132 |
+
pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
|
133 |
+
resnet.maxpool, resnet.layer1)
|
134 |
+
|
135 |
+
pretrained.layer2 = resnet.layer2
|
136 |
+
pretrained.layer3 = resnet.layer3
|
137 |
+
pretrained.layer4 = resnet.layer4
|
138 |
+
|
139 |
+
return pretrained
|
140 |
+
|
141 |
+
|
142 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
143 |
+
resnet = torch.hub.load('facebookresearch/WSL-Images',
|
144 |
+
'resnext101_32x8d_wsl')
|
145 |
+
return _make_resnet_backbone(resnet)
|
146 |
+
|
147 |
+
|
148 |
+
class Interpolate(nn.Module):
|
149 |
+
"""Interpolation module.
|
150 |
+
"""
|
151 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
152 |
+
"""Init.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
scale_factor (float): scaling
|
156 |
+
mode (str): interpolation mode
|
157 |
+
"""
|
158 |
+
super(Interpolate, self).__init__()
|
159 |
+
|
160 |
+
self.interp = nn.functional.interpolate
|
161 |
+
self.scale_factor = scale_factor
|
162 |
+
self.mode = mode
|
163 |
+
self.align_corners = align_corners
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
"""Forward pass.
|
167 |
+
|
168 |
+
Args:
|
169 |
+
x (tensor): input
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
tensor: interpolated data
|
173 |
+
"""
|
174 |
+
|
175 |
+
x = self.interp(x,
|
176 |
+
scale_factor=self.scale_factor,
|
177 |
+
mode=self.mode,
|
178 |
+
align_corners=self.align_corners)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class ResidualConvUnit(nn.Module):
|
184 |
+
"""Residual convolution module.
|
185 |
+
"""
|
186 |
+
def __init__(self, features):
|
187 |
+
"""Init.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
features (int): number of features
|
191 |
+
"""
|
192 |
+
super().__init__()
|
193 |
+
|
194 |
+
self.conv1 = nn.Conv2d(features,
|
195 |
+
features,
|
196 |
+
kernel_size=3,
|
197 |
+
stride=1,
|
198 |
+
padding=1,
|
199 |
+
bias=True)
|
200 |
+
|
201 |
+
self.conv2 = nn.Conv2d(features,
|
202 |
+
features,
|
203 |
+
kernel_size=3,
|
204 |
+
stride=1,
|
205 |
+
padding=1,
|
206 |
+
bias=True)
|
207 |
+
|
208 |
+
self.relu = nn.ReLU(inplace=True)
|
209 |
+
|
210 |
+
def forward(self, x):
|
211 |
+
"""Forward pass.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x (tensor): input
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
tensor: output
|
218 |
+
"""
|
219 |
+
out = self.relu(x)
|
220 |
+
out = self.conv1(out)
|
221 |
+
out = self.relu(out)
|
222 |
+
out = self.conv2(out)
|
223 |
+
|
224 |
+
return out + x
|
225 |
+
|
226 |
+
|
227 |
+
class FeatureFusionBlock(nn.Module):
|
228 |
+
"""Feature fusion block.
|
229 |
+
"""
|
230 |
+
def __init__(self, features):
|
231 |
+
"""Init.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
features (int): number of features
|
235 |
+
"""
|
236 |
+
super(FeatureFusionBlock, self).__init__()
|
237 |
+
|
238 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
239 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
240 |
+
|
241 |
+
def forward(self, *xs):
|
242 |
+
"""Forward pass.
|
243 |
+
|
244 |
+
Returns:
|
245 |
+
tensor: output
|
246 |
+
"""
|
247 |
+
output = xs[0]
|
248 |
+
|
249 |
+
if len(xs) == 2:
|
250 |
+
output += self.resConfUnit1(xs[1])
|
251 |
+
|
252 |
+
output = self.resConfUnit2(output)
|
253 |
+
|
254 |
+
output = nn.functional.interpolate(output,
|
255 |
+
scale_factor=2,
|
256 |
+
mode='bilinear',
|
257 |
+
align_corners=True)
|
258 |
+
|
259 |
+
return output
|
260 |
+
|
261 |
+
|
262 |
+
class ResidualConvUnit_custom(nn.Module):
|
263 |
+
"""Residual convolution module.
|
264 |
+
"""
|
265 |
+
def __init__(self, features, activation, bn):
|
266 |
+
"""Init.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
features (int): number of features
|
270 |
+
"""
|
271 |
+
super().__init__()
|
272 |
+
|
273 |
+
self.bn = bn
|
274 |
+
|
275 |
+
self.groups = 1
|
276 |
+
|
277 |
+
self.conv1 = nn.Conv2d(features,
|
278 |
+
features,
|
279 |
+
kernel_size=3,
|
280 |
+
stride=1,
|
281 |
+
padding=1,
|
282 |
+
bias=True,
|
283 |
+
groups=self.groups)
|
284 |
+
|
285 |
+
self.conv2 = nn.Conv2d(features,
|
286 |
+
features,
|
287 |
+
kernel_size=3,
|
288 |
+
stride=1,
|
289 |
+
padding=1,
|
290 |
+
bias=True,
|
291 |
+
groups=self.groups)
|
292 |
+
|
293 |
+
if self.bn is True:
|
294 |
+
self.bn1 = nn.BatchNorm2d(features)
|
295 |
+
self.bn2 = nn.BatchNorm2d(features)
|
296 |
+
|
297 |
+
self.activation = activation
|
298 |
+
|
299 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
300 |
+
|
301 |
+
def forward(self, x):
|
302 |
+
"""Forward pass.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
x (tensor): input
|
306 |
+
|
307 |
+
Returns:
|
308 |
+
tensor: output
|
309 |
+
"""
|
310 |
+
|
311 |
+
out = self.activation(x)
|
312 |
+
out = self.conv1(out)
|
313 |
+
if self.bn is True:
|
314 |
+
out = self.bn1(out)
|
315 |
+
|
316 |
+
out = self.activation(out)
|
317 |
+
out = self.conv2(out)
|
318 |
+
if self.bn is True:
|
319 |
+
out = self.bn2(out)
|
320 |
+
|
321 |
+
if self.groups > 1:
|
322 |
+
out = self.conv_merge(out)
|
323 |
+
|
324 |
+
return self.skip_add.add(out, x)
|
325 |
+
|
326 |
+
# return out + x
|
327 |
+
|
328 |
+
|
329 |
+
class FeatureFusionBlock_custom(nn.Module):
|
330 |
+
"""Feature fusion block.
|
331 |
+
"""
|
332 |
+
def __init__(self,
|
333 |
+
features,
|
334 |
+
activation,
|
335 |
+
deconv=False,
|
336 |
+
bn=False,
|
337 |
+
expand=False,
|
338 |
+
align_corners=True):
|
339 |
+
"""Init.
|
340 |
+
|
341 |
+
Args:
|
342 |
+
features (int): number of features
|
343 |
+
"""
|
344 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
345 |
+
|
346 |
+
self.deconv = deconv
|
347 |
+
self.align_corners = align_corners
|
348 |
+
|
349 |
+
self.groups = 1
|
350 |
+
|
351 |
+
self.expand = expand
|
352 |
+
out_features = features
|
353 |
+
if self.expand is True:
|
354 |
+
out_features = features // 2
|
355 |
+
|
356 |
+
self.out_conv = nn.Conv2d(features,
|
357 |
+
out_features,
|
358 |
+
kernel_size=1,
|
359 |
+
stride=1,
|
360 |
+
padding=0,
|
361 |
+
bias=True,
|
362 |
+
groups=1)
|
363 |
+
|
364 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
365 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
366 |
+
|
367 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
368 |
+
|
369 |
+
def forward(self, *xs):
|
370 |
+
"""Forward pass.
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
tensor: output
|
374 |
+
"""
|
375 |
+
output = xs[0]
|
376 |
+
|
377 |
+
if len(xs) == 2:
|
378 |
+
res = self.resConfUnit1(xs[1])
|
379 |
+
output = self.skip_add.add(output, res)
|
380 |
+
# output += res
|
381 |
+
|
382 |
+
output = self.resConfUnit2(output)
|
383 |
+
|
384 |
+
output = nn.functional.interpolate(output,
|
385 |
+
scale_factor=2,
|
386 |
+
mode='bilinear',
|
387 |
+
align_corners=self.align_corners)
|
388 |
+
|
389 |
+
output = self.out_conv(output)
|
390 |
+
|
391 |
+
return output
|
preprocessing/midas/depth.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from einops import rearrange
|
7 |
+
from PIL import Image
|
8 |
+
import cv2
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def convert_to_numpy(image):
|
13 |
+
if isinstance(image, Image.Image):
|
14 |
+
image = np.array(image)
|
15 |
+
elif isinstance(image, torch.Tensor):
|
16 |
+
image = image.detach().cpu().numpy()
|
17 |
+
elif isinstance(image, np.ndarray):
|
18 |
+
image = image.copy()
|
19 |
+
else:
|
20 |
+
raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
|
21 |
+
return image
|
22 |
+
|
23 |
+
def resize_image(input_image, resolution):
|
24 |
+
H, W, C = input_image.shape
|
25 |
+
H = float(H)
|
26 |
+
W = float(W)
|
27 |
+
k = float(resolution) / min(H, W)
|
28 |
+
H *= k
|
29 |
+
W *= k
|
30 |
+
H = int(np.round(H / 64.0)) * 64
|
31 |
+
W = int(np.round(W / 64.0)) * 64
|
32 |
+
img = cv2.resize(
|
33 |
+
input_image, (W, H),
|
34 |
+
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
35 |
+
return img, k
|
36 |
+
|
37 |
+
|
38 |
+
def resize_image_ori(h, w, image, k):
|
39 |
+
img = cv2.resize(
|
40 |
+
image, (w, h),
|
41 |
+
interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
42 |
+
return img
|
43 |
+
|
44 |
+
class DepthAnnotator:
|
45 |
+
def __init__(self, cfg, device=None):
|
46 |
+
from .api import MiDaSInference
|
47 |
+
pretrained_model = cfg['PRETRAINED_MODEL']
|
48 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
|
49 |
+
self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
|
50 |
+
self.a = cfg.get('A', np.pi * 2.0)
|
51 |
+
self.bg_th = cfg.get('BG_TH', 0.1)
|
52 |
+
|
53 |
+
@torch.no_grad()
|
54 |
+
@torch.inference_mode()
|
55 |
+
@torch.autocast('cuda', enabled=False)
|
56 |
+
def forward(self, image):
|
57 |
+
image = convert_to_numpy(image)
|
58 |
+
image_depth = image
|
59 |
+
h, w, c = image.shape
|
60 |
+
image_depth, k = resize_image(image_depth,
|
61 |
+
1024 if min(h, w) > 1024 else min(h, w))
|
62 |
+
image_depth = torch.from_numpy(image_depth).float().to(self.device)
|
63 |
+
image_depth = image_depth / 127.5 - 1.0
|
64 |
+
image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
|
65 |
+
depth = self.model(image_depth)[0]
|
66 |
+
|
67 |
+
depth_pt = depth.clone()
|
68 |
+
depth_pt -= torch.min(depth_pt)
|
69 |
+
depth_pt /= torch.max(depth_pt)
|
70 |
+
depth_pt = depth_pt.cpu().numpy()
|
71 |
+
depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
|
72 |
+
depth_image = depth_image[..., None].repeat(3, 2)
|
73 |
+
|
74 |
+
depth_image = resize_image_ori(h, w, depth_image, k)
|
75 |
+
return depth_image
|
76 |
+
|
77 |
+
|
78 |
+
class DepthVideoAnnotator(DepthAnnotator):
|
79 |
+
def forward(self, frames):
|
80 |
+
ret_frames = []
|
81 |
+
for frame in frames:
|
82 |
+
anno_frame = super().forward(np.array(frame))
|
83 |
+
ret_frames.append(anno_frame)
|
84 |
+
return ret_frames
|
preprocessing/midas/dpt_depth.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .base_model import BaseModel
|
7 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
8 |
+
from .vit import forward_vit
|
9 |
+
|
10 |
+
|
11 |
+
def _make_fusion_block(features, use_bn):
|
12 |
+
return FeatureFusionBlock_custom(
|
13 |
+
features,
|
14 |
+
nn.ReLU(False),
|
15 |
+
deconv=False,
|
16 |
+
bn=use_bn,
|
17 |
+
expand=False,
|
18 |
+
align_corners=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class DPT(BaseModel):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
head,
|
26 |
+
features=256,
|
27 |
+
backbone='vitb_rn50_384',
|
28 |
+
readout='project',
|
29 |
+
channels_last=False,
|
30 |
+
use_bn=False,
|
31 |
+
):
|
32 |
+
|
33 |
+
super(DPT, self).__init__()
|
34 |
+
|
35 |
+
self.channels_last = channels_last
|
36 |
+
|
37 |
+
hooks = {
|
38 |
+
'vitb_rn50_384': [0, 1, 8, 11],
|
39 |
+
'vitb16_384': [2, 5, 8, 11],
|
40 |
+
'vitl16_384': [5, 11, 17, 23],
|
41 |
+
}
|
42 |
+
|
43 |
+
# Instantiate backbone and reassemble blocks
|
44 |
+
self.pretrained, self.scratch = _make_encoder(
|
45 |
+
backbone,
|
46 |
+
features,
|
47 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
48 |
+
groups=1,
|
49 |
+
expand=False,
|
50 |
+
exportable=False,
|
51 |
+
hooks=hooks[backbone],
|
52 |
+
use_readout=readout,
|
53 |
+
)
|
54 |
+
|
55 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
56 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
57 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
58 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
59 |
+
|
60 |
+
self.scratch.output_conv = head
|
61 |
+
|
62 |
+
def forward(self, x):
|
63 |
+
if self.channels_last is True:
|
64 |
+
x.contiguous(memory_format=torch.channels_last)
|
65 |
+
|
66 |
+
layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
|
67 |
+
|
68 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
69 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
70 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
71 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
72 |
+
|
73 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
74 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
75 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
76 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
77 |
+
|
78 |
+
out = self.scratch.output_conv(path_1)
|
79 |
+
|
80 |
+
return out
|
81 |
+
|
82 |
+
|
83 |
+
class DPTDepthModel(DPT):
|
84 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
85 |
+
features = kwargs['features'] if 'features' in kwargs else 256
|
86 |
+
|
87 |
+
head = nn.Sequential(
|
88 |
+
nn.Conv2d(features,
|
89 |
+
features // 2,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1),
|
93 |
+
Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
|
94 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
95 |
+
nn.ReLU(True),
|
96 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
97 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
98 |
+
nn.Identity(),
|
99 |
+
)
|
100 |
+
|
101 |
+
super().__init__(head, **kwargs)
|
102 |
+
|
103 |
+
if path is not None:
|
104 |
+
self.load(path)
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
return super().forward(x).squeeze(dim=1)
|
preprocessing/midas/midas_net.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
4 |
+
This file contains code that is adapted from
|
5 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .base_model import BaseModel
|
11 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class MidasNet(BaseModel):
|
15 |
+
"""Network for monocular depth estimation.
|
16 |
+
"""
|
17 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print('Loading weights: ', path)
|
26 |
+
|
27 |
+
super(MidasNet, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path is None else True
|
30 |
+
|
31 |
+
self.pretrained, self.scratch = _make_encoder(
|
32 |
+
backbone='resnext101_wsl',
|
33 |
+
features=features,
|
34 |
+
use_pretrained=use_pretrained)
|
35 |
+
|
36 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
37 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
38 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
39 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
40 |
+
|
41 |
+
self.scratch.output_conv = nn.Sequential(
|
42 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
43 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
44 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
45 |
+
nn.ReLU(True),
|
46 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
47 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
48 |
+
)
|
49 |
+
|
50 |
+
if path:
|
51 |
+
self.load(path)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
"""Forward pass.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
x (tensor): input data (image)
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
tensor: depth
|
61 |
+
"""
|
62 |
+
|
63 |
+
layer_1 = self.pretrained.layer1(x)
|
64 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
65 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
66 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
67 |
+
|
68 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
69 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
70 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
71 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
72 |
+
|
73 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
74 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
75 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
76 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
77 |
+
|
78 |
+
out = self.scratch.output_conv(path_1)
|
79 |
+
|
80 |
+
return torch.squeeze(out, dim=1)
|
preprocessing/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
4 |
+
This file contains code that is adapted from
|
5 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
from .base_model import BaseModel
|
11 |
+
from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
|
12 |
+
|
13 |
+
|
14 |
+
class MidasNet_small(BaseModel):
|
15 |
+
"""Network for monocular depth estimation.
|
16 |
+
"""
|
17 |
+
def __init__(self,
|
18 |
+
path=None,
|
19 |
+
features=64,
|
20 |
+
backbone='efficientnet_lite3',
|
21 |
+
non_negative=True,
|
22 |
+
exportable=True,
|
23 |
+
channels_last=False,
|
24 |
+
align_corners=True,
|
25 |
+
blocks={'expand': True}):
|
26 |
+
"""Init.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
path (str, optional): Path to saved model. Defaults to None.
|
30 |
+
features (int, optional): Number of features. Defaults to 256.
|
31 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
32 |
+
"""
|
33 |
+
print('Loading weights: ', path)
|
34 |
+
|
35 |
+
super(MidasNet_small, self).__init__()
|
36 |
+
|
37 |
+
use_pretrained = False if path else True
|
38 |
+
|
39 |
+
self.channels_last = channels_last
|
40 |
+
self.blocks = blocks
|
41 |
+
self.backbone = backbone
|
42 |
+
|
43 |
+
self.groups = 1
|
44 |
+
|
45 |
+
features1 = features
|
46 |
+
features2 = features
|
47 |
+
features3 = features
|
48 |
+
features4 = features
|
49 |
+
self.expand = False
|
50 |
+
if 'expand' in self.blocks and self.blocks['expand'] is True:
|
51 |
+
self.expand = True
|
52 |
+
features1 = features
|
53 |
+
features2 = features * 2
|
54 |
+
features3 = features * 4
|
55 |
+
features4 = features * 8
|
56 |
+
|
57 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone,
|
58 |
+
features,
|
59 |
+
use_pretrained,
|
60 |
+
groups=self.groups,
|
61 |
+
expand=self.expand,
|
62 |
+
exportable=exportable)
|
63 |
+
|
64 |
+
self.scratch.activation = nn.ReLU(False)
|
65 |
+
|
66 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(
|
67 |
+
features4,
|
68 |
+
self.scratch.activation,
|
69 |
+
deconv=False,
|
70 |
+
bn=False,
|
71 |
+
expand=self.expand,
|
72 |
+
align_corners=align_corners)
|
73 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(
|
74 |
+
features3,
|
75 |
+
self.scratch.activation,
|
76 |
+
deconv=False,
|
77 |
+
bn=False,
|
78 |
+
expand=self.expand,
|
79 |
+
align_corners=align_corners)
|
80 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(
|
81 |
+
features2,
|
82 |
+
self.scratch.activation,
|
83 |
+
deconv=False,
|
84 |
+
bn=False,
|
85 |
+
expand=self.expand,
|
86 |
+
align_corners=align_corners)
|
87 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(
|
88 |
+
features1,
|
89 |
+
self.scratch.activation,
|
90 |
+
deconv=False,
|
91 |
+
bn=False,
|
92 |
+
align_corners=align_corners)
|
93 |
+
|
94 |
+
self.scratch.output_conv = nn.Sequential(
|
95 |
+
nn.Conv2d(features,
|
96 |
+
features // 2,
|
97 |
+
kernel_size=3,
|
98 |
+
stride=1,
|
99 |
+
padding=1,
|
100 |
+
groups=self.groups),
|
101 |
+
Interpolate(scale_factor=2, mode='bilinear'),
|
102 |
+
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
|
103 |
+
self.scratch.activation,
|
104 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
105 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
106 |
+
nn.Identity(),
|
107 |
+
)
|
108 |
+
|
109 |
+
if path:
|
110 |
+
self.load(path)
|
111 |
+
|
112 |
+
def forward(self, x):
|
113 |
+
"""Forward pass.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
x (tensor): input data (image)
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
tensor: depth
|
120 |
+
"""
|
121 |
+
if self.channels_last is True:
|
122 |
+
print('self.channels_last = ', self.channels_last)
|
123 |
+
x.contiguous(memory_format=torch.channels_last)
|
124 |
+
|
125 |
+
layer_1 = self.pretrained.layer1(x)
|
126 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
127 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
128 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
129 |
+
|
130 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
131 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
132 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
133 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
134 |
+
|
135 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
136 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
137 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
138 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
139 |
+
|
140 |
+
out = self.scratch.output_conv(path_1)
|
141 |
+
|
142 |
+
return torch.squeeze(out, dim=1)
|
143 |
+
|
144 |
+
|
145 |
+
def fuse_model(m):
|
146 |
+
prev_previous_type = nn.Identity()
|
147 |
+
prev_previous_name = ''
|
148 |
+
previous_type = nn.Identity()
|
149 |
+
previous_name = ''
|
150 |
+
for name, module in m.named_modules():
|
151 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
|
152 |
+
module) == nn.ReLU:
|
153 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
154 |
+
torch.quantization.fuse_modules(
|
155 |
+
m, [prev_previous_name, previous_name, name], inplace=True)
|
156 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
157 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
158 |
+
torch.quantization.fuse_modules(
|
159 |
+
m, [prev_previous_name, previous_name], inplace=True)
|
160 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
161 |
+
# print("FUSED ", previous_name, name)
|
162 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
163 |
+
|
164 |
+
prev_previous_type = previous_type
|
165 |
+
prev_previous_name = previous_name
|
166 |
+
previous_type = type(module)
|
167 |
+
previous_name = name
|
preprocessing/midas/transforms.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
10 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
sample (dict): sample
|
14 |
+
size (tuple): image size
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
tuple: new size
|
18 |
+
"""
|
19 |
+
shape = list(sample['disparity'].shape)
|
20 |
+
|
21 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
22 |
+
return sample
|
23 |
+
|
24 |
+
scale = [0, 0]
|
25 |
+
scale[0] = size[0] / shape[0]
|
26 |
+
scale[1] = size[1] / shape[1]
|
27 |
+
|
28 |
+
scale = max(scale)
|
29 |
+
|
30 |
+
shape[0] = math.ceil(scale * shape[0])
|
31 |
+
shape[1] = math.ceil(scale * shape[1])
|
32 |
+
|
33 |
+
# resize
|
34 |
+
sample['image'] = cv2.resize(sample['image'],
|
35 |
+
tuple(shape[::-1]),
|
36 |
+
interpolation=image_interpolation_method)
|
37 |
+
|
38 |
+
sample['disparity'] = cv2.resize(sample['disparity'],
|
39 |
+
tuple(shape[::-1]),
|
40 |
+
interpolation=cv2.INTER_NEAREST)
|
41 |
+
sample['mask'] = cv2.resize(
|
42 |
+
sample['mask'].astype(np.float32),
|
43 |
+
tuple(shape[::-1]),
|
44 |
+
interpolation=cv2.INTER_NEAREST,
|
45 |
+
)
|
46 |
+
sample['mask'] = sample['mask'].astype(bool)
|
47 |
+
|
48 |
+
return tuple(shape)
|
49 |
+
|
50 |
+
|
51 |
+
class Resize(object):
|
52 |
+
"""Resize sample to given size (width, height).
|
53 |
+
"""
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
width,
|
57 |
+
height,
|
58 |
+
resize_target=True,
|
59 |
+
keep_aspect_ratio=False,
|
60 |
+
ensure_multiple_of=1,
|
61 |
+
resize_method='lower_bound',
|
62 |
+
image_interpolation_method=cv2.INTER_AREA,
|
63 |
+
):
|
64 |
+
"""Init.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
width (int): desired output width
|
68 |
+
height (int): desired output height
|
69 |
+
resize_target (bool, optional):
|
70 |
+
True: Resize the full sample (image, mask, target).
|
71 |
+
False: Resize image only.
|
72 |
+
Defaults to True.
|
73 |
+
keep_aspect_ratio (bool, optional):
|
74 |
+
True: Keep the aspect ratio of the input sample.
|
75 |
+
Output sample might not have the given width and height, and
|
76 |
+
resize behaviour depends on the parameter 'resize_method'.
|
77 |
+
Defaults to False.
|
78 |
+
ensure_multiple_of (int, optional):
|
79 |
+
Output width and height is constrained to be multiple of this parameter.
|
80 |
+
Defaults to 1.
|
81 |
+
resize_method (str, optional):
|
82 |
+
"lower_bound": Output will be at least as large as the given size.
|
83 |
+
"upper_bound": Output will be at max as large as the given size. "
|
84 |
+
"(Output size might be smaller than given size.)"
|
85 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
86 |
+
Defaults to "lower_bound".
|
87 |
+
"""
|
88 |
+
self.__width = width
|
89 |
+
self.__height = height
|
90 |
+
|
91 |
+
self.__resize_target = resize_target
|
92 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
93 |
+
self.__multiple_of = ensure_multiple_of
|
94 |
+
self.__resize_method = resize_method
|
95 |
+
self.__image_interpolation_method = image_interpolation_method
|
96 |
+
|
97 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
98 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if max_val is not None and y > max_val:
|
101 |
+
y = (np.floor(x / self.__multiple_of) *
|
102 |
+
self.__multiple_of).astype(int)
|
103 |
+
|
104 |
+
if y < min_val:
|
105 |
+
y = (np.ceil(x / self.__multiple_of) *
|
106 |
+
self.__multiple_of).astype(int)
|
107 |
+
|
108 |
+
return y
|
109 |
+
|
110 |
+
def get_size(self, width, height):
|
111 |
+
# determine new height and width
|
112 |
+
scale_height = self.__height / height
|
113 |
+
scale_width = self.__width / width
|
114 |
+
|
115 |
+
if self.__keep_aspect_ratio:
|
116 |
+
if self.__resize_method == 'lower_bound':
|
117 |
+
# scale such that output size is lower bound
|
118 |
+
if scale_width > scale_height:
|
119 |
+
# fit width
|
120 |
+
scale_height = scale_width
|
121 |
+
else:
|
122 |
+
# fit height
|
123 |
+
scale_width = scale_height
|
124 |
+
elif self.__resize_method == 'upper_bound':
|
125 |
+
# scale such that output size is upper bound
|
126 |
+
if scale_width < scale_height:
|
127 |
+
# fit width
|
128 |
+
scale_height = scale_width
|
129 |
+
else:
|
130 |
+
# fit height
|
131 |
+
scale_width = scale_height
|
132 |
+
elif self.__resize_method == 'minimal':
|
133 |
+
# scale as least as possbile
|
134 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
135 |
+
# fit width
|
136 |
+
scale_height = scale_width
|
137 |
+
else:
|
138 |
+
# fit height
|
139 |
+
scale_width = scale_height
|
140 |
+
else:
|
141 |
+
raise ValueError(
|
142 |
+
f'resize_method {self.__resize_method} not implemented')
|
143 |
+
|
144 |
+
if self.__resize_method == 'lower_bound':
|
145 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
146 |
+
min_val=self.__height)
|
147 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
148 |
+
min_val=self.__width)
|
149 |
+
elif self.__resize_method == 'upper_bound':
|
150 |
+
new_height = self.constrain_to_multiple_of(scale_height * height,
|
151 |
+
max_val=self.__height)
|
152 |
+
new_width = self.constrain_to_multiple_of(scale_width * width,
|
153 |
+
max_val=self.__width)
|
154 |
+
elif self.__resize_method == 'minimal':
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(
|
159 |
+
f'resize_method {self.__resize_method} not implemented')
|
160 |
+
|
161 |
+
return (new_width, new_height)
|
162 |
+
|
163 |
+
def __call__(self, sample):
|
164 |
+
width, height = self.get_size(sample['image'].shape[1],
|
165 |
+
sample['image'].shape[0])
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample['image'] = cv2.resize(
|
169 |
+
sample['image'],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if 'disparity' in sample:
|
176 |
+
sample['disparity'] = cv2.resize(
|
177 |
+
sample['disparity'],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if 'depth' in sample:
|
183 |
+
sample['depth'] = cv2.resize(sample['depth'], (width, height),
|
184 |
+
interpolation=cv2.INTER_NEAREST)
|
185 |
+
|
186 |
+
sample['mask'] = cv2.resize(
|
187 |
+
sample['mask'].astype(np.float32),
|
188 |
+
(width, height),
|
189 |
+
interpolation=cv2.INTER_NEAREST,
|
190 |
+
)
|
191 |
+
sample['mask'] = sample['mask'].astype(bool)
|
192 |
+
|
193 |
+
return sample
|
194 |
+
|
195 |
+
|
196 |
+
class NormalizeImage(object):
|
197 |
+
"""Normlize image by given mean and std.
|
198 |
+
"""
|
199 |
+
def __init__(self, mean, std):
|
200 |
+
self.__mean = mean
|
201 |
+
self.__std = std
|
202 |
+
|
203 |
+
def __call__(self, sample):
|
204 |
+
sample['image'] = (sample['image'] - self.__mean) / self.__std
|
205 |
+
|
206 |
+
return sample
|
207 |
+
|
208 |
+
|
209 |
+
class PrepareForNet(object):
|
210 |
+
"""Prepare sample for usage as network input.
|
211 |
+
"""
|
212 |
+
def __init__(self):
|
213 |
+
pass
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
image = np.transpose(sample['image'], (2, 0, 1))
|
217 |
+
sample['image'] = np.ascontiguousarray(image).astype(np.float32)
|
218 |
+
|
219 |
+
if 'mask' in sample:
|
220 |
+
sample['mask'] = sample['mask'].astype(np.float32)
|
221 |
+
sample['mask'] = np.ascontiguousarray(sample['mask'])
|
222 |
+
|
223 |
+
if 'disparity' in sample:
|
224 |
+
disparity = sample['disparity'].astype(np.float32)
|
225 |
+
sample['disparity'] = np.ascontiguousarray(disparity)
|
226 |
+
|
227 |
+
if 'depth' in sample:
|
228 |
+
depth = sample['depth'].astype(np.float32)
|
229 |
+
sample['depth'] = np.ascontiguousarray(depth)
|
230 |
+
|
231 |
+
return sample
|
preprocessing/midas/utils.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
"""Utils for monoDepth."""
|
4 |
+
import re
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def read_pfm(path):
|
13 |
+
"""Read pfm file.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
path (str): path to file
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
tuple: (data, scale)
|
20 |
+
"""
|
21 |
+
with open(path, 'rb') as file:
|
22 |
+
|
23 |
+
color = None
|
24 |
+
width = None
|
25 |
+
height = None
|
26 |
+
scale = None
|
27 |
+
endian = None
|
28 |
+
|
29 |
+
header = file.readline().rstrip()
|
30 |
+
if header.decode('ascii') == 'PF':
|
31 |
+
color = True
|
32 |
+
elif header.decode('ascii') == 'Pf':
|
33 |
+
color = False
|
34 |
+
else:
|
35 |
+
raise Exception('Not a PFM file: ' + path)
|
36 |
+
|
37 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$',
|
38 |
+
file.readline().decode('ascii'))
|
39 |
+
if dim_match:
|
40 |
+
width, height = list(map(int, dim_match.groups()))
|
41 |
+
else:
|
42 |
+
raise Exception('Malformed PFM header.')
|
43 |
+
|
44 |
+
scale = float(file.readline().decode('ascii').rstrip())
|
45 |
+
if scale < 0:
|
46 |
+
# little-endian
|
47 |
+
endian = '<'
|
48 |
+
scale = -scale
|
49 |
+
else:
|
50 |
+
# big-endian
|
51 |
+
endian = '>'
|
52 |
+
|
53 |
+
data = np.fromfile(file, endian + 'f')
|
54 |
+
shape = (height, width, 3) if color else (height, width)
|
55 |
+
|
56 |
+
data = np.reshape(data, shape)
|
57 |
+
data = np.flipud(data)
|
58 |
+
|
59 |
+
return data, scale
|
60 |
+
|
61 |
+
|
62 |
+
def write_pfm(path, image, scale=1):
|
63 |
+
"""Write pfm file.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
path (str): pathto file
|
67 |
+
image (array): data
|
68 |
+
scale (int, optional): Scale. Defaults to 1.
|
69 |
+
"""
|
70 |
+
|
71 |
+
with open(path, 'wb') as file:
|
72 |
+
color = None
|
73 |
+
|
74 |
+
if image.dtype.name != 'float32':
|
75 |
+
raise Exception('Image dtype must be float32.')
|
76 |
+
|
77 |
+
image = np.flipud(image)
|
78 |
+
|
79 |
+
if len(image.shape) == 3 and image.shape[2] == 3: # color image
|
80 |
+
color = True
|
81 |
+
elif (len(image.shape) == 2
|
82 |
+
or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
|
83 |
+
color = False
|
84 |
+
else:
|
85 |
+
raise Exception(
|
86 |
+
'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
|
87 |
+
|
88 |
+
file.write('PF\n' if color else 'Pf\n'.encode())
|
89 |
+
file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
|
90 |
+
|
91 |
+
endian = image.dtype.byteorder
|
92 |
+
|
93 |
+
if endian == '<' or endian == '=' and sys.byteorder == 'little':
|
94 |
+
scale = -scale
|
95 |
+
|
96 |
+
file.write('%f\n'.encode() % scale)
|
97 |
+
|
98 |
+
image.tofile(file)
|
99 |
+
|
100 |
+
|
101 |
+
def read_image(path):
|
102 |
+
"""Read image and output RGB image (0-1).
|
103 |
+
|
104 |
+
Args:
|
105 |
+
path (str): path to file
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
array: RGB image (0-1)
|
109 |
+
"""
|
110 |
+
img = cv2.imread(path)
|
111 |
+
|
112 |
+
if img.ndim == 2:
|
113 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
114 |
+
|
115 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
116 |
+
|
117 |
+
return img
|
118 |
+
|
119 |
+
|
120 |
+
def resize_image(img):
|
121 |
+
"""Resize image and make it fit for network.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
img (array): image
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tensor: data ready for network
|
128 |
+
"""
|
129 |
+
height_orig = img.shape[0]
|
130 |
+
width_orig = img.shape[1]
|
131 |
+
|
132 |
+
if width_orig > height_orig:
|
133 |
+
scale = width_orig / 384
|
134 |
+
else:
|
135 |
+
scale = height_orig / 384
|
136 |
+
|
137 |
+
height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
|
138 |
+
width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
|
139 |
+
|
140 |
+
img_resized = cv2.resize(img, (width, height),
|
141 |
+
interpolation=cv2.INTER_AREA)
|
142 |
+
|
143 |
+
img_resized = (torch.from_numpy(np.transpose(
|
144 |
+
img_resized, (2, 0, 1))).contiguous().float())
|
145 |
+
img_resized = img_resized.unsqueeze(0)
|
146 |
+
|
147 |
+
return img_resized
|
148 |
+
|
149 |
+
|
150 |
+
def resize_depth(depth, width, height):
|
151 |
+
"""Resize depth map and bring to CPU (numpy).
|
152 |
+
|
153 |
+
Args:
|
154 |
+
depth (tensor): depth
|
155 |
+
width (int): image width
|
156 |
+
height (int): image height
|
157 |
+
|
158 |
+
Returns:
|
159 |
+
array: processed depth
|
160 |
+
"""
|
161 |
+
depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
|
162 |
+
|
163 |
+
depth_resized = cv2.resize(depth.numpy(), (width, height),
|
164 |
+
interpolation=cv2.INTER_CUBIC)
|
165 |
+
|
166 |
+
return depth_resized
|
167 |
+
|
168 |
+
|
169 |
+
def write_depth(path, depth, bits=1):
|
170 |
+
"""Write depth map to pfm and png file.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
path (str): filepath without extension
|
174 |
+
depth (array): depth
|
175 |
+
"""
|
176 |
+
write_pfm(path + '.pfm', depth.astype(np.float32))
|
177 |
+
|
178 |
+
depth_min = depth.min()
|
179 |
+
depth_max = depth.max()
|
180 |
+
|
181 |
+
max_val = (2**(8 * bits)) - 1
|
182 |
+
|
183 |
+
if depth_max - depth_min > np.finfo('float').eps:
|
184 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
185 |
+
else:
|
186 |
+
out = np.zeros(depth.shape, dtype=depth.type)
|
187 |
+
|
188 |
+
if bits == 1:
|
189 |
+
cv2.imwrite(path + '.png', out.astype('uint8'))
|
190 |
+
elif bits == 2:
|
191 |
+
cv2.imwrite(path + '.png', out.astype('uint16'))
|
192 |
+
|
193 |
+
return
|
preprocessing/midas/vit.py
ADDED
@@ -0,0 +1,510 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
+
import math
|
4 |
+
import types
|
5 |
+
|
6 |
+
import timm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class Slice(nn.Module):
|
13 |
+
def __init__(self, start_index=1):
|
14 |
+
super(Slice, self).__init__()
|
15 |
+
self.start_index = start_index
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
return x[:, self.start_index:]
|
19 |
+
|
20 |
+
|
21 |
+
class AddReadout(nn.Module):
|
22 |
+
def __init__(self, start_index=1):
|
23 |
+
super(AddReadout, self).__init__()
|
24 |
+
self.start_index = start_index
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
if self.start_index == 2:
|
28 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
29 |
+
else:
|
30 |
+
readout = x[:, 0]
|
31 |
+
return x[:, self.start_index:] + readout.unsqueeze(1)
|
32 |
+
|
33 |
+
|
34 |
+
class ProjectReadout(nn.Module):
|
35 |
+
def __init__(self, in_features, start_index=1):
|
36 |
+
super(ProjectReadout, self).__init__()
|
37 |
+
self.start_index = start_index
|
38 |
+
|
39 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
|
40 |
+
nn.GELU())
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
|
44 |
+
features = torch.cat((x[:, self.start_index:], readout), -1)
|
45 |
+
|
46 |
+
return self.project(features)
|
47 |
+
|
48 |
+
|
49 |
+
class Transpose(nn.Module):
|
50 |
+
def __init__(self, dim0, dim1):
|
51 |
+
super(Transpose, self).__init__()
|
52 |
+
self.dim0 = dim0
|
53 |
+
self.dim1 = dim1
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
x = x.transpose(self.dim0, self.dim1)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
def forward_vit(pretrained, x):
|
61 |
+
b, c, h, w = x.shape
|
62 |
+
|
63 |
+
_ = pretrained.model.forward_flex(x)
|
64 |
+
|
65 |
+
layer_1 = pretrained.activations['1']
|
66 |
+
layer_2 = pretrained.activations['2']
|
67 |
+
layer_3 = pretrained.activations['3']
|
68 |
+
layer_4 = pretrained.activations['4']
|
69 |
+
|
70 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
71 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
72 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
73 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
74 |
+
|
75 |
+
unflatten = nn.Sequential(
|
76 |
+
nn.Unflatten(
|
77 |
+
2,
|
78 |
+
torch.Size([
|
79 |
+
h // pretrained.model.patch_size[1],
|
80 |
+
w // pretrained.model.patch_size[0],
|
81 |
+
]),
|
82 |
+
))
|
83 |
+
|
84 |
+
if layer_1.ndim == 3:
|
85 |
+
layer_1 = unflatten(layer_1)
|
86 |
+
if layer_2.ndim == 3:
|
87 |
+
layer_2 = unflatten(layer_2)
|
88 |
+
if layer_3.ndim == 3:
|
89 |
+
layer_3 = unflatten(layer_3)
|
90 |
+
if layer_4.ndim == 3:
|
91 |
+
layer_4 = unflatten(layer_4)
|
92 |
+
|
93 |
+
layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
|
94 |
+
layer_1)
|
95 |
+
layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
|
96 |
+
layer_2)
|
97 |
+
layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
|
98 |
+
layer_3)
|
99 |
+
layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
|
100 |
+
layer_4)
|
101 |
+
|
102 |
+
return layer_1, layer_2, layer_3, layer_4
|
103 |
+
|
104 |
+
|
105 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
106 |
+
posemb_tok, posemb_grid = (
|
107 |
+
posemb[:, :self.start_index],
|
108 |
+
posemb[0, self.start_index:],
|
109 |
+
)
|
110 |
+
|
111 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
112 |
+
|
113 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
|
114 |
+
-1).permute(0, 3, 1, 2)
|
115 |
+
posemb_grid = F.interpolate(posemb_grid,
|
116 |
+
size=(gs_h, gs_w),
|
117 |
+
mode='bilinear')
|
118 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
119 |
+
|
120 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
121 |
+
|
122 |
+
return posemb
|
123 |
+
|
124 |
+
|
125 |
+
def forward_flex(self, x):
|
126 |
+
b, c, h, w = x.shape
|
127 |
+
|
128 |
+
pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
|
129 |
+
w // self.patch_size[0])
|
130 |
+
|
131 |
+
B = x.shape[0]
|
132 |
+
|
133 |
+
if hasattr(self.patch_embed, 'backbone'):
|
134 |
+
x = self.patch_embed.backbone(x)
|
135 |
+
if isinstance(x, (list, tuple)):
|
136 |
+
x = x[
|
137 |
+
-1] # last feature if backbone outputs list/tuple of features
|
138 |
+
|
139 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
140 |
+
|
141 |
+
if getattr(self, 'dist_token', None) is not None:
|
142 |
+
cls_tokens = self.cls_token.expand(
|
143 |
+
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
144 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
145 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
146 |
+
else:
|
147 |
+
cls_tokens = self.cls_token.expand(
|
148 |
+
B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
149 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
150 |
+
|
151 |
+
x = x + pos_embed
|
152 |
+
x = self.pos_drop(x)
|
153 |
+
|
154 |
+
for blk in self.blocks:
|
155 |
+
x = blk(x)
|
156 |
+
|
157 |
+
x = self.norm(x)
|
158 |
+
|
159 |
+
return x
|
160 |
+
|
161 |
+
|
162 |
+
activations = {}
|
163 |
+
|
164 |
+
|
165 |
+
def get_activation(name):
|
166 |
+
def hook(model, input, output):
|
167 |
+
activations[name] = output
|
168 |
+
|
169 |
+
return hook
|
170 |
+
|
171 |
+
|
172 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
173 |
+
if use_readout == 'ignore':
|
174 |
+
readout_oper = [Slice(start_index)] * len(features)
|
175 |
+
elif use_readout == 'add':
|
176 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
177 |
+
elif use_readout == 'project':
|
178 |
+
readout_oper = [
|
179 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
180 |
+
]
|
181 |
+
else:
|
182 |
+
assert (
|
183 |
+
False
|
184 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
185 |
+
|
186 |
+
return readout_oper
|
187 |
+
|
188 |
+
|
189 |
+
def _make_vit_b16_backbone(
|
190 |
+
model,
|
191 |
+
features=[96, 192, 384, 768],
|
192 |
+
size=[384, 384],
|
193 |
+
hooks=[2, 5, 8, 11],
|
194 |
+
vit_features=768,
|
195 |
+
use_readout='ignore',
|
196 |
+
start_index=1,
|
197 |
+
):
|
198 |
+
pretrained = nn.Module()
|
199 |
+
|
200 |
+
pretrained.model = model
|
201 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
202 |
+
get_activation('1'))
|
203 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
204 |
+
get_activation('2'))
|
205 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
206 |
+
get_activation('3'))
|
207 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
208 |
+
get_activation('4'))
|
209 |
+
|
210 |
+
pretrained.activations = activations
|
211 |
+
|
212 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
213 |
+
start_index)
|
214 |
+
|
215 |
+
# 32, 48, 136, 384
|
216 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
217 |
+
readout_oper[0],
|
218 |
+
Transpose(1, 2),
|
219 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
220 |
+
nn.Conv2d(
|
221 |
+
in_channels=vit_features,
|
222 |
+
out_channels=features[0],
|
223 |
+
kernel_size=1,
|
224 |
+
stride=1,
|
225 |
+
padding=0,
|
226 |
+
),
|
227 |
+
nn.ConvTranspose2d(
|
228 |
+
in_channels=features[0],
|
229 |
+
out_channels=features[0],
|
230 |
+
kernel_size=4,
|
231 |
+
stride=4,
|
232 |
+
padding=0,
|
233 |
+
bias=True,
|
234 |
+
dilation=1,
|
235 |
+
groups=1,
|
236 |
+
),
|
237 |
+
)
|
238 |
+
|
239 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
240 |
+
readout_oper[1],
|
241 |
+
Transpose(1, 2),
|
242 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
243 |
+
nn.Conv2d(
|
244 |
+
in_channels=vit_features,
|
245 |
+
out_channels=features[1],
|
246 |
+
kernel_size=1,
|
247 |
+
stride=1,
|
248 |
+
padding=0,
|
249 |
+
),
|
250 |
+
nn.ConvTranspose2d(
|
251 |
+
in_channels=features[1],
|
252 |
+
out_channels=features[1],
|
253 |
+
kernel_size=2,
|
254 |
+
stride=2,
|
255 |
+
padding=0,
|
256 |
+
bias=True,
|
257 |
+
dilation=1,
|
258 |
+
groups=1,
|
259 |
+
),
|
260 |
+
)
|
261 |
+
|
262 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
263 |
+
readout_oper[2],
|
264 |
+
Transpose(1, 2),
|
265 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
266 |
+
nn.Conv2d(
|
267 |
+
in_channels=vit_features,
|
268 |
+
out_channels=features[2],
|
269 |
+
kernel_size=1,
|
270 |
+
stride=1,
|
271 |
+
padding=0,
|
272 |
+
),
|
273 |
+
)
|
274 |
+
|
275 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
276 |
+
readout_oper[3],
|
277 |
+
Transpose(1, 2),
|
278 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
279 |
+
nn.Conv2d(
|
280 |
+
in_channels=vit_features,
|
281 |
+
out_channels=features[3],
|
282 |
+
kernel_size=1,
|
283 |
+
stride=1,
|
284 |
+
padding=0,
|
285 |
+
),
|
286 |
+
nn.Conv2d(
|
287 |
+
in_channels=features[3],
|
288 |
+
out_channels=features[3],
|
289 |
+
kernel_size=3,
|
290 |
+
stride=2,
|
291 |
+
padding=1,
|
292 |
+
),
|
293 |
+
)
|
294 |
+
|
295 |
+
pretrained.model.start_index = start_index
|
296 |
+
pretrained.model.patch_size = [16, 16]
|
297 |
+
|
298 |
+
# We inject this function into the VisionTransformer instances so that
|
299 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
300 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
301 |
+
pretrained.model)
|
302 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
303 |
+
_resize_pos_embed, pretrained.model)
|
304 |
+
|
305 |
+
return pretrained
|
306 |
+
|
307 |
+
|
308 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
|
309 |
+
model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
|
310 |
+
|
311 |
+
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
312 |
+
return _make_vit_b16_backbone(
|
313 |
+
model,
|
314 |
+
features=[256, 512, 1024, 1024],
|
315 |
+
hooks=hooks,
|
316 |
+
vit_features=1024,
|
317 |
+
use_readout=use_readout,
|
318 |
+
)
|
319 |
+
|
320 |
+
|
321 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
|
322 |
+
model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
|
323 |
+
|
324 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
325 |
+
return _make_vit_b16_backbone(model,
|
326 |
+
features=[96, 192, 384, 768],
|
327 |
+
hooks=hooks,
|
328 |
+
use_readout=use_readout)
|
329 |
+
|
330 |
+
|
331 |
+
def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
|
332 |
+
model = timm.create_model('vit_deit_base_patch16_384',
|
333 |
+
pretrained=pretrained)
|
334 |
+
|
335 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
336 |
+
return _make_vit_b16_backbone(model,
|
337 |
+
features=[96, 192, 384, 768],
|
338 |
+
hooks=hooks,
|
339 |
+
use_readout=use_readout)
|
340 |
+
|
341 |
+
|
342 |
+
def _make_pretrained_deitb16_distil_384(pretrained,
|
343 |
+
use_readout='ignore',
|
344 |
+
hooks=None):
|
345 |
+
model = timm.create_model('vit_deit_base_distilled_patch16_384',
|
346 |
+
pretrained=pretrained)
|
347 |
+
|
348 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
349 |
+
return _make_vit_b16_backbone(
|
350 |
+
model,
|
351 |
+
features=[96, 192, 384, 768],
|
352 |
+
hooks=hooks,
|
353 |
+
use_readout=use_readout,
|
354 |
+
start_index=2,
|
355 |
+
)
|
356 |
+
|
357 |
+
|
358 |
+
def _make_vit_b_rn50_backbone(
|
359 |
+
model,
|
360 |
+
features=[256, 512, 768, 768],
|
361 |
+
size=[384, 384],
|
362 |
+
hooks=[0, 1, 8, 11],
|
363 |
+
vit_features=768,
|
364 |
+
use_vit_only=False,
|
365 |
+
use_readout='ignore',
|
366 |
+
start_index=1,
|
367 |
+
):
|
368 |
+
pretrained = nn.Module()
|
369 |
+
|
370 |
+
pretrained.model = model
|
371 |
+
|
372 |
+
if use_vit_only is True:
|
373 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(
|
374 |
+
get_activation('1'))
|
375 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(
|
376 |
+
get_activation('2'))
|
377 |
+
else:
|
378 |
+
pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
|
379 |
+
get_activation('1'))
|
380 |
+
pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
|
381 |
+
get_activation('2'))
|
382 |
+
|
383 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(
|
384 |
+
get_activation('3'))
|
385 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(
|
386 |
+
get_activation('4'))
|
387 |
+
|
388 |
+
pretrained.activations = activations
|
389 |
+
|
390 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout,
|
391 |
+
start_index)
|
392 |
+
|
393 |
+
if use_vit_only is True:
|
394 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
395 |
+
readout_oper[0],
|
396 |
+
Transpose(1, 2),
|
397 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
398 |
+
nn.Conv2d(
|
399 |
+
in_channels=vit_features,
|
400 |
+
out_channels=features[0],
|
401 |
+
kernel_size=1,
|
402 |
+
stride=1,
|
403 |
+
padding=0,
|
404 |
+
),
|
405 |
+
nn.ConvTranspose2d(
|
406 |
+
in_channels=features[0],
|
407 |
+
out_channels=features[0],
|
408 |
+
kernel_size=4,
|
409 |
+
stride=4,
|
410 |
+
padding=0,
|
411 |
+
bias=True,
|
412 |
+
dilation=1,
|
413 |
+
groups=1,
|
414 |
+
),
|
415 |
+
)
|
416 |
+
|
417 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
418 |
+
readout_oper[1],
|
419 |
+
Transpose(1, 2),
|
420 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
421 |
+
nn.Conv2d(
|
422 |
+
in_channels=vit_features,
|
423 |
+
out_channels=features[1],
|
424 |
+
kernel_size=1,
|
425 |
+
stride=1,
|
426 |
+
padding=0,
|
427 |
+
),
|
428 |
+
nn.ConvTranspose2d(
|
429 |
+
in_channels=features[1],
|
430 |
+
out_channels=features[1],
|
431 |
+
kernel_size=2,
|
432 |
+
stride=2,
|
433 |
+
padding=0,
|
434 |
+
bias=True,
|
435 |
+
dilation=1,
|
436 |
+
groups=1,
|
437 |
+
),
|
438 |
+
)
|
439 |
+
else:
|
440 |
+
pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
|
441 |
+
nn.Identity(),
|
442 |
+
nn.Identity())
|
443 |
+
pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
|
444 |
+
nn.Identity(),
|
445 |
+
nn.Identity())
|
446 |
+
|
447 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
448 |
+
readout_oper[2],
|
449 |
+
Transpose(1, 2),
|
450 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
451 |
+
nn.Conv2d(
|
452 |
+
in_channels=vit_features,
|
453 |
+
out_channels=features[2],
|
454 |
+
kernel_size=1,
|
455 |
+
stride=1,
|
456 |
+
padding=0,
|
457 |
+
),
|
458 |
+
)
|
459 |
+
|
460 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
461 |
+
readout_oper[3],
|
462 |
+
Transpose(1, 2),
|
463 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
464 |
+
nn.Conv2d(
|
465 |
+
in_channels=vit_features,
|
466 |
+
out_channels=features[3],
|
467 |
+
kernel_size=1,
|
468 |
+
stride=1,
|
469 |
+
padding=0,
|
470 |
+
),
|
471 |
+
nn.Conv2d(
|
472 |
+
in_channels=features[3],
|
473 |
+
out_channels=features[3],
|
474 |
+
kernel_size=3,
|
475 |
+
stride=2,
|
476 |
+
padding=1,
|
477 |
+
),
|
478 |
+
)
|
479 |
+
|
480 |
+
pretrained.model.start_index = start_index
|
481 |
+
pretrained.model.patch_size = [16, 16]
|
482 |
+
|
483 |
+
# We inject this function into the VisionTransformer instances so that
|
484 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
485 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex,
|
486 |
+
pretrained.model)
|
487 |
+
|
488 |
+
# We inject this function into the VisionTransformer instances so that
|
489 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
490 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
491 |
+
_resize_pos_embed, pretrained.model)
|
492 |
+
|
493 |
+
return pretrained
|
494 |
+
|
495 |
+
|
496 |
+
def _make_pretrained_vitb_rn50_384(pretrained,
|
497 |
+
use_readout='ignore',
|
498 |
+
hooks=None,
|
499 |
+
use_vit_only=False):
|
500 |
+
model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
|
501 |
+
|
502 |
+
hooks = [0, 1, 8, 11] if hooks is None else hooks
|
503 |
+
return _make_vit_b_rn50_backbone(
|
504 |
+
model,
|
505 |
+
features=[256, 512, 768, 768],
|
506 |
+
size=[384, 384],
|
507 |
+
hooks=hooks,
|
508 |
+
use_vit_only=use_vit_only,
|
509 |
+
use_readout=use_readout,
|
510 |
+
)
|
wan/text2video.py
CHANGED
@@ -207,18 +207,19 @@ class WanT2V:
|
|
207 |
def vace_latent(self, z, m):
|
208 |
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
209 |
|
210 |
-
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
|
211 |
image_sizes = []
|
212 |
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
213 |
if sub_src_mask is not None and sub_src_video is not None:
|
214 |
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
|
|
|
|
|
215 |
src_video[i] = src_video[i].to(device)
|
216 |
src_mask[i] = src_mask[i].to(device)
|
217 |
src_video_shape = src_video[i].shape
|
218 |
if src_video_shape[1] != num_frames:
|
219 |
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
220 |
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
221 |
-
|
222 |
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
223 |
image_sizes.append(src_video[i].shape[2:])
|
224 |
elif sub_src_video is None:
|
@@ -228,10 +229,11 @@ class WanT2V:
|
|
228 |
else:
|
229 |
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
|
230 |
src_video[i] = src_video[i].to(device)
|
|
|
231 |
src_video_shape = src_video[i].shape
|
232 |
if src_video_shape[1] != num_frames:
|
233 |
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
234 |
-
|
235 |
image_sizes.append(src_video[i].shape[2:])
|
236 |
|
237 |
for i, ref_images in enumerate(src_ref_images):
|
|
|
207 |
def vace_latent(self, z, m):
|
208 |
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
|
209 |
|
210 |
+
def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
|
211 |
image_sizes = []
|
212 |
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
|
213 |
if sub_src_mask is not None and sub_src_video is not None:
|
214 |
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
|
215 |
+
# src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
|
216 |
+
# src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
|
217 |
src_video[i] = src_video[i].to(device)
|
218 |
src_mask[i] = src_mask[i].to(device)
|
219 |
src_video_shape = src_video[i].shape
|
220 |
if src_video_shape[1] != num_frames:
|
221 |
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
222 |
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
|
|
223 |
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
|
224 |
image_sizes.append(src_video[i].shape[2:])
|
225 |
elif sub_src_video is None:
|
|
|
229 |
else:
|
230 |
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
|
231 |
src_video[i] = src_video[i].to(device)
|
232 |
+
src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
|
233 |
src_video_shape = src_video[i].shape
|
234 |
if src_video_shape[1] != num_frames:
|
235 |
src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
236 |
+
src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
|
237 |
image_sizes.append(src_video[i].shape[2:])
|
238 |
|
239 |
for i, ref_images in enumerate(src_ref_images):
|
wan/utils/utils.py
CHANGED
@@ -21,6 +21,30 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
|
|
21 |
|
22 |
from PIL import Image
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def get_video_frame(file_name, frame_no):
|
25 |
decord.bridge.set_bridge('torch')
|
26 |
reader = decord.VideoReader(file_name)
|
|
|
21 |
|
22 |
from PIL import Image
|
23 |
|
24 |
+
|
25 |
+
def resample(video_fps, video_frames_count, max_frames, target_fps):
|
26 |
+
import math
|
27 |
+
|
28 |
+
video_frame_duration = 1 /video_fps
|
29 |
+
target_frame_duration = 1 / target_fps
|
30 |
+
|
31 |
+
cur_time = 0
|
32 |
+
target_time = 0
|
33 |
+
frame_no = 0
|
34 |
+
frame_ids =[]
|
35 |
+
while True:
|
36 |
+
if max_frames != 0 and len(frame_ids) >= max_frames:
|
37 |
+
break
|
38 |
+
add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
|
39 |
+
frame_no += add_frames_count
|
40 |
+
frame_ids.append(frame_no)
|
41 |
+
cur_time += add_frames_count * video_frame_duration
|
42 |
+
target_time += target_frame_duration
|
43 |
+
if frame_no >= video_frames_count -1:
|
44 |
+
break
|
45 |
+
frame_ids = frame_ids[:video_frames_count]
|
46 |
+
return frame_ids
|
47 |
+
|
48 |
def get_video_frame(file_name, frame_no):
|
49 |
decord.bridge.set_bridge('torch')
|
50 |
reader = decord.VideoReader(file_name)
|
wan/utils/vace_preprocessor.py
CHANGED
@@ -180,26 +180,17 @@ class VaceVideoProcessor(object):
|
|
180 |
), axis=1).tolist()
|
181 |
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
182 |
|
183 |
-
|
184 |
-
import math
|
185 |
-
target_fps = self.max_fps
|
186 |
-
video_frames_count = len(frame_timestamps)
|
187 |
-
video_frame_duration = 1 /fps
|
188 |
-
target_frame_duration = 1 / target_fps
|
189 |
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
target_time += target_frame_duration
|
200 |
-
if frame_no >= video_frames_count -1:
|
201 |
-
break
|
202 |
-
frame_ids = frame_ids[:video_frames_count]
|
203 |
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
204 |
h, w = y2 - y1, x2 - x1
|
205 |
ratio = h / w
|
@@ -235,11 +226,11 @@ class VaceVideoProcessor(object):
|
|
235 |
|
236 |
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
237 |
|
238 |
-
def _get_frameid_bbox(self, fps,
|
239 |
if self.keep_last:
|
240 |
-
return self._get_frameid_bbox_adjust_last(fps,
|
241 |
else:
|
242 |
-
return self._get_frameid_bbox_default(fps,
|
243 |
|
244 |
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
245 |
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
@@ -253,23 +244,37 @@ class VaceVideoProcessor(object):
|
|
253 |
import decord
|
254 |
decord.bridge.set_bridge('torch')
|
255 |
readers = []
|
|
|
256 |
for data_k in data_key_batch:
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
|
268 |
-
|
269 |
-
|
|
|
|
|
|
|
|
|
270 |
|
271 |
# preprocess video
|
272 |
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
|
|
|
|
273 |
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
274 |
return *videos, frame_ids, (oh, ow), fps
|
275 |
# return videos if len(videos) > 1 else videos[0]
|
|
|
180 |
), axis=1).tolist()
|
181 |
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
182 |
|
183 |
+
|
|
|
|
|
|
|
|
|
|
|
184 |
|
185 |
+
def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
|
186 |
+
from wan.utils.utils import resample
|
187 |
+
|
188 |
+
target_fps = self.max_fps
|
189 |
+
|
190 |
+
# video_frames_count = len(frame_timestamps)
|
191 |
+
|
192 |
+
frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
|
193 |
+
|
|
|
|
|
|
|
|
|
194 |
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
|
195 |
h, w = y2 - y1, x2 - x1
|
196 |
ratio = h / w
|
|
|
226 |
|
227 |
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
|
228 |
|
229 |
+
def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
|
230 |
if self.keep_last:
|
231 |
+
return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
|
232 |
else:
|
233 |
+
return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
|
234 |
|
235 |
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
|
236 |
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
|
|
|
244 |
import decord
|
245 |
decord.bridge.set_bridge('torch')
|
246 |
readers = []
|
247 |
+
src_video = None
|
248 |
for data_k in data_key_batch:
|
249 |
+
if torch.is_tensor(data_k):
|
250 |
+
src_video = data_k
|
251 |
+
else:
|
252 |
+
reader = decord.VideoReader(data_k)
|
253 |
+
readers.append(reader)
|
254 |
+
|
255 |
+
if src_video != None:
|
256 |
+
fps = 16
|
257 |
+
length = src_video.shape[1]
|
258 |
+
if len(readers) > 0:
|
259 |
+
min_readers = min([len(r) for r in readers])
|
260 |
+
length = min(length, min_readers )
|
261 |
+
else:
|
262 |
+
fps = readers[0].get_avg_fps()
|
263 |
+
length = min([len(r) for r in readers])
|
264 |
+
# frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
|
265 |
+
# frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
|
266 |
max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
|
267 |
+
if src_video != None:
|
268 |
+
src_video = src_video[:max_frames]
|
269 |
+
h, w = src_video.shape[1:3]
|
270 |
+
else:
|
271 |
+
h, w = readers[0].next().shape[:2]
|
272 |
+
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
|
273 |
|
274 |
# preprocess video
|
275 |
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
|
276 |
+
if src_video != None:
|
277 |
+
videos = [src_video] + videos
|
278 |
videos = [self._video_preprocess(video, oh, ow) for video in videos]
|
279 |
return *videos, frame_ids, (oh, ow), fps
|
280 |
# return videos if len(videos) > 1 else videos[0]
|
wgp.py
CHANGED
@@ -141,12 +141,27 @@ def process_prompt_and_add_tasks(state, model_choice):
|
|
141 |
res = VACE_SIZE_CONFIGS.keys().join(" and ")
|
142 |
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
143 |
return
|
144 |
-
if
|
|
|
|
|
|
|
|
|
145 |
image_refs = None
|
146 |
-
if
|
|
|
|
|
|
|
|
|
147 |
video_guide = None
|
148 |
-
if
|
|
|
|
|
|
|
|
|
149 |
video_mask = None
|
|
|
|
|
|
|
150 |
|
151 |
if isinstance(image_refs, list):
|
152 |
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
@@ -260,7 +275,7 @@ def add_video_task(**inputs):
|
|
260 |
queue = gen["queue"]
|
261 |
task_id += 1
|
262 |
current_task_id = task_id
|
263 |
-
inputs_to_query = ["image_start", "image_end", "
|
264 |
start_image_data = None
|
265 |
end_image_data = None
|
266 |
for name in inputs_to_query:
|
@@ -718,7 +733,7 @@ if not Path(server_config_filename).is_file():
|
|
718 |
"transformer_types": [],
|
719 |
"transformer_quantization": "int8",
|
720 |
"text_encoder_filename" : text_encoder_choices[1],
|
721 |
-
"save_path": os.path.join(os.getcwd(),
|
722 |
"compile" : "",
|
723 |
"metadata_type": "metadata",
|
724 |
"default_ui": "t2v",
|
@@ -726,7 +741,7 @@ if not Path(server_config_filename).is_file():
|
|
726 |
"clear_file_list" : 0,
|
727 |
"vae_config": 0,
|
728 |
"profile" : profile_type.LowRAM_LowVRAM,
|
729 |
-
"
|
730 |
|
731 |
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
732 |
writer.write(json.dumps(server_config))
|
@@ -860,7 +875,7 @@ if len(args.vae_config) > 0:
|
|
860 |
reload_needed = False
|
861 |
default_ui = server_config.get("default_ui", "t2v")
|
862 |
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
863 |
-
|
864 |
|
865 |
|
866 |
if args.t2v_14B or args.t2v:
|
@@ -962,8 +977,8 @@ def download_models(transformer_filename, text_encoder_filename):
|
|
962 |
|
963 |
from huggingface_hub import hf_hub_download, snapshot_download
|
964 |
repoId = "DeepBeepMeep/Wan2.1"
|
965 |
-
sourceFolderList = ["xlm-roberta-large", "", ]
|
966 |
-
fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
967 |
targetRoot = "ckpts/"
|
968 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
969 |
if len(files)==0:
|
@@ -1166,7 +1181,7 @@ def load_models(model_filename):
|
|
1166 |
|
1167 |
return wan_model, offloadobj, pipe["transformer"]
|
1168 |
|
1169 |
-
if
|
1170 |
wan_model, offloadobj, transformer = None, None, None
|
1171 |
reload_needed = True
|
1172 |
else:
|
@@ -1254,7 +1269,7 @@ def apply_changes( state,
|
|
1254 |
quantization_choice,
|
1255 |
boost_choice = 1,
|
1256 |
clear_file_list = 0,
|
1257 |
-
|
1258 |
):
|
1259 |
if args.lock_config:
|
1260 |
return
|
@@ -1272,7 +1287,7 @@ def apply_changes( state,
|
|
1272 |
"transformer_quantization" : quantization_choice,
|
1273 |
"boost" : boost_choice,
|
1274 |
"clear_file_list" : clear_file_list,
|
1275 |
-
"
|
1276 |
}
|
1277 |
|
1278 |
if Path(server_config_filename).is_file():
|
@@ -1295,14 +1310,14 @@ def apply_changes( state,
|
|
1295 |
if v != v_old:
|
1296 |
changes.append(k)
|
1297 |
|
1298 |
-
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed,
|
1299 |
attention_mode = server_config["attention_mode"]
|
1300 |
profile = server_config["profile"]
|
1301 |
compile = server_config["compile"]
|
1302 |
text_encoder_filename = server_config["text_encoder_filename"]
|
1303 |
vae_config = server_config["vae_config"]
|
1304 |
boost = server_config["boost"]
|
1305 |
-
|
1306 |
transformer_quantization = server_config["transformer_quantization"]
|
1307 |
transformer_types = server_config["transformer_types"]
|
1308 |
transformer_type = get_model_type(transformer_filename)
|
@@ -1381,7 +1396,8 @@ def abort_generation(state):
|
|
1381 |
|
1382 |
gen["abort"] = True
|
1383 |
gen["extra_orders"] = 0
|
1384 |
-
wan_model
|
|
|
1385 |
msg = "Processing Request to abort Current Generation"
|
1386 |
gr.Info(msg)
|
1387 |
return msg, gr.Button(interactive= False)
|
@@ -1480,24 +1496,68 @@ def expand_slist(slist, num_inference_steps ):
|
|
1480 |
return new_slist
|
1481 |
def convert_image(image):
|
1482 |
|
1483 |
-
from PIL import
|
1484 |
from typing import cast
|
1485 |
|
1486 |
return cast(Image, ImageOps.exif_transpose(image))
|
1487 |
-
|
1488 |
-
|
1489 |
-
|
1490 |
-
|
1491 |
-
|
1492 |
-
|
1493 |
-
|
1494 |
-
|
1495 |
-
|
1496 |
-
|
1497 |
-
|
1498 |
-
|
1499 |
-
|
1500 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1501 |
|
1502 |
def generate_video(
|
1503 |
task_id,
|
@@ -1551,7 +1611,7 @@ def generate_video(
|
|
1551 |
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
|
1552 |
# return
|
1553 |
|
1554 |
-
if
|
1555 |
while wan_model == None:
|
1556 |
time.sleep(1)
|
1557 |
|
@@ -1681,10 +1741,32 @@ def generate_video(
|
|
1681 |
raise gr.Error("Teacache not supported for this model")
|
1682 |
|
1683 |
if "Vace" in model_filename:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1684 |
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
|
1685 |
[video_mask],
|
1686 |
[image_refs],
|
1687 |
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
|
|
|
1688 |
trim_video=max_frames)
|
1689 |
else:
|
1690 |
src_video, src_mask, src_ref_images = None, None, None
|
@@ -2539,9 +2621,9 @@ def fill_inputs(state):
|
|
2539 |
|
2540 |
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
|
2541 |
|
2542 |
-
def
|
2543 |
global reload_needed, wan_model, offloadobj
|
2544 |
-
if
|
2545 |
model_filename = state["model_filename"]
|
2546 |
if state["model_filename"] != transformer_filename:
|
2547 |
wan_model = None
|
@@ -2558,7 +2640,7 @@ def preload_model(state):
|
|
2558 |
|
2559 |
def unload_model_if_needed(state):
|
2560 |
global reload_needed, wan_model, offloadobj
|
2561 |
-
if
|
2562 |
if wan_model != None:
|
2563 |
wan_model = None
|
2564 |
if offloadobj is not None:
|
@@ -2567,7 +2649,39 @@ def unload_model_if_needed(state):
|
|
2567 |
gc.collect()
|
2568 |
reload_needed= True
|
2569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2570 |
|
|
|
2571 |
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
|
2572 |
global inputs_names #, advanced
|
2573 |
|
@@ -2676,19 +2790,36 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
2676 |
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
2677 |
|
2678 |
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
|
2679 |
-
|
2680 |
-
|
2681 |
-
|
2682 |
-
|
2683 |
-
|
2684 |
-
|
2685 |
-
|
2686 |
-
|
2687 |
-
|
2688 |
-
|
2689 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2690 |
|
2691 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2692 |
|
2693 |
|
2694 |
advanced_prompt = advanced_ui
|
@@ -2923,7 +3054,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
2923 |
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
|
2924 |
|
2925 |
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
|
2926 |
-
video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
|
|
|
|
|
|
|
2927 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
2928 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
2929 |
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
|
@@ -2967,7 +3101,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
|
|
2967 |
).then(fn= fill_inputs,
|
2968 |
inputs=[state],
|
2969 |
outputs=gen_inputs + extra_inputs
|
2970 |
-
).then(fn=
|
2971 |
inputs=[state],
|
2972 |
outputs=[gen_status])
|
2973 |
|
@@ -3149,14 +3283,8 @@ def generate_configuration_tab(header, model_choice):
|
|
3149 |
value=server_config.get("metadata_type", "metadata"),
|
3150 |
label="Metadata Handling"
|
3151 |
)
|
3152 |
-
|
3153 |
-
|
3154 |
-
("Load Model When Starting the App and Changing Model if Model Changed", 1),
|
3155 |
-
("Load Model When Starting the App and Pressing Generate if Model Changed", 2),
|
3156 |
-
("Load Model When Pressing Generate if Model Changed", 3),
|
3157 |
-
("Load Model When Pressing Generate and Unload Model when Finished", 4),
|
3158 |
-
],
|
3159 |
-
value=server_config.get("reload_model",2),
|
3160 |
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
|
3161 |
)
|
3162 |
|
@@ -3191,7 +3319,7 @@ def generate_configuration_tab(header, model_choice):
|
|
3191 |
quantization_choice,
|
3192 |
boost_choice,
|
3193 |
clear_file_list_choice,
|
3194 |
-
|
3195 |
],
|
3196 |
outputs= [msg , header, model_choice]
|
3197 |
)
|
@@ -3201,10 +3329,16 @@ def generate_about_tab():
|
|
3201 |
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
|
3202 |
gr.Markdown("Many thanks to:")
|
3203 |
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
|
|
|
3204 |
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
|
3205 |
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
|
3206 |
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
|
3207 |
-
gr.Markdown("- <B>Remade_AI</B> : for
|
|
|
|
|
|
|
|
|
|
|
3208 |
|
3209 |
def generate_info_tab():
|
3210 |
gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
|
@@ -3231,17 +3365,26 @@ def generate_dropdown_model_list():
|
|
3231 |
choices= dropdown_choices,
|
3232 |
value= current_model_type,
|
3233 |
show_label= False,
|
3234 |
-
scale= 2
|
|
|
|
|
3235 |
)
|
3236 |
|
3237 |
|
3238 |
|
3239 |
def create_demo():
|
3240 |
css = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3241 |
.title-with-lines {
|
3242 |
display: flex;
|
3243 |
align-items: center;
|
3244 |
-
margin:
|
3245 |
}
|
3246 |
.line {
|
3247 |
flex-grow: 1;
|
@@ -3462,7 +3605,7 @@ def create_demo():
|
|
3462 |
pointer-events: none;
|
3463 |
}
|
3464 |
"""
|
3465 |
-
with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
3466 |
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
3467 |
global model_list
|
3468 |
|
|
|
141 |
res = VACE_SIZE_CONFIGS.keys().join(" and ")
|
142 |
gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
|
143 |
return
|
144 |
+
if "I" in video_prompt_type:
|
145 |
+
if image_refs == None:
|
146 |
+
gr.Info("You must provide at one Refererence Image")
|
147 |
+
return
|
148 |
+
else:
|
149 |
image_refs = None
|
150 |
+
if "V" in video_prompt_type:
|
151 |
+
if video_guide == None:
|
152 |
+
gr.Info("You must provide a Control Video")
|
153 |
+
return
|
154 |
+
else:
|
155 |
video_guide = None
|
156 |
+
if "M" in video_prompt_type:
|
157 |
+
if video_mask == None:
|
158 |
+
gr.Info("You must provide a Video Mask ")
|
159 |
+
return
|
160 |
+
else:
|
161 |
video_mask = None
|
162 |
+
if "O" in video_prompt_type and inputs["max_frames"]==0:
|
163 |
+
gr.Info(f"In order to extend a video, you need to indicate how many frames you want to reuse in the source video.")
|
164 |
+
return
|
165 |
|
166 |
if isinstance(image_refs, list):
|
167 |
image_refs = [ convert_image(tup[0]) for tup in image_refs ]
|
|
|
275 |
queue = gen["queue"]
|
276 |
task_id += 1
|
277 |
current_task_id = task_id
|
278 |
+
inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
|
279 |
start_image_data = None
|
280 |
end_image_data = None
|
281 |
for name in inputs_to_query:
|
|
|
733 |
"transformer_types": [],
|
734 |
"transformer_quantization": "int8",
|
735 |
"text_encoder_filename" : text_encoder_choices[1],
|
736 |
+
"save_path": "outputs", #os.path.join(os.getcwd(),
|
737 |
"compile" : "",
|
738 |
"metadata_type": "metadata",
|
739 |
"default_ui": "t2v",
|
|
|
741 |
"clear_file_list" : 0,
|
742 |
"vae_config": 0,
|
743 |
"profile" : profile_type.LowRAM_LowVRAM,
|
744 |
+
"preload_model_policy": [] }
|
745 |
|
746 |
with open(server_config_filename, "w", encoding="utf-8") as writer:
|
747 |
writer.write(json.dumps(server_config))
|
|
|
875 |
reload_needed = False
|
876 |
default_ui = server_config.get("default_ui", "t2v")
|
877 |
save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
|
878 |
+
preload_model_policy = server_config.get("preload_model_policy", [])
|
879 |
|
880 |
|
881 |
if args.t2v_14B or args.t2v:
|
|
|
977 |
|
978 |
from huggingface_hub import hf_hub_download, snapshot_download
|
979 |
repoId = "DeepBeepMeep/Wan2.1"
|
980 |
+
sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
|
981 |
+
fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
|
982 |
targetRoot = "ckpts/"
|
983 |
for sourceFolder, files in zip(sourceFolderList,fileList ):
|
984 |
if len(files)==0:
|
|
|
1181 |
|
1182 |
return wan_model, offloadobj, pipe["transformer"]
|
1183 |
|
1184 |
+
if not "P" in preload_model_policy:
|
1185 |
wan_model, offloadobj, transformer = None, None, None
|
1186 |
reload_needed = True
|
1187 |
else:
|
|
|
1269 |
quantization_choice,
|
1270 |
boost_choice = 1,
|
1271 |
clear_file_list = 0,
|
1272 |
+
preload_model_policy_choice = 1,
|
1273 |
):
|
1274 |
if args.lock_config:
|
1275 |
return
|
|
|
1287 |
"transformer_quantization" : quantization_choice,
|
1288 |
"boost" : boost_choice,
|
1289 |
"clear_file_list" : clear_file_list,
|
1290 |
+
"preload_model_policy" : preload_model_policy_choice,
|
1291 |
}
|
1292 |
|
1293 |
if Path(server_config_filename).is_file():
|
|
|
1310 |
if v != v_old:
|
1311 |
changes.append(k)
|
1312 |
|
1313 |
+
global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
|
1314 |
attention_mode = server_config["attention_mode"]
|
1315 |
profile = server_config["profile"]
|
1316 |
compile = server_config["compile"]
|
1317 |
text_encoder_filename = server_config["text_encoder_filename"]
|
1318 |
vae_config = server_config["vae_config"]
|
1319 |
boost = server_config["boost"]
|
1320 |
+
preload_model_policy = server_config["preload_model_policy"]
|
1321 |
transformer_quantization = server_config["transformer_quantization"]
|
1322 |
transformer_types = server_config["transformer_types"]
|
1323 |
transformer_type = get_model_type(transformer_filename)
|
|
|
1396 |
|
1397 |
gen["abort"] = True
|
1398 |
gen["extra_orders"] = 0
|
1399 |
+
if wan_model != None:
|
1400 |
+
wan_model._interrupt= True
|
1401 |
msg = "Processing Request to abort Current Generation"
|
1402 |
gr.Info(msg)
|
1403 |
return msg, gr.Button(interactive= False)
|
|
|
1496 |
return new_slist
|
1497 |
def convert_image(image):
|
1498 |
|
1499 |
+
from PIL import ImageOps
|
1500 |
from typing import cast
|
1501 |
|
1502 |
return cast(Image, ImageOps.exif_transpose(image))
|
1503 |
+
|
1504 |
+
|
1505 |
+
def preprocess_video(process_type, height, width, video_in, max_frames):
|
1506 |
+
|
1507 |
+
from wan.utils.utils import resample
|
1508 |
+
|
1509 |
+
import decord
|
1510 |
+
decord.bridge.set_bridge('torch')
|
1511 |
+
reader = decord.VideoReader(video_in)
|
1512 |
+
|
1513 |
+
fps = reader.get_avg_fps()
|
1514 |
+
|
1515 |
+
frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
|
1516 |
+
frames_list = reader.get_batch(frame_nos)
|
1517 |
+
frame_height, frame_width, _ = frames_list[0].shape
|
1518 |
+
|
1519 |
+
scale = ((height * width ) / (frame_height * frame_width))**(1/2)
|
1520 |
+
# scale = min(height / frame_height, width / frame_width)
|
1521 |
+
|
1522 |
+
new_height = (int(frame_height * scale) // 16) * 16
|
1523 |
+
new_width = (int(frame_width * scale) // 16) * 16
|
1524 |
+
|
1525 |
+
processed_frames_list = []
|
1526 |
+
for frame in frames_list:
|
1527 |
+
frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
|
1528 |
+
frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
|
1529 |
+
processed_frames_list.append(frame)
|
1530 |
+
|
1531 |
+
if process_type=="pose":
|
1532 |
+
from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
|
1533 |
+
cfg_dict = {
|
1534 |
+
"DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
|
1535 |
+
"POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
|
1536 |
+
"RESIZE_SIZE": 1024
|
1537 |
+
}
|
1538 |
+
anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
|
1539 |
+
elif process_type=="depth":
|
1540 |
+
from preprocessing.midas.depth import DepthVideoAnnotator
|
1541 |
+
cfg_dict = {
|
1542 |
+
"PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
|
1543 |
+
}
|
1544 |
+
anno_ins = DepthVideoAnnotator(cfg_dict)
|
1545 |
+
else:
|
1546 |
+
from preprocessing.gray import GrayVideoAnnotator
|
1547 |
+
cfg_dict = {}
|
1548 |
+
anno_ins = GrayVideoAnnotator(cfg_dict)
|
1549 |
+
|
1550 |
+
np_frames = anno_ins.forward(processed_frames_list)
|
1551 |
+
|
1552 |
+
# from preprocessing.dwpose.pose import save_one_video
|
1553 |
+
# save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
|
1554 |
+
|
1555 |
+
torch_frames = []
|
1556 |
+
for np_frame in np_frames:
|
1557 |
+
torch_frame = torch.from_numpy(np_frame)
|
1558 |
+
torch_frames.append(torch_frame)
|
1559 |
+
|
1560 |
+
return torch.stack(torch_frames)
|
1561 |
|
1562 |
def generate_video(
|
1563 |
task_id,
|
|
|
1611 |
# gr.Info("Unable to generate a Video while a new configuration is being applied.")
|
1612 |
# return
|
1613 |
|
1614 |
+
if "P" in preload_model_policy:
|
1615 |
while wan_model == None:
|
1616 |
time.sleep(1)
|
1617 |
|
|
|
1741 |
raise gr.Error("Teacache not supported for this model")
|
1742 |
|
1743 |
if "Vace" in model_filename:
|
1744 |
+
# video_prompt_type = video_prompt_type +"G"
|
1745 |
+
if any(process in video_prompt_type for process in ("P", "D", "G")) :
|
1746 |
+
prompts_max = gen["prompts_max"]
|
1747 |
+
|
1748 |
+
status = get_generation_status(prompt_no, prompts_max, 1, 1)
|
1749 |
+
preprocess_type = None
|
1750 |
+
if "P" in video_prompt_type :
|
1751 |
+
progress_args = [0, status + " - Extracting Open Pose Information"]
|
1752 |
+
preprocess_type = "pose"
|
1753 |
+
elif "D" in video_prompt_type :
|
1754 |
+
progress_args = [0, status + " - Extracting Depth Information"]
|
1755 |
+
preprocess_type = "depth"
|
1756 |
+
elif "G" in video_prompt_type :
|
1757 |
+
progress_args = [0, status + " - Extracting Gray Level Information"]
|
1758 |
+
preprocess_type = "gray"
|
1759 |
+
|
1760 |
+
if preprocess_type != None :
|
1761 |
+
progress(*progress_args )
|
1762 |
+
gen["progress_args"] = progress_args
|
1763 |
+
video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
|
1764 |
+
|
1765 |
src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
|
1766 |
[video_mask],
|
1767 |
[image_refs],
|
1768 |
video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
|
1769 |
+
original_video= "O" in video_prompt_type,
|
1770 |
trim_video=max_frames)
|
1771 |
else:
|
1772 |
src_video, src_mask, src_ref_images = None, None, None
|
|
|
2621 |
|
2622 |
return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
|
2623 |
|
2624 |
+
def preload_model_when_switching(state):
|
2625 |
global reload_needed, wan_model, offloadobj
|
2626 |
+
if "S" in preload_model_policy:
|
2627 |
model_filename = state["model_filename"]
|
2628 |
if state["model_filename"] != transformer_filename:
|
2629 |
wan_model = None
|
|
|
2640 |
|
2641 |
def unload_model_if_needed(state):
|
2642 |
global reload_needed, wan_model, offloadobj
|
2643 |
+
if "U" in preload_model_policy:
|
2644 |
if wan_model != None:
|
2645 |
wan_model = None
|
2646 |
if offloadobj is not None:
|
|
|
2649 |
gc.collect()
|
2650 |
reload_needed= True
|
2651 |
|
2652 |
+
def filter_letters(source_str, letters):
|
2653 |
+
ret = ""
|
2654 |
+
for letter in letters:
|
2655 |
+
if letter in source_str:
|
2656 |
+
ret += letter
|
2657 |
+
return ret
|
2658 |
+
|
2659 |
+
def add_to_sequence(source_str, letters):
|
2660 |
+
ret = source_str
|
2661 |
+
for letter in letters:
|
2662 |
+
if not letter in source_str:
|
2663 |
+
ret += letter
|
2664 |
+
return ret
|
2665 |
+
|
2666 |
+
def del_in_sequence(source_str, letters):
|
2667 |
+
ret = source_str
|
2668 |
+
for letter in letters:
|
2669 |
+
if letter in source_str:
|
2670 |
+
ret = ret.replace(letter, "")
|
2671 |
+
return ret
|
2672 |
+
|
2673 |
+
|
2674 |
+
def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
|
2675 |
+
video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I")
|
2676 |
+
return video_prompt_type, gr.update(visible = video_prompt_type_image_refs),gr.update(visible = video_prompt_type_image_refs)
|
2677 |
+
|
2678 |
+
def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_video_guide):
|
2679 |
+
video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
|
2680 |
+
video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
|
2681 |
+
visible = "V" in video_prompt_type
|
2682 |
+
return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
|
2683 |
|
2684 |
+
|
2685 |
def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
|
2686 |
global inputs_names #, advanced
|
2687 |
|
|
|
2790 |
image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
|
2791 |
|
2792 |
with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
|
2793 |
+
video_prompt_type_value= ui_defaults.get("video_prompt_type","")
|
2794 |
+
video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
|
2795 |
+
video_prompt_type_video_guide = gr.Dropdown(
|
2796 |
+
choices=[
|
2797 |
+
("None, use only the Text Prompt", ""),
|
2798 |
+
("Extend the Control Video", "OV"),
|
2799 |
+
("Transfer Human Motion from the Control Video", "PV"),
|
2800 |
+
("Transfer Depth from the Control Video", "DV"),
|
2801 |
+
("Recolorize the Control Video", "CV"),
|
2802 |
+
("Control Video contains Open Pose, Depth or Black & White ", "V"),
|
2803 |
+
("Inpainting of Control Video using Mask Video ", "MV"),
|
2804 |
+
],
|
2805 |
+
value=filter_letters(video_prompt_type_value, "ODPCMV"),
|
2806 |
+
label="Video to Video"
|
2807 |
+
)
|
2808 |
+
video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
|
2809 |
+
|
2810 |
+
video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
|
2811 |
+
max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
|
2812 |
|
2813 |
+
image_refs = gr.Gallery( label ="Reference Images",
|
2814 |
+
type ="pil", show_label= True,
|
2815 |
+
columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
|
2816 |
+
value= ui_defaults.get("image_refs", None) )
|
2817 |
+
|
2818 |
+
# with gr.Row():
|
2819 |
+
remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
|
2820 |
+
|
2821 |
+
|
2822 |
+
video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
|
2823 |
|
2824 |
|
2825 |
advanced_prompt = advanced_ui
|
|
|
3054 |
target_settings = gr.Text(value = "settings", interactive= False, visible= False)
|
3055 |
|
3056 |
image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
|
3057 |
+
# video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
|
3058 |
+
video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
|
3059 |
+
video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
|
3060 |
+
|
3061 |
show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
|
3062 |
fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
|
3063 |
queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
|
|
|
3101 |
).then(fn= fill_inputs,
|
3102 |
inputs=[state],
|
3103 |
outputs=gen_inputs + extra_inputs
|
3104 |
+
).then(fn= preload_model_when_switching,
|
3105 |
inputs=[state],
|
3106 |
outputs=[gen_status])
|
3107 |
|
|
|
3283 |
value=server_config.get("metadata_type", "metadata"),
|
3284 |
label="Metadata Handling"
|
3285 |
)
|
3286 |
+
preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
|
3287 |
+
value=server_config.get("preload_model_policy",[]),
|
|
|
|
|
|
|
|
|
|
|
|
|
3288 |
label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
|
3289 |
)
|
3290 |
|
|
|
3319 |
quantization_choice,
|
3320 |
boost_choice,
|
3321 |
clear_file_list_choice,
|
3322 |
+
preload_model_policy_choice,
|
3323 |
],
|
3324 |
outputs= [msg , header, model_choice]
|
3325 |
)
|
|
|
3329 |
gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
|
3330 |
gr.Markdown("Many thanks to:")
|
3331 |
gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
|
3332 |
+
gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
|
3333 |
gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
|
3334 |
gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
|
3335 |
gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
|
3336 |
+
gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
|
3337 |
+
gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
|
3338 |
+
gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
|
3339 |
+
gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
|
3340 |
+
gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
|
3341 |
+
|
3342 |
|
3343 |
def generate_info_tab():
|
3344 |
gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
|
|
|
3365 |
choices= dropdown_choices,
|
3366 |
value= current_model_type,
|
3367 |
show_label= False,
|
3368 |
+
scale= 2,
|
3369 |
+
elem_id="model_list",
|
3370 |
+
elem_classes="model_list_class",
|
3371 |
)
|
3372 |
|
3373 |
|
3374 |
|
3375 |
def create_demo():
|
3376 |
css = """
|
3377 |
+
#model_list{
|
3378 |
+
background-color:black;
|
3379 |
+
padding:1px}
|
3380 |
+
|
3381 |
+
#model_list input {
|
3382 |
+
font-size:25px}
|
3383 |
+
|
3384 |
.title-with-lines {
|
3385 |
display: flex;
|
3386 |
align-items: center;
|
3387 |
+
margin: 25px 0;
|
3388 |
}
|
3389 |
.line {
|
3390 |
flex-grow: 1;
|
|
|
3605 |
pointer-events: none;
|
3606 |
}
|
3607 |
"""
|
3608 |
+
with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
|
3609 |
gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
|
3610 |
global model_list
|
3611 |
|