DeepBeepMeep commited on
Commit
5747c0d
·
1 Parent(s): 12652e0

Simplified Vace, added auto open pose and depth extrators

Browse files
README.md CHANGED
@@ -14,7 +14,7 @@
14
 
15
 
16
  ## 🔥 Latest News!!
17
- * April 4 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
18
  - A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
19
  - Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
20
  - Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
@@ -272,23 +272,24 @@ You can define multiple lines of macros. If there is only one macro line, the ap
272
 
273
  ### VACE ControlNet introduction
274
 
275
- Vace is a ControlNet 1.3B text2video model that allows you on top of a text prompt to provide visual hints to guide the generation. It can do more things than image2video although it is not as good for just starting a video with an image because it only a 1.3B model (in fact 3B) versus 14B and (it is not specialized for start frames). However, with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ...
276
 
277
  First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
278
 
279
  Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
280
- - reference Images: use this to inject people or objects in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
281
 
282
- - a Video: this can be a video that contains a body pose (an animated wireframe that indicates the positions of limbs of a person), a greyed depth map video, a normal video combined with a masked video (see below),... The Vace model will detect automatically what to do depending on the video content. You can tell WanGP to use only the n first frames of this Video. All the frames beyond and up the number of requested frames will be generated by following the Text prompt and the other visual hints (for instance reference images). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images. There
283
 
284
  - a Video Mask
285
- This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint.
286
 
287
 
288
  Examples:
289
- - Inject people and / objects into a scene describe by a text promtp: Ref. Images + text Prompt
290
- - Animate a character described in a text prompt: Body Pose Video + text Prompt
291
- - Animate a character of your choice : Ref Images + Body Pose Video + text Prompt
 
292
 
293
 
294
  There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).
 
14
 
15
 
16
  ## 🔥 Latest News!!
17
+ * April 9 2025: 👋 Wan 2.1GP v4.0: lots of goodies for you !
18
  - A new queuing system that lets you stack in a queue as many text2video and imag2video tasks as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...).
19
  - Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge you video by x2 or x4. Check these new advanced options.
20
  - Wan Vace Control Net support : with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... I have provided an introduction guide below.
 
272
 
273
  ### VACE ControlNet introduction
274
 
275
+ Vace is a ControlNet 1.3B text2video model that allows you to do Video to Video and Reference to Video (inject your own images into the output video). So with Vace you can inject in the scene people or objects of your choice, animate a person, perform inpainting or outpainting, continue a video, ...
276
 
277
  First you need to select the Vace 1.3B model in the Drop Down box at the top. Please note that Vace works well for the moment only with videos up to 5s (81 frames).
278
 
279
  Beside the usual Text Prompt, three new types of visual hints can be provided (and combined !):
280
+ - a Control Video: Based on your choice, you can decide to transfer the motion, the depth in a new Video. You can tell WanGP to use only the first n frames of Control Video and to extrapolate the rest. You can also do inpainting ). If the video contains area of grey color 127, they will be considered as masks and will be filled based on the Text prompt of the reference Images.
281
 
282
+ - reference Images: Use this to inject people or objects of your choice in the video. You can select multiple reference Images. The integration of the image is more efficient if the background is replaced by the full white color. You can do that with your preferred background remover or use the built in background remover by checking the box *Remove background*
283
 
284
  - a Video Mask
285
+ This offers a stronger mechanism to tell Vace which parts should be kept (black) or replaced (white). You can do as well inpainting / outpainting, fill the missing part of a video more efficientlty with just the video hint. If a video mask is white, it will be generated so with black frames at the beginning and at the end and the rest white, you could generate the missing frames in between.
286
 
287
 
288
  Examples:
289
+ - Inject people and / objects into a scene describe by a text prompt: Ref. Images + text Prompt
290
+ - Animate a character described in a text prompt: a Video of person moving + text Prompt
291
+ - Animate a character of your choice (pose transfer) : Ref Images + a Video of person moving + text Prompt
292
+ - Change the style of a scene (depth transfer): a Video that contains objects / person at differen depths + text Prompt
293
 
294
 
295
  There are lots of possible combinations. Some of them require to prepare some materials (masks on top of video, full masks, etc...).
preprocessing/dwpose/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
preprocessing/dwpose/onnxdet.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import numpy as np
5
+
6
+ import onnxruntime
7
+
8
+ def nms(boxes, scores, nms_thr):
9
+ """Single class NMS implemented in Numpy."""
10
+ x1 = boxes[:, 0]
11
+ y1 = boxes[:, 1]
12
+ x2 = boxes[:, 2]
13
+ y2 = boxes[:, 3]
14
+
15
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
16
+ order = scores.argsort()[::-1]
17
+
18
+ keep = []
19
+ while order.size > 0:
20
+ i = order[0]
21
+ keep.append(i)
22
+ xx1 = np.maximum(x1[i], x1[order[1:]])
23
+ yy1 = np.maximum(y1[i], y1[order[1:]])
24
+ xx2 = np.minimum(x2[i], x2[order[1:]])
25
+ yy2 = np.minimum(y2[i], y2[order[1:]])
26
+
27
+ w = np.maximum(0.0, xx2 - xx1 + 1)
28
+ h = np.maximum(0.0, yy2 - yy1 + 1)
29
+ inter = w * h
30
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
31
+
32
+ inds = np.where(ovr <= nms_thr)[0]
33
+ order = order[inds + 1]
34
+
35
+ return keep
36
+
37
+ def multiclass_nms(boxes, scores, nms_thr, score_thr):
38
+ """Multiclass NMS implemented in Numpy. Class-aware version."""
39
+ final_dets = []
40
+ num_classes = scores.shape[1]
41
+ for cls_ind in range(num_classes):
42
+ cls_scores = scores[:, cls_ind]
43
+ valid_score_mask = cls_scores > score_thr
44
+ if valid_score_mask.sum() == 0:
45
+ continue
46
+ else:
47
+ valid_scores = cls_scores[valid_score_mask]
48
+ valid_boxes = boxes[valid_score_mask]
49
+ keep = nms(valid_boxes, valid_scores, nms_thr)
50
+ if len(keep) > 0:
51
+ cls_inds = np.ones((len(keep), 1)) * cls_ind
52
+ dets = np.concatenate(
53
+ [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
54
+ )
55
+ final_dets.append(dets)
56
+ if len(final_dets) == 0:
57
+ return None
58
+ return np.concatenate(final_dets, 0)
59
+
60
+ def demo_postprocess(outputs, img_size, p6=False):
61
+ grids = []
62
+ expanded_strides = []
63
+ strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
64
+
65
+ hsizes = [img_size[0] // stride for stride in strides]
66
+ wsizes = [img_size[1] // stride for stride in strides]
67
+
68
+ for hsize, wsize, stride in zip(hsizes, wsizes, strides):
69
+ xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
70
+ grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
71
+ grids.append(grid)
72
+ shape = grid.shape[:2]
73
+ expanded_strides.append(np.full((*shape, 1), stride))
74
+
75
+ grids = np.concatenate(grids, 1)
76
+ expanded_strides = np.concatenate(expanded_strides, 1)
77
+ outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
78
+ outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
79
+
80
+ return outputs
81
+
82
+ def preprocess(img, input_size, swap=(2, 0, 1)):
83
+ if len(img.shape) == 3:
84
+ padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
85
+ else:
86
+ padded_img = np.ones(input_size, dtype=np.uint8) * 114
87
+
88
+ r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
89
+ resized_img = cv2.resize(
90
+ img,
91
+ (int(img.shape[1] * r), int(img.shape[0] * r)),
92
+ interpolation=cv2.INTER_LINEAR,
93
+ ).astype(np.uint8)
94
+ padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
95
+
96
+ padded_img = padded_img.transpose(swap)
97
+ padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
98
+ return padded_img, r
99
+
100
+ def inference_detector(session, oriImg):
101
+ input_shape = (640,640)
102
+ img, ratio = preprocess(oriImg, input_shape)
103
+
104
+ ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
105
+ output = session.run(None, ort_inputs)
106
+ predictions = demo_postprocess(output[0], input_shape)[0]
107
+
108
+ boxes = predictions[:, :4]
109
+ scores = predictions[:, 4:5] * predictions[:, 5:]
110
+
111
+ boxes_xyxy = np.ones_like(boxes)
112
+ boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
113
+ boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
114
+ boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
115
+ boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
116
+ boxes_xyxy /= ratio
117
+ dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
118
+ if dets is not None:
119
+ final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
120
+ isscore = final_scores>0.3
121
+ iscat = final_cls_inds == 0
122
+ isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
123
+ final_boxes = final_boxes[isbbox]
124
+ else:
125
+ final_boxes = np.array([])
126
+
127
+ return final_boxes
preprocessing/dwpose/onnxpose.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ from typing import List, Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ def preprocess(
10
+ img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256)
11
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
12
+ """Do preprocessing for RTMPose model inference.
13
+
14
+ Args:
15
+ img (np.ndarray): Input image in shape.
16
+ input_size (tuple): Input image size in shape (w, h).
17
+
18
+ Returns:
19
+ tuple:
20
+ - resized_img (np.ndarray): Preprocessed image.
21
+ - center (np.ndarray): Center of image.
22
+ - scale (np.ndarray): Scale of image.
23
+ """
24
+ # get shape of image
25
+ img_shape = img.shape[:2]
26
+ out_img, out_center, out_scale = [], [], []
27
+ if len(out_bbox) == 0:
28
+ out_bbox = [[0, 0, img_shape[1], img_shape[0]]]
29
+ for i in range(len(out_bbox)):
30
+ x0 = out_bbox[i][0]
31
+ y0 = out_bbox[i][1]
32
+ x1 = out_bbox[i][2]
33
+ y1 = out_bbox[i][3]
34
+ bbox = np.array([x0, y0, x1, y1])
35
+
36
+ # get center and scale
37
+ center, scale = bbox_xyxy2cs(bbox, padding=1.25)
38
+
39
+ # do affine transformation
40
+ resized_img, scale = top_down_affine(input_size, scale, center, img)
41
+
42
+ # normalize image
43
+ mean = np.array([123.675, 116.28, 103.53])
44
+ std = np.array([58.395, 57.12, 57.375])
45
+ resized_img = (resized_img - mean) / std
46
+
47
+ out_img.append(resized_img)
48
+ out_center.append(center)
49
+ out_scale.append(scale)
50
+
51
+ return out_img, out_center, out_scale
52
+
53
+
54
+ def inference(sess: ort.InferenceSession, img: np.ndarray) -> np.ndarray:
55
+ """Inference RTMPose model.
56
+
57
+ Args:
58
+ sess (ort.InferenceSession): ONNXRuntime session.
59
+ img (np.ndarray): Input image in shape.
60
+
61
+ Returns:
62
+ outputs (np.ndarray): Output of RTMPose model.
63
+ """
64
+ all_out = []
65
+ # build input
66
+ for i in range(len(img)):
67
+ input = [img[i].transpose(2, 0, 1)]
68
+
69
+ # build output
70
+ sess_input = {sess.get_inputs()[0].name: input}
71
+ sess_output = []
72
+ for out in sess.get_outputs():
73
+ sess_output.append(out.name)
74
+
75
+ # run model
76
+ outputs = sess.run(sess_output, sess_input)
77
+ all_out.append(outputs)
78
+
79
+ return all_out
80
+
81
+
82
+ def postprocess(outputs: List[np.ndarray],
83
+ model_input_size: Tuple[int, int],
84
+ center: Tuple[int, int],
85
+ scale: Tuple[int, int],
86
+ simcc_split_ratio: float = 2.0
87
+ ) -> Tuple[np.ndarray, np.ndarray]:
88
+ """Postprocess for RTMPose model output.
89
+
90
+ Args:
91
+ outputs (np.ndarray): Output of RTMPose model.
92
+ model_input_size (tuple): RTMPose model Input image size.
93
+ center (tuple): Center of bbox in shape (x, y).
94
+ scale (tuple): Scale of bbox in shape (w, h).
95
+ simcc_split_ratio (float): Split ratio of simcc.
96
+
97
+ Returns:
98
+ tuple:
99
+ - keypoints (np.ndarray): Rescaled keypoints.
100
+ - scores (np.ndarray): Model predict scores.
101
+ """
102
+ all_key = []
103
+ all_score = []
104
+ for i in range(len(outputs)):
105
+ # use simcc to decode
106
+ simcc_x, simcc_y = outputs[i]
107
+ keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio)
108
+
109
+ # rescale keypoints
110
+ keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2
111
+ all_key.append(keypoints[0])
112
+ all_score.append(scores[0])
113
+
114
+ return np.array(all_key), np.array(all_score)
115
+
116
+
117
+ def bbox_xyxy2cs(bbox: np.ndarray,
118
+ padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]:
119
+ """Transform the bbox format from (x,y,w,h) into (center, scale)
120
+
121
+ Args:
122
+ bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted
123
+ as (left, top, right, bottom)
124
+ padding (float): BBox padding factor that will be multilied to scale.
125
+ Default: 1.0
126
+
127
+ Returns:
128
+ tuple: A tuple containing center and scale.
129
+ - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or
130
+ (n, 2)
131
+ - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or
132
+ (n, 2)
133
+ """
134
+ # convert single bbox from (4, ) to (1, 4)
135
+ dim = bbox.ndim
136
+ if dim == 1:
137
+ bbox = bbox[None, :]
138
+
139
+ # get bbox center and scale
140
+ x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3])
141
+ center = np.hstack([x1 + x2, y1 + y2]) * 0.5
142
+ scale = np.hstack([x2 - x1, y2 - y1]) * padding
143
+
144
+ if dim == 1:
145
+ center = center[0]
146
+ scale = scale[0]
147
+
148
+ return center, scale
149
+
150
+
151
+ def _fix_aspect_ratio(bbox_scale: np.ndarray,
152
+ aspect_ratio: float) -> np.ndarray:
153
+ """Extend the scale to match the given aspect ratio.
154
+
155
+ Args:
156
+ scale (np.ndarray): The image scale (w, h) in shape (2, )
157
+ aspect_ratio (float): The ratio of ``w/h``
158
+
159
+ Returns:
160
+ np.ndarray: The reshaped image scale in (2, )
161
+ """
162
+ w, h = np.hsplit(bbox_scale, [1])
163
+ bbox_scale = np.where(w > h * aspect_ratio,
164
+ np.hstack([w, w / aspect_ratio]),
165
+ np.hstack([h * aspect_ratio, h]))
166
+ return bbox_scale
167
+
168
+
169
+ def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray:
170
+ """Rotate a point by an angle.
171
+
172
+ Args:
173
+ pt (np.ndarray): 2D point coordinates (x, y) in shape (2, )
174
+ angle_rad (float): rotation angle in radian
175
+
176
+ Returns:
177
+ np.ndarray: Rotated point in shape (2, )
178
+ """
179
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
180
+ rot_mat = np.array([[cs, -sn], [sn, cs]])
181
+ return rot_mat @ pt
182
+
183
+
184
+ def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray:
185
+ """To calculate the affine matrix, three pairs of points are required. This
186
+ function is used to get the 3rd point, given 2D points a & b.
187
+
188
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
189
+ anticlockwise, using b as the rotation center.
190
+
191
+ Args:
192
+ a (np.ndarray): The 1st point (x,y) in shape (2, )
193
+ b (np.ndarray): The 2nd point (x,y) in shape (2, )
194
+
195
+ Returns:
196
+ np.ndarray: The 3rd point.
197
+ """
198
+ direction = a - b
199
+ c = b + np.r_[-direction[1], direction[0]]
200
+ return c
201
+
202
+
203
+ def get_warp_matrix(center: np.ndarray,
204
+ scale: np.ndarray,
205
+ rot: float,
206
+ output_size: Tuple[int, int],
207
+ shift: Tuple[float, float] = (0., 0.),
208
+ inv: bool = False) -> np.ndarray:
209
+ """Calculate the affine transformation matrix that can warp the bbox area
210
+ in the input image to the output size.
211
+
212
+ Args:
213
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
214
+ scale (np.ndarray[2, ]): Scale of the bounding box
215
+ wrt [width, height].
216
+ rot (float): Rotation angle (degree).
217
+ output_size (np.ndarray[2, ] | list(2,)): Size of the
218
+ destination heatmaps.
219
+ shift (0-100%): Shift translation ratio wrt the width/height.
220
+ Default (0., 0.).
221
+ inv (bool): Option to inverse the affine transform direction.
222
+ (inv=False: src->dst or inv=True: dst->src)
223
+
224
+ Returns:
225
+ np.ndarray: A 2x3 transformation matrix
226
+ """
227
+ shift = np.array(shift)
228
+ src_w = scale[0]
229
+ dst_w = output_size[0]
230
+ dst_h = output_size[1]
231
+
232
+ # compute transformation matrix
233
+ rot_rad = np.deg2rad(rot)
234
+ src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad)
235
+ dst_dir = np.array([0., dst_w * -0.5])
236
+
237
+ # get four corners of the src rectangle in the original image
238
+ src = np.zeros((3, 2), dtype=np.float32)
239
+ src[0, :] = center + scale * shift
240
+ src[1, :] = center + src_dir + scale * shift
241
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
242
+
243
+ # get four corners of the dst rectangle in the input image
244
+ dst = np.zeros((3, 2), dtype=np.float32)
245
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
246
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
247
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
248
+
249
+ if inv:
250
+ warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src))
251
+ else:
252
+ warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst))
253
+
254
+ return warp_mat
255
+
256
+
257
+ def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict,
258
+ img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
259
+ """Get the bbox image as the model input by affine transform.
260
+
261
+ Args:
262
+ input_size (dict): The input size of the model.
263
+ bbox_scale (dict): The bbox scale of the img.
264
+ bbox_center (dict): The bbox center of the img.
265
+ img (np.ndarray): The original image.
266
+
267
+ Returns:
268
+ tuple: A tuple containing center and scale.
269
+ - np.ndarray[float32]: img after affine transform.
270
+ - np.ndarray[float32]: bbox scale after affine transform.
271
+ """
272
+ w, h = input_size
273
+ warp_size = (int(w), int(h))
274
+
275
+ # reshape bbox to fixed aspect ratio
276
+ bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h)
277
+
278
+ # get the affine matrix
279
+ center = bbox_center
280
+ scale = bbox_scale
281
+ rot = 0
282
+ warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h))
283
+
284
+ # do affine transform
285
+ img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR)
286
+
287
+ return img, bbox_scale
288
+
289
+
290
+ def get_simcc_maximum(simcc_x: np.ndarray,
291
+ simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
292
+ """Get maximum response location and value from simcc representations.
293
+
294
+ Note:
295
+ instance number: N
296
+ num_keypoints: K
297
+ heatmap height: H
298
+ heatmap width: W
299
+
300
+ Args:
301
+ simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx)
302
+ simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy)
303
+
304
+ Returns:
305
+ tuple:
306
+ - locs (np.ndarray): locations of maximum heatmap responses in shape
307
+ (K, 2) or (N, K, 2)
308
+ - vals (np.ndarray): values of maximum heatmap responses in shape
309
+ (K,) or (N, K)
310
+ """
311
+ N, K, Wx = simcc_x.shape
312
+ simcc_x = simcc_x.reshape(N * K, -1)
313
+ simcc_y = simcc_y.reshape(N * K, -1)
314
+
315
+ # get maximum value locations
316
+ x_locs = np.argmax(simcc_x, axis=1)
317
+ y_locs = np.argmax(simcc_y, axis=1)
318
+ locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32)
319
+ max_val_x = np.amax(simcc_x, axis=1)
320
+ max_val_y = np.amax(simcc_y, axis=1)
321
+
322
+ # get maximum value across x and y axis
323
+ mask = max_val_x > max_val_y
324
+ max_val_x[mask] = max_val_y[mask]
325
+ vals = max_val_x
326
+ locs[vals <= 0.] = -1
327
+
328
+ # reshape
329
+ locs = locs.reshape(N, K, 2)
330
+ vals = vals.reshape(N, K)
331
+
332
+ return locs, vals
333
+
334
+
335
+ def decode(simcc_x: np.ndarray, simcc_y: np.ndarray,
336
+ simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]:
337
+ """Modulate simcc distribution with Gaussian.
338
+
339
+ Args:
340
+ simcc_x (np.ndarray[K, Wx]): model predicted simcc in x.
341
+ simcc_y (np.ndarray[K, Wy]): model predicted simcc in y.
342
+ simcc_split_ratio (int): The split ratio of simcc.
343
+
344
+ Returns:
345
+ tuple: A tuple containing center and scale.
346
+ - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2)
347
+ - np.ndarray[float32]: scores in shape (K,) or (n, K)
348
+ """
349
+ keypoints, scores = get_simcc_maximum(simcc_x, simcc_y)
350
+ keypoints /= simcc_split_ratio
351
+
352
+ return keypoints, scores
353
+
354
+
355
+ def inference_pose(session, out_bbox, oriImg):
356
+ h, w = session.get_inputs()[0].shape[2:]
357
+ model_input_size = (w, h)
358
+ resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size)
359
+ outputs = inference(session, resized_img)
360
+ keypoints, scores = postprocess(outputs, model_input_size, center, scale)
361
+
362
+ return keypoints, scores
preprocessing/dwpose/pose.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import os
5
+
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from . import util
10
+ from .wholebody import Wholebody, HWC3, resize_image
11
+ from PIL import Image
12
+
13
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
14
+
15
+ def convert_to_numpy(image):
16
+ if isinstance(image, Image.Image):
17
+ image = np.array(image)
18
+ elif isinstance(image, torch.Tensor):
19
+ image = image.detach().cpu().numpy()
20
+ elif isinstance(image, np.ndarray):
21
+ image = image.copy()
22
+ else:
23
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
24
+ return image
25
+
26
+
27
+
28
+ def draw_pose(pose, H, W, use_hand=False, use_body=False, use_face=False):
29
+ bodies = pose['bodies']
30
+ faces = pose['faces']
31
+ hands = pose['hands']
32
+ candidate = bodies['candidate']
33
+ subset = bodies['subset']
34
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
35
+
36
+ if use_body:
37
+ canvas = util.draw_bodypose(canvas, candidate, subset)
38
+ if use_hand:
39
+ canvas = util.draw_handpose(canvas, hands)
40
+ if use_face:
41
+ canvas = util.draw_facepose(canvas, faces)
42
+
43
+ return canvas
44
+
45
+
46
+ class PoseAnnotator:
47
+ def __init__(self, cfg, device=None):
48
+ onnx_det = cfg['DETECTION_MODEL']
49
+ onnx_pose = cfg['POSE_MODEL']
50
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
51
+ self.pose_estimation = Wholebody(onnx_det, onnx_pose, device=self.device)
52
+ self.resize_size = cfg.get("RESIZE_SIZE", 1024)
53
+ self.use_body = cfg.get('USE_BODY', True)
54
+ self.use_face = cfg.get('USE_FACE', True)
55
+ self.use_hand = cfg.get('USE_HAND', True)
56
+
57
+ @torch.no_grad()
58
+ @torch.inference_mode
59
+ def forward(self, image):
60
+ image = convert_to_numpy(image)
61
+ input_image = HWC3(image[..., ::-1])
62
+ return self.process(resize_image(input_image, self.resize_size), image.shape[:2])
63
+
64
+ def process(self, ori_img, ori_shape):
65
+ ori_h, ori_w = ori_shape
66
+ ori_img = ori_img.copy()
67
+ H, W, C = ori_img.shape
68
+ with torch.no_grad():
69
+ candidate, subset, det_result = self.pose_estimation(ori_img)
70
+ nums, keys, locs = candidate.shape
71
+ candidate[..., 0] /= float(W)
72
+ candidate[..., 1] /= float(H)
73
+ body = candidate[:, :18].copy()
74
+ body = body.reshape(nums * 18, locs)
75
+ score = subset[:, :18]
76
+ for i in range(len(score)):
77
+ for j in range(len(score[i])):
78
+ if score[i][j] > 0.3:
79
+ score[i][j] = int(18 * i + j)
80
+ else:
81
+ score[i][j] = -1
82
+
83
+ un_visible = subset < 0.3
84
+ candidate[un_visible] = -1
85
+
86
+ foot = candidate[:, 18:24]
87
+
88
+ faces = candidate[:, 24:92]
89
+
90
+ hands = candidate[:, 92:113]
91
+ hands = np.vstack([hands, candidate[:, 113:]])
92
+
93
+ bodies = dict(candidate=body, subset=score)
94
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
95
+
96
+ ret_data = {}
97
+ if self.use_body:
98
+ detected_map_body = draw_pose(pose, H, W, use_body=True)
99
+ detected_map_body = cv2.resize(detected_map_body[..., ::-1], (ori_w, ori_h),
100
+ interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
101
+ ret_data["detected_map_body"] = detected_map_body
102
+
103
+ if self.use_face:
104
+ detected_map_face = draw_pose(pose, H, W, use_face=True)
105
+ detected_map_face = cv2.resize(detected_map_face[..., ::-1], (ori_w, ori_h),
106
+ interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
107
+ ret_data["detected_map_face"] = detected_map_face
108
+
109
+ if self.use_body and self.use_face:
110
+ detected_map_bodyface = draw_pose(pose, H, W, use_body=True, use_face=True)
111
+ detected_map_bodyface = cv2.resize(detected_map_bodyface[..., ::-1], (ori_w, ori_h),
112
+ interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
113
+ ret_data["detected_map_bodyface"] = detected_map_bodyface
114
+
115
+ if self.use_hand and self.use_body and self.use_face:
116
+ detected_map_handbodyface = draw_pose(pose, H, W, use_hand=True, use_body=True, use_face=True)
117
+ detected_map_handbodyface = cv2.resize(detected_map_handbodyface[..., ::-1], (ori_w, ori_h),
118
+ interpolation=cv2.INTER_LANCZOS4 if ori_h * ori_w > H * W else cv2.INTER_AREA)
119
+ ret_data["detected_map_handbodyface"] = detected_map_handbodyface
120
+
121
+ # convert_size
122
+ if det_result.shape[0] > 0:
123
+ w_ratio, h_ratio = ori_w / W, ori_h / H
124
+ det_result[..., ::2] *= h_ratio
125
+ det_result[..., 1::2] *= w_ratio
126
+ det_result = det_result.astype(np.int32)
127
+ return ret_data, det_result
128
+
129
+
130
+ class PoseBodyFaceAnnotator(PoseAnnotator):
131
+ def __init__(self, cfg):
132
+ super().__init__(cfg)
133
+ self.use_body, self.use_face, self.use_hand = True, True, False
134
+ @torch.no_grad()
135
+ @torch.inference_mode
136
+ def forward(self, image):
137
+ ret_data, det_result = super().forward(image)
138
+ return ret_data['detected_map_bodyface']
139
+
140
+
141
+ class PoseBodyFaceVideoAnnotator(PoseBodyFaceAnnotator):
142
+ def forward(self, frames):
143
+ ret_frames = []
144
+ for frame in frames:
145
+ anno_frame = super().forward(np.array(frame))
146
+ ret_frames.append(anno_frame)
147
+ return ret_frames
148
+
149
+ import imageio
150
+
151
+ def save_one_video(file_path, videos, fps=8, quality=8, macro_block_size=None):
152
+ try:
153
+ video_writer = imageio.get_writer(file_path, fps=fps, codec='libx264', quality=quality, macro_block_size=macro_block_size)
154
+ for frame in videos:
155
+ video_writer.append_data(frame)
156
+ video_writer.close()
157
+ return True
158
+ except Exception as e:
159
+ print(f"Video save error: {e}")
160
+ return False
161
+
162
+ def get_frames(video_path):
163
+ frames = []
164
+
165
+
166
+ # Opens the Video file with CV2
167
+ cap = cv2.VideoCapture(video_path)
168
+
169
+ fps = cap.get(cv2.CAP_PROP_FPS)
170
+ print("video fps: " + str(fps))
171
+ i = 0
172
+ while cap.isOpened():
173
+ ret, frame = cap.read()
174
+ if ret == False:
175
+ break
176
+ frames.append(frame)
177
+ i += 1
178
+
179
+ cap.release()
180
+ cv2.destroyAllWindows()
181
+
182
+ return frames, fps
183
+
preprocessing/dwpose/util.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import numpy as np
5
+ import matplotlib
6
+ import cv2
7
+
8
+
9
+ eps = 0.01
10
+
11
+
12
+ def smart_resize(x, s):
13
+ Ht, Wt = s
14
+ if x.ndim == 2:
15
+ Ho, Wo = x.shape
16
+ Co = 1
17
+ else:
18
+ Ho, Wo, Co = x.shape
19
+ if Co == 3 or Co == 1:
20
+ k = float(Ht + Wt) / float(Ho + Wo)
21
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
22
+ else:
23
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
24
+
25
+
26
+ def smart_resize_k(x, fx, fy):
27
+ if x.ndim == 2:
28
+ Ho, Wo = x.shape
29
+ Co = 1
30
+ else:
31
+ Ho, Wo, Co = x.shape
32
+ Ht, Wt = Ho * fy, Wo * fx
33
+ if Co == 3 or Co == 1:
34
+ k = float(Ht + Wt) / float(Ho + Wo)
35
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
36
+ else:
37
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
38
+
39
+
40
+ def padRightDownCorner(img, stride, padValue):
41
+ h = img.shape[0]
42
+ w = img.shape[1]
43
+
44
+ pad = 4 * [None]
45
+ pad[0] = 0 # up
46
+ pad[1] = 0 # left
47
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
48
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
49
+
50
+ img_padded = img
51
+ pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
52
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
53
+ pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
54
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
55
+ pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
56
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
57
+ pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
58
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
59
+
60
+ return img_padded, pad
61
+
62
+
63
+ def transfer(model, model_weights):
64
+ transfered_model_weights = {}
65
+ for weights_name in model.state_dict().keys():
66
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
67
+ return transfered_model_weights
68
+
69
+
70
+ def draw_bodypose(canvas, candidate, subset):
71
+ H, W, C = canvas.shape
72
+ candidate = np.array(candidate)
73
+ subset = np.array(subset)
74
+
75
+ stickwidth = 4
76
+
77
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
78
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
79
+ [1, 16], [16, 18], [3, 17], [6, 18]]
80
+
81
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
82
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
83
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
84
+
85
+ for i in range(17):
86
+ for n in range(len(subset)):
87
+ index = subset[n][np.array(limbSeq[i]) - 1]
88
+ if -1 in index:
89
+ continue
90
+ Y = candidate[index.astype(int), 0] * float(W)
91
+ X = candidate[index.astype(int), 1] * float(H)
92
+ mX = np.mean(X)
93
+ mY = np.mean(Y)
94
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
95
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
96
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
97
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
98
+
99
+ canvas = (canvas * 0.6).astype(np.uint8)
100
+
101
+ for i in range(18):
102
+ for n in range(len(subset)):
103
+ index = int(subset[n][i])
104
+ if index == -1:
105
+ continue
106
+ x, y = candidate[index][0:2]
107
+ x = int(x * W)
108
+ y = int(y * H)
109
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
110
+
111
+ return canvas
112
+
113
+
114
+ def draw_handpose(canvas, all_hand_peaks):
115
+ H, W, C = canvas.shape
116
+
117
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
118
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
119
+
120
+ for peaks in all_hand_peaks:
121
+ peaks = np.array(peaks)
122
+
123
+ for ie, e in enumerate(edges):
124
+ x1, y1 = peaks[e[0]]
125
+ x2, y2 = peaks[e[1]]
126
+ x1 = int(x1 * W)
127
+ y1 = int(y1 * H)
128
+ x2 = int(x2 * W)
129
+ y2 = int(y2 * H)
130
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
131
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
132
+
133
+ for i, keyponit in enumerate(peaks):
134
+ x, y = keyponit
135
+ x = int(x * W)
136
+ y = int(y * H)
137
+ if x > eps and y > eps:
138
+ cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
139
+ return canvas
140
+
141
+
142
+ def draw_facepose(canvas, all_lmks):
143
+ H, W, C = canvas.shape
144
+ for lmks in all_lmks:
145
+ lmks = np.array(lmks)
146
+ for lmk in lmks:
147
+ x, y = lmk
148
+ x = int(x * W)
149
+ y = int(y * H)
150
+ if x > eps and y > eps:
151
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
152
+ return canvas
153
+
154
+
155
+ # detect hand according to body pose keypoints
156
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
157
+ def handDetect(candidate, subset, oriImg):
158
+ # right hand: wrist 4, elbow 3, shoulder 2
159
+ # left hand: wrist 7, elbow 6, shoulder 5
160
+ ratioWristElbow = 0.33
161
+ detect_result = []
162
+ image_height, image_width = oriImg.shape[0:2]
163
+ for person in subset.astype(int):
164
+ # if any of three not detected
165
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
166
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
167
+ if not (has_left or has_right):
168
+ continue
169
+ hands = []
170
+ #left hand
171
+ if has_left:
172
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
173
+ x1, y1 = candidate[left_shoulder_index][:2]
174
+ x2, y2 = candidate[left_elbow_index][:2]
175
+ x3, y3 = candidate[left_wrist_index][:2]
176
+ hands.append([x1, y1, x2, y2, x3, y3, True])
177
+ # right hand
178
+ if has_right:
179
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
180
+ x1, y1 = candidate[right_shoulder_index][:2]
181
+ x2, y2 = candidate[right_elbow_index][:2]
182
+ x3, y3 = candidate[right_wrist_index][:2]
183
+ hands.append([x1, y1, x2, y2, x3, y3, False])
184
+
185
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
186
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
187
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
188
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
189
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
190
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
191
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
192
+ x = x3 + ratioWristElbow * (x3 - x2)
193
+ y = y3 + ratioWristElbow * (y3 - y2)
194
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
195
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
196
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
197
+ # x-y refers to the center --> offset to topLeft point
198
+ # handRectangle.x -= handRectangle.width / 2.f;
199
+ # handRectangle.y -= handRectangle.height / 2.f;
200
+ x -= width / 2
201
+ y -= width / 2 # width = height
202
+ # overflow the image
203
+ if x < 0: x = 0
204
+ if y < 0: y = 0
205
+ width1 = width
206
+ width2 = width
207
+ if x + width > image_width: width1 = image_width - x
208
+ if y + width > image_height: width2 = image_height - y
209
+ width = min(width1, width2)
210
+ # the max hand box value is 20 pixels
211
+ if width >= 20:
212
+ detect_result.append([int(x), int(y), int(width), is_left])
213
+
214
+ '''
215
+ return value: [[x, y, w, True if left hand else False]].
216
+ width=height since the network require squared input.
217
+ x, y is the coordinate of top left
218
+ '''
219
+ return detect_result
220
+
221
+
222
+ # Written by Lvmin
223
+ def faceDetect(candidate, subset, oriImg):
224
+ # left right eye ear 14 15 16 17
225
+ detect_result = []
226
+ image_height, image_width = oriImg.shape[0:2]
227
+ for person in subset.astype(int):
228
+ has_head = person[0] > -1
229
+ if not has_head:
230
+ continue
231
+
232
+ has_left_eye = person[14] > -1
233
+ has_right_eye = person[15] > -1
234
+ has_left_ear = person[16] > -1
235
+ has_right_ear = person[17] > -1
236
+
237
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
238
+ continue
239
+
240
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
241
+
242
+ width = 0.0
243
+ x0, y0 = candidate[head][:2]
244
+
245
+ if has_left_eye:
246
+ x1, y1 = candidate[left_eye][:2]
247
+ d = max(abs(x0 - x1), abs(y0 - y1))
248
+ width = max(width, d * 3.0)
249
+
250
+ if has_right_eye:
251
+ x1, y1 = candidate[right_eye][:2]
252
+ d = max(abs(x0 - x1), abs(y0 - y1))
253
+ width = max(width, d * 3.0)
254
+
255
+ if has_left_ear:
256
+ x1, y1 = candidate[left_ear][:2]
257
+ d = max(abs(x0 - x1), abs(y0 - y1))
258
+ width = max(width, d * 1.5)
259
+
260
+ if has_right_ear:
261
+ x1, y1 = candidate[right_ear][:2]
262
+ d = max(abs(x0 - x1), abs(y0 - y1))
263
+ width = max(width, d * 1.5)
264
+
265
+ x, y = x0, y0
266
+
267
+ x -= width
268
+ y -= width
269
+
270
+ if x < 0:
271
+ x = 0
272
+
273
+ if y < 0:
274
+ y = 0
275
+
276
+ width1 = width * 2
277
+ width2 = width * 2
278
+
279
+ if x + width > image_width:
280
+ width1 = image_width - x
281
+
282
+ if y + width > image_height:
283
+ width2 = image_height - y
284
+
285
+ width = min(width1, width2)
286
+
287
+ if width >= 20:
288
+ detect_result.append([int(x), int(y), int(width)])
289
+
290
+ return detect_result
291
+
292
+
293
+ # get max index of 2d array
294
+ def npmax(array):
295
+ arrayindex = array.argmax(1)
296
+ arrayvalue = array.max(1)
297
+ i = arrayvalue.argmax()
298
+ j = arrayindex[i]
299
+ return i, j
preprocessing/dwpose/wholebody.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import cv2
4
+ import numpy as np
5
+ import onnxruntime as ort
6
+ from .onnxdet import inference_detector
7
+ from .onnxpose import inference_pose
8
+
9
+ def HWC3(x):
10
+ assert x.dtype == np.uint8
11
+ if x.ndim == 2:
12
+ x = x[:, :, None]
13
+ assert x.ndim == 3
14
+ H, W, C = x.shape
15
+ assert C == 1 or C == 3 or C == 4
16
+ if C == 3:
17
+ return x
18
+ if C == 1:
19
+ return np.concatenate([x, x, x], axis=2)
20
+ if C == 4:
21
+ color = x[:, :, 0:3].astype(np.float32)
22
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
23
+ y = color * alpha + 255.0 * (1.0 - alpha)
24
+ y = y.clip(0, 255).astype(np.uint8)
25
+ return y
26
+
27
+
28
+ def resize_image(input_image, resolution):
29
+ H, W, C = input_image.shape
30
+ H = float(H)
31
+ W = float(W)
32
+ k = float(resolution) / min(H, W)
33
+ H *= k
34
+ W *= k
35
+ H = int(np.round(H / 64.0)) * 64
36
+ W = int(np.round(W / 64.0)) * 64
37
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
38
+ return img
39
+
40
+ class Wholebody:
41
+ def __init__(self, onnx_det, onnx_pose, device = 'cuda:0'):
42
+
43
+ providers = ['CPUExecutionProvider'
44
+ ] if device == 'cpu' else ['CUDAExecutionProvider']
45
+ # onnx_det = 'annotator/ckpts/yolox_l.onnx'
46
+ # onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
47
+
48
+ self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
49
+ self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
50
+
51
+ def __call__(self, ori_img):
52
+ det_result = inference_detector(self.session_det, ori_img)
53
+ keypoints, scores = inference_pose(self.session_pose, det_result, ori_img)
54
+
55
+ keypoints_info = np.concatenate(
56
+ (keypoints, scores[..., None]), axis=-1)
57
+ # compute neck joint
58
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
59
+ # neck score when visualizing pred
60
+ neck[:, 2:4] = np.logical_and(
61
+ keypoints_info[:, 5, 2:4] > 0.3,
62
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
63
+ new_keypoints_info = np.insert(
64
+ keypoints_info, 17, neck, axis=1)
65
+ mmpose_idx = [
66
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
67
+ ]
68
+ openpose_idx = [
69
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
70
+ ]
71
+ new_keypoints_info[:, openpose_idx] = \
72
+ new_keypoints_info[:, mmpose_idx]
73
+ keypoints_info = new_keypoints_info
74
+
75
+ keypoints, scores = keypoints_info[
76
+ ..., :2], keypoints_info[..., 2]
77
+
78
+ return keypoints, scores, det_result
79
+
80
+
preprocessing/gray.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+
9
+ def convert_to_numpy(image):
10
+ if isinstance(image, Image.Image):
11
+ image = np.array(image)
12
+ elif isinstance(image, torch.Tensor):
13
+ image = image.detach().cpu().numpy()
14
+ elif isinstance(image, np.ndarray):
15
+ image = image.copy()
16
+ else:
17
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
18
+ return image
19
+
20
+ class GrayAnnotator:
21
+ def __init__(self, cfg):
22
+ pass
23
+ def forward(self, image):
24
+ image = convert_to_numpy(image)
25
+ gray_map = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
26
+ return gray_map[..., None].repeat(3, axis=2)
27
+
28
+
29
+ class GrayVideoAnnotator(GrayAnnotator):
30
+ def forward(self, frames):
31
+ ret_frames = []
32
+ for frame in frames:
33
+ anno_frame = super().forward(np.array(frame))
34
+ ret_frames.append(anno_frame)
35
+ return ret_frames
preprocessing/midas/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
preprocessing/midas/api.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ # based on https://github.com/isl-org/MiDaS
4
+
5
+ import cv2
6
+ import torch
7
+ import torch.nn as nn
8
+ from torchvision.transforms import Compose
9
+
10
+ from .dpt_depth import DPTDepthModel
11
+ from .midas_net import MidasNet
12
+ from .midas_net_custom import MidasNet_small
13
+ from .transforms import NormalizeImage, PrepareForNet, Resize
14
+
15
+ # ISL_PATHS = {
16
+ # "dpt_large": "dpt_large-midas-2f21e586.pt",
17
+ # "dpt_hybrid": "dpt_hybrid-midas-501f0c75.pt",
18
+ # "midas_v21": "",
19
+ # "midas_v21_small": "",
20
+ # }
21
+
22
+ # remote_model_path =
23
+ # "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt"
24
+
25
+
26
+ def disabled_train(self, mode=True):
27
+ """Overwrite model.train with this function to make sure train/eval mode
28
+ does not change anymore."""
29
+ return self
30
+
31
+
32
+ def load_midas_transform(model_type):
33
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
34
+ # load transform only
35
+ if model_type == 'dpt_large': # DPT-Large
36
+ net_w, net_h = 384, 384
37
+ resize_mode = 'minimal'
38
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
39
+ std=[0.5, 0.5, 0.5])
40
+
41
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
42
+ net_w, net_h = 384, 384
43
+ resize_mode = 'minimal'
44
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
45
+ std=[0.5, 0.5, 0.5])
46
+
47
+ elif model_type == 'midas_v21':
48
+ net_w, net_h = 384, 384
49
+ resize_mode = 'upper_bound'
50
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
51
+ std=[0.229, 0.224, 0.225])
52
+
53
+ elif model_type == 'midas_v21_small':
54
+ net_w, net_h = 256, 256
55
+ resize_mode = 'upper_bound'
56
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
57
+ std=[0.229, 0.224, 0.225])
58
+
59
+ else:
60
+ assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
61
+
62
+ transform = Compose([
63
+ Resize(
64
+ net_w,
65
+ net_h,
66
+ resize_target=None,
67
+ keep_aspect_ratio=True,
68
+ ensure_multiple_of=32,
69
+ resize_method=resize_mode,
70
+ image_interpolation_method=cv2.INTER_CUBIC,
71
+ ),
72
+ normalization,
73
+ PrepareForNet(),
74
+ ])
75
+
76
+ return transform
77
+
78
+
79
+ def load_model(model_type, model_path):
80
+ # https://github.com/isl-org/MiDaS/blob/master/run.py
81
+ # load network
82
+ # model_path = ISL_PATHS[model_type]
83
+ if model_type == 'dpt_large': # DPT-Large
84
+ model = DPTDepthModel(
85
+ path=model_path,
86
+ backbone='vitl16_384',
87
+ non_negative=True,
88
+ )
89
+ net_w, net_h = 384, 384
90
+ resize_mode = 'minimal'
91
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
92
+ std=[0.5, 0.5, 0.5])
93
+
94
+ elif model_type == 'dpt_hybrid': # DPT-Hybrid
95
+ model = DPTDepthModel(
96
+ path=model_path,
97
+ backbone='vitb_rn50_384',
98
+ non_negative=True,
99
+ )
100
+ net_w, net_h = 384, 384
101
+ resize_mode = 'minimal'
102
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5],
103
+ std=[0.5, 0.5, 0.5])
104
+
105
+ elif model_type == 'midas_v21':
106
+ model = MidasNet(model_path, non_negative=True)
107
+ net_w, net_h = 384, 384
108
+ resize_mode = 'upper_bound'
109
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
110
+ std=[0.229, 0.224, 0.225])
111
+
112
+ elif model_type == 'midas_v21_small':
113
+ model = MidasNet_small(model_path,
114
+ features=64,
115
+ backbone='efficientnet_lite3',
116
+ exportable=True,
117
+ non_negative=True,
118
+ blocks={'expand': True})
119
+ net_w, net_h = 256, 256
120
+ resize_mode = 'upper_bound'
121
+ normalization = NormalizeImage(mean=[0.485, 0.456, 0.406],
122
+ std=[0.229, 0.224, 0.225])
123
+
124
+ else:
125
+ print(
126
+ f"model_type '{model_type}' not implemented, use: --model_type large"
127
+ )
128
+ assert False
129
+
130
+ transform = Compose([
131
+ Resize(
132
+ net_w,
133
+ net_h,
134
+ resize_target=None,
135
+ keep_aspect_ratio=True,
136
+ ensure_multiple_of=32,
137
+ resize_method=resize_mode,
138
+ image_interpolation_method=cv2.INTER_CUBIC,
139
+ ),
140
+ normalization,
141
+ PrepareForNet(),
142
+ ])
143
+
144
+ return model.eval(), transform
145
+
146
+
147
+ class MiDaSInference(nn.Module):
148
+ MODEL_TYPES_TORCH_HUB = ['DPT_Large', 'DPT_Hybrid', 'MiDaS_small']
149
+ MODEL_TYPES_ISL = [
150
+ 'dpt_large',
151
+ 'dpt_hybrid',
152
+ 'midas_v21',
153
+ 'midas_v21_small',
154
+ ]
155
+
156
+ def __init__(self, model_type, model_path):
157
+ super().__init__()
158
+ assert (model_type in self.MODEL_TYPES_ISL)
159
+ model, _ = load_model(model_type, model_path)
160
+ self.model = model
161
+ self.model.train = disabled_train
162
+
163
+ def forward(self, x):
164
+ with torch.no_grad():
165
+ prediction = self.model(x)
166
+ return prediction
preprocessing/midas/base_model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+
5
+
6
+ class BaseModel(torch.nn.Module):
7
+ def load(self, path):
8
+ """Load model from file.
9
+
10
+ Args:
11
+ path (str): file path
12
+ """
13
+ parameters = torch.load(path, map_location=torch.device('cpu'), weights_only=True)
14
+
15
+ if 'optimizer' in parameters:
16
+ parameters = parameters['model']
17
+
18
+ self.load_state_dict(parameters)
preprocessing/midas/blocks.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .vit import (_make_pretrained_vitb16_384, _make_pretrained_vitb_rn50_384,
7
+ _make_pretrained_vitl16_384)
8
+
9
+
10
+ def _make_encoder(
11
+ backbone,
12
+ features,
13
+ use_pretrained,
14
+ groups=1,
15
+ expand=False,
16
+ exportable=True,
17
+ hooks=None,
18
+ use_vit_only=False,
19
+ use_readout='ignore',
20
+ ):
21
+ if backbone == 'vitl16_384':
22
+ pretrained = _make_pretrained_vitl16_384(use_pretrained,
23
+ hooks=hooks,
24
+ use_readout=use_readout)
25
+ scratch = _make_scratch(
26
+ [256, 512, 1024, 1024], features, groups=groups,
27
+ expand=expand) # ViT-L/16 - 85.0% Top1 (backbone)
28
+ elif backbone == 'vitb_rn50_384':
29
+ pretrained = _make_pretrained_vitb_rn50_384(
30
+ use_pretrained,
31
+ hooks=hooks,
32
+ use_vit_only=use_vit_only,
33
+ use_readout=use_readout,
34
+ )
35
+ scratch = _make_scratch(
36
+ [256, 512, 768, 768], features, groups=groups,
37
+ expand=expand) # ViT-H/16 - 85.0% Top1 (backbone)
38
+ elif backbone == 'vitb16_384':
39
+ pretrained = _make_pretrained_vitb16_384(use_pretrained,
40
+ hooks=hooks,
41
+ use_readout=use_readout)
42
+ scratch = _make_scratch(
43
+ [96, 192, 384, 768], features, groups=groups,
44
+ expand=expand) # ViT-B/16 - 84.6% Top1 (backbone)
45
+ elif backbone == 'resnext101_wsl':
46
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
47
+ scratch = _make_scratch([256, 512, 1024, 2048],
48
+ features,
49
+ groups=groups,
50
+ expand=expand) # efficientnet_lite3
51
+ elif backbone == 'efficientnet_lite3':
52
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained,
53
+ exportable=exportable)
54
+ scratch = _make_scratch([32, 48, 136, 384],
55
+ features,
56
+ groups=groups,
57
+ expand=expand) # efficientnet_lite3
58
+ else:
59
+ print(f"Backbone '{backbone}' not implemented")
60
+ assert False
61
+
62
+ return pretrained, scratch
63
+
64
+
65
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
66
+ scratch = nn.Module()
67
+
68
+ out_shape1 = out_shape
69
+ out_shape2 = out_shape
70
+ out_shape3 = out_shape
71
+ out_shape4 = out_shape
72
+ if expand is True:
73
+ out_shape1 = out_shape
74
+ out_shape2 = out_shape * 2
75
+ out_shape3 = out_shape * 4
76
+ out_shape4 = out_shape * 8
77
+
78
+ scratch.layer1_rn = nn.Conv2d(in_shape[0],
79
+ out_shape1,
80
+ kernel_size=3,
81
+ stride=1,
82
+ padding=1,
83
+ bias=False,
84
+ groups=groups)
85
+ scratch.layer2_rn = nn.Conv2d(in_shape[1],
86
+ out_shape2,
87
+ kernel_size=3,
88
+ stride=1,
89
+ padding=1,
90
+ bias=False,
91
+ groups=groups)
92
+ scratch.layer3_rn = nn.Conv2d(in_shape[2],
93
+ out_shape3,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=False,
98
+ groups=groups)
99
+ scratch.layer4_rn = nn.Conv2d(in_shape[3],
100
+ out_shape4,
101
+ kernel_size=3,
102
+ stride=1,
103
+ padding=1,
104
+ bias=False,
105
+ groups=groups)
106
+
107
+ return scratch
108
+
109
+
110
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
111
+ efficientnet = torch.hub.load('rwightman/gen-efficientnet-pytorch',
112
+ 'tf_efficientnet_lite3',
113
+ pretrained=use_pretrained,
114
+ exportable=exportable)
115
+ return _make_efficientnet_backbone(efficientnet)
116
+
117
+
118
+ def _make_efficientnet_backbone(effnet):
119
+ pretrained = nn.Module()
120
+
121
+ pretrained.layer1 = nn.Sequential(effnet.conv_stem, effnet.bn1,
122
+ effnet.act1, *effnet.blocks[0:2])
123
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
124
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
125
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
126
+
127
+ return pretrained
128
+
129
+
130
+ def _make_resnet_backbone(resnet):
131
+ pretrained = nn.Module()
132
+ pretrained.layer1 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu,
133
+ resnet.maxpool, resnet.layer1)
134
+
135
+ pretrained.layer2 = resnet.layer2
136
+ pretrained.layer3 = resnet.layer3
137
+ pretrained.layer4 = resnet.layer4
138
+
139
+ return pretrained
140
+
141
+
142
+ def _make_pretrained_resnext101_wsl(use_pretrained):
143
+ resnet = torch.hub.load('facebookresearch/WSL-Images',
144
+ 'resnext101_32x8d_wsl')
145
+ return _make_resnet_backbone(resnet)
146
+
147
+
148
+ class Interpolate(nn.Module):
149
+ """Interpolation module.
150
+ """
151
+ def __init__(self, scale_factor, mode, align_corners=False):
152
+ """Init.
153
+
154
+ Args:
155
+ scale_factor (float): scaling
156
+ mode (str): interpolation mode
157
+ """
158
+ super(Interpolate, self).__init__()
159
+
160
+ self.interp = nn.functional.interpolate
161
+ self.scale_factor = scale_factor
162
+ self.mode = mode
163
+ self.align_corners = align_corners
164
+
165
+ def forward(self, x):
166
+ """Forward pass.
167
+
168
+ Args:
169
+ x (tensor): input
170
+
171
+ Returns:
172
+ tensor: interpolated data
173
+ """
174
+
175
+ x = self.interp(x,
176
+ scale_factor=self.scale_factor,
177
+ mode=self.mode,
178
+ align_corners=self.align_corners)
179
+
180
+ return x
181
+
182
+
183
+ class ResidualConvUnit(nn.Module):
184
+ """Residual convolution module.
185
+ """
186
+ def __init__(self, features):
187
+ """Init.
188
+
189
+ Args:
190
+ features (int): number of features
191
+ """
192
+ super().__init__()
193
+
194
+ self.conv1 = nn.Conv2d(features,
195
+ features,
196
+ kernel_size=3,
197
+ stride=1,
198
+ padding=1,
199
+ bias=True)
200
+
201
+ self.conv2 = nn.Conv2d(features,
202
+ features,
203
+ kernel_size=3,
204
+ stride=1,
205
+ padding=1,
206
+ bias=True)
207
+
208
+ self.relu = nn.ReLU(inplace=True)
209
+
210
+ def forward(self, x):
211
+ """Forward pass.
212
+
213
+ Args:
214
+ x (tensor): input
215
+
216
+ Returns:
217
+ tensor: output
218
+ """
219
+ out = self.relu(x)
220
+ out = self.conv1(out)
221
+ out = self.relu(out)
222
+ out = self.conv2(out)
223
+
224
+ return out + x
225
+
226
+
227
+ class FeatureFusionBlock(nn.Module):
228
+ """Feature fusion block.
229
+ """
230
+ def __init__(self, features):
231
+ """Init.
232
+
233
+ Args:
234
+ features (int): number of features
235
+ """
236
+ super(FeatureFusionBlock, self).__init__()
237
+
238
+ self.resConfUnit1 = ResidualConvUnit(features)
239
+ self.resConfUnit2 = ResidualConvUnit(features)
240
+
241
+ def forward(self, *xs):
242
+ """Forward pass.
243
+
244
+ Returns:
245
+ tensor: output
246
+ """
247
+ output = xs[0]
248
+
249
+ if len(xs) == 2:
250
+ output += self.resConfUnit1(xs[1])
251
+
252
+ output = self.resConfUnit2(output)
253
+
254
+ output = nn.functional.interpolate(output,
255
+ scale_factor=2,
256
+ mode='bilinear',
257
+ align_corners=True)
258
+
259
+ return output
260
+
261
+
262
+ class ResidualConvUnit_custom(nn.Module):
263
+ """Residual convolution module.
264
+ """
265
+ def __init__(self, features, activation, bn):
266
+ """Init.
267
+
268
+ Args:
269
+ features (int): number of features
270
+ """
271
+ super().__init__()
272
+
273
+ self.bn = bn
274
+
275
+ self.groups = 1
276
+
277
+ self.conv1 = nn.Conv2d(features,
278
+ features,
279
+ kernel_size=3,
280
+ stride=1,
281
+ padding=1,
282
+ bias=True,
283
+ groups=self.groups)
284
+
285
+ self.conv2 = nn.Conv2d(features,
286
+ features,
287
+ kernel_size=3,
288
+ stride=1,
289
+ padding=1,
290
+ bias=True,
291
+ groups=self.groups)
292
+
293
+ if self.bn is True:
294
+ self.bn1 = nn.BatchNorm2d(features)
295
+ self.bn2 = nn.BatchNorm2d(features)
296
+
297
+ self.activation = activation
298
+
299
+ self.skip_add = nn.quantized.FloatFunctional()
300
+
301
+ def forward(self, x):
302
+ """Forward pass.
303
+
304
+ Args:
305
+ x (tensor): input
306
+
307
+ Returns:
308
+ tensor: output
309
+ """
310
+
311
+ out = self.activation(x)
312
+ out = self.conv1(out)
313
+ if self.bn is True:
314
+ out = self.bn1(out)
315
+
316
+ out = self.activation(out)
317
+ out = self.conv2(out)
318
+ if self.bn is True:
319
+ out = self.bn2(out)
320
+
321
+ if self.groups > 1:
322
+ out = self.conv_merge(out)
323
+
324
+ return self.skip_add.add(out, x)
325
+
326
+ # return out + x
327
+
328
+
329
+ class FeatureFusionBlock_custom(nn.Module):
330
+ """Feature fusion block.
331
+ """
332
+ def __init__(self,
333
+ features,
334
+ activation,
335
+ deconv=False,
336
+ bn=False,
337
+ expand=False,
338
+ align_corners=True):
339
+ """Init.
340
+
341
+ Args:
342
+ features (int): number of features
343
+ """
344
+ super(FeatureFusionBlock_custom, self).__init__()
345
+
346
+ self.deconv = deconv
347
+ self.align_corners = align_corners
348
+
349
+ self.groups = 1
350
+
351
+ self.expand = expand
352
+ out_features = features
353
+ if self.expand is True:
354
+ out_features = features // 2
355
+
356
+ self.out_conv = nn.Conv2d(features,
357
+ out_features,
358
+ kernel_size=1,
359
+ stride=1,
360
+ padding=0,
361
+ bias=True,
362
+ groups=1)
363
+
364
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
365
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
366
+
367
+ self.skip_add = nn.quantized.FloatFunctional()
368
+
369
+ def forward(self, *xs):
370
+ """Forward pass.
371
+
372
+ Returns:
373
+ tensor: output
374
+ """
375
+ output = xs[0]
376
+
377
+ if len(xs) == 2:
378
+ res = self.resConfUnit1(xs[1])
379
+ output = self.skip_add.add(output, res)
380
+ # output += res
381
+
382
+ output = self.resConfUnit2(output)
383
+
384
+ output = nn.functional.interpolate(output,
385
+ scale_factor=2,
386
+ mode='bilinear',
387
+ align_corners=self.align_corners)
388
+
389
+ output = self.out_conv(output)
390
+
391
+ return output
preprocessing/midas/depth.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+ from PIL import Image
8
+ import cv2
9
+
10
+
11
+
12
+ def convert_to_numpy(image):
13
+ if isinstance(image, Image.Image):
14
+ image = np.array(image)
15
+ elif isinstance(image, torch.Tensor):
16
+ image = image.detach().cpu().numpy()
17
+ elif isinstance(image, np.ndarray):
18
+ image = image.copy()
19
+ else:
20
+ raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
21
+ return image
22
+
23
+ def resize_image(input_image, resolution):
24
+ H, W, C = input_image.shape
25
+ H = float(H)
26
+ W = float(W)
27
+ k = float(resolution) / min(H, W)
28
+ H *= k
29
+ W *= k
30
+ H = int(np.round(H / 64.0)) * 64
31
+ W = int(np.round(W / 64.0)) * 64
32
+ img = cv2.resize(
33
+ input_image, (W, H),
34
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
35
+ return img, k
36
+
37
+
38
+ def resize_image_ori(h, w, image, k):
39
+ img = cv2.resize(
40
+ image, (w, h),
41
+ interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
42
+ return img
43
+
44
+ class DepthAnnotator:
45
+ def __init__(self, cfg, device=None):
46
+ from .api import MiDaSInference
47
+ pretrained_model = cfg['PRETRAINED_MODEL']
48
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
49
+ self.model = MiDaSInference(model_type='dpt_hybrid', model_path=pretrained_model).to(self.device)
50
+ self.a = cfg.get('A', np.pi * 2.0)
51
+ self.bg_th = cfg.get('BG_TH', 0.1)
52
+
53
+ @torch.no_grad()
54
+ @torch.inference_mode()
55
+ @torch.autocast('cuda', enabled=False)
56
+ def forward(self, image):
57
+ image = convert_to_numpy(image)
58
+ image_depth = image
59
+ h, w, c = image.shape
60
+ image_depth, k = resize_image(image_depth,
61
+ 1024 if min(h, w) > 1024 else min(h, w))
62
+ image_depth = torch.from_numpy(image_depth).float().to(self.device)
63
+ image_depth = image_depth / 127.5 - 1.0
64
+ image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
65
+ depth = self.model(image_depth)[0]
66
+
67
+ depth_pt = depth.clone()
68
+ depth_pt -= torch.min(depth_pt)
69
+ depth_pt /= torch.max(depth_pt)
70
+ depth_pt = depth_pt.cpu().numpy()
71
+ depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)
72
+ depth_image = depth_image[..., None].repeat(3, 2)
73
+
74
+ depth_image = resize_image_ori(h, w, depth_image, k)
75
+ return depth_image
76
+
77
+
78
+ class DepthVideoAnnotator(DepthAnnotator):
79
+ def forward(self, frames):
80
+ ret_frames = []
81
+ for frame in frames:
82
+ anno_frame = super().forward(np.array(frame))
83
+ ret_frames.append(anno_frame)
84
+ return ret_frames
preprocessing/midas/dpt_depth.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from .base_model import BaseModel
7
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
8
+ from .vit import forward_vit
9
+
10
+
11
+ def _make_fusion_block(features, use_bn):
12
+ return FeatureFusionBlock_custom(
13
+ features,
14
+ nn.ReLU(False),
15
+ deconv=False,
16
+ bn=use_bn,
17
+ expand=False,
18
+ align_corners=True,
19
+ )
20
+
21
+
22
+ class DPT(BaseModel):
23
+ def __init__(
24
+ self,
25
+ head,
26
+ features=256,
27
+ backbone='vitb_rn50_384',
28
+ readout='project',
29
+ channels_last=False,
30
+ use_bn=False,
31
+ ):
32
+
33
+ super(DPT, self).__init__()
34
+
35
+ self.channels_last = channels_last
36
+
37
+ hooks = {
38
+ 'vitb_rn50_384': [0, 1, 8, 11],
39
+ 'vitb16_384': [2, 5, 8, 11],
40
+ 'vitl16_384': [5, 11, 17, 23],
41
+ }
42
+
43
+ # Instantiate backbone and reassemble blocks
44
+ self.pretrained, self.scratch = _make_encoder(
45
+ backbone,
46
+ features,
47
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
48
+ groups=1,
49
+ expand=False,
50
+ exportable=False,
51
+ hooks=hooks[backbone],
52
+ use_readout=readout,
53
+ )
54
+
55
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
56
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
57
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
58
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
59
+
60
+ self.scratch.output_conv = head
61
+
62
+ def forward(self, x):
63
+ if self.channels_last is True:
64
+ x.contiguous(memory_format=torch.channels_last)
65
+
66
+ layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
67
+
68
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
69
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
70
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
71
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
72
+
73
+ path_4 = self.scratch.refinenet4(layer_4_rn)
74
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
75
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
76
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
77
+
78
+ out = self.scratch.output_conv(path_1)
79
+
80
+ return out
81
+
82
+
83
+ class DPTDepthModel(DPT):
84
+ def __init__(self, path=None, non_negative=True, **kwargs):
85
+ features = kwargs['features'] if 'features' in kwargs else 256
86
+
87
+ head = nn.Sequential(
88
+ nn.Conv2d(features,
89
+ features // 2,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1),
93
+ Interpolate(scale_factor=2, mode='bilinear', align_corners=True),
94
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
95
+ nn.ReLU(True),
96
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
97
+ nn.ReLU(True) if non_negative else nn.Identity(),
98
+ nn.Identity(),
99
+ )
100
+
101
+ super().__init__(head, **kwargs)
102
+
103
+ if path is not None:
104
+ self.load(path)
105
+
106
+ def forward(self, x):
107
+ return super().forward(x).squeeze(dim=1)
preprocessing/midas/midas_net.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
4
+ This file contains code that is adapted from
5
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .base_model import BaseModel
11
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
12
+
13
+
14
+ class MidasNet(BaseModel):
15
+ """Network for monocular depth estimation.
16
+ """
17
+ def __init__(self, path=None, features=256, non_negative=True):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print('Loading weights: ', path)
26
+
27
+ super(MidasNet, self).__init__()
28
+
29
+ use_pretrained = False if path is None else True
30
+
31
+ self.pretrained, self.scratch = _make_encoder(
32
+ backbone='resnext101_wsl',
33
+ features=features,
34
+ use_pretrained=use_pretrained)
35
+
36
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
37
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
38
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
39
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
40
+
41
+ self.scratch.output_conv = nn.Sequential(
42
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
43
+ Interpolate(scale_factor=2, mode='bilinear'),
44
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
45
+ nn.ReLU(True),
46
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
47
+ nn.ReLU(True) if non_negative else nn.Identity(),
48
+ )
49
+
50
+ if path:
51
+ self.load(path)
52
+
53
+ def forward(self, x):
54
+ """Forward pass.
55
+
56
+ Args:
57
+ x (tensor): input data (image)
58
+
59
+ Returns:
60
+ tensor: depth
61
+ """
62
+
63
+ layer_1 = self.pretrained.layer1(x)
64
+ layer_2 = self.pretrained.layer2(layer_1)
65
+ layer_3 = self.pretrained.layer3(layer_2)
66
+ layer_4 = self.pretrained.layer4(layer_3)
67
+
68
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
69
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
70
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
71
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
72
+
73
+ path_4 = self.scratch.refinenet4(layer_4_rn)
74
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
75
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
76
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
77
+
78
+ out = self.scratch.output_conv(path_1)
79
+
80
+ return torch.squeeze(out, dim=1)
preprocessing/midas/midas_net_custom.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
4
+ This file contains code that is adapted from
5
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
6
+ """
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from .base_model import BaseModel
11
+ from .blocks import FeatureFusionBlock_custom, Interpolate, _make_encoder
12
+
13
+
14
+ class MidasNet_small(BaseModel):
15
+ """Network for monocular depth estimation.
16
+ """
17
+ def __init__(self,
18
+ path=None,
19
+ features=64,
20
+ backbone='efficientnet_lite3',
21
+ non_negative=True,
22
+ exportable=True,
23
+ channels_last=False,
24
+ align_corners=True,
25
+ blocks={'expand': True}):
26
+ """Init.
27
+
28
+ Args:
29
+ path (str, optional): Path to saved model. Defaults to None.
30
+ features (int, optional): Number of features. Defaults to 256.
31
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
32
+ """
33
+ print('Loading weights: ', path)
34
+
35
+ super(MidasNet_small, self).__init__()
36
+
37
+ use_pretrained = False if path else True
38
+
39
+ self.channels_last = channels_last
40
+ self.blocks = blocks
41
+ self.backbone = backbone
42
+
43
+ self.groups = 1
44
+
45
+ features1 = features
46
+ features2 = features
47
+ features3 = features
48
+ features4 = features
49
+ self.expand = False
50
+ if 'expand' in self.blocks and self.blocks['expand'] is True:
51
+ self.expand = True
52
+ features1 = features
53
+ features2 = features * 2
54
+ features3 = features * 4
55
+ features4 = features * 8
56
+
57
+ self.pretrained, self.scratch = _make_encoder(self.backbone,
58
+ features,
59
+ use_pretrained,
60
+ groups=self.groups,
61
+ expand=self.expand,
62
+ exportable=exportable)
63
+
64
+ self.scratch.activation = nn.ReLU(False)
65
+
66
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(
67
+ features4,
68
+ self.scratch.activation,
69
+ deconv=False,
70
+ bn=False,
71
+ expand=self.expand,
72
+ align_corners=align_corners)
73
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(
74
+ features3,
75
+ self.scratch.activation,
76
+ deconv=False,
77
+ bn=False,
78
+ expand=self.expand,
79
+ align_corners=align_corners)
80
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(
81
+ features2,
82
+ self.scratch.activation,
83
+ deconv=False,
84
+ bn=False,
85
+ expand=self.expand,
86
+ align_corners=align_corners)
87
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(
88
+ features1,
89
+ self.scratch.activation,
90
+ deconv=False,
91
+ bn=False,
92
+ align_corners=align_corners)
93
+
94
+ self.scratch.output_conv = nn.Sequential(
95
+ nn.Conv2d(features,
96
+ features // 2,
97
+ kernel_size=3,
98
+ stride=1,
99
+ padding=1,
100
+ groups=self.groups),
101
+ Interpolate(scale_factor=2, mode='bilinear'),
102
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
103
+ self.scratch.activation,
104
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
105
+ nn.ReLU(True) if non_negative else nn.Identity(),
106
+ nn.Identity(),
107
+ )
108
+
109
+ if path:
110
+ self.load(path)
111
+
112
+ def forward(self, x):
113
+ """Forward pass.
114
+
115
+ Args:
116
+ x (tensor): input data (image)
117
+
118
+ Returns:
119
+ tensor: depth
120
+ """
121
+ if self.channels_last is True:
122
+ print('self.channels_last = ', self.channels_last)
123
+ x.contiguous(memory_format=torch.channels_last)
124
+
125
+ layer_1 = self.pretrained.layer1(x)
126
+ layer_2 = self.pretrained.layer2(layer_1)
127
+ layer_3 = self.pretrained.layer3(layer_2)
128
+ layer_4 = self.pretrained.layer4(layer_3)
129
+
130
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
131
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
132
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
133
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
134
+
135
+ path_4 = self.scratch.refinenet4(layer_4_rn)
136
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
137
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
138
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
139
+
140
+ out = self.scratch.output_conv(path_1)
141
+
142
+ return torch.squeeze(out, dim=1)
143
+
144
+
145
+ def fuse_model(m):
146
+ prev_previous_type = nn.Identity()
147
+ prev_previous_name = ''
148
+ previous_type = nn.Identity()
149
+ previous_name = ''
150
+ for name, module in m.named_modules():
151
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(
152
+ module) == nn.ReLU:
153
+ # print("FUSED ", prev_previous_name, previous_name, name)
154
+ torch.quantization.fuse_modules(
155
+ m, [prev_previous_name, previous_name, name], inplace=True)
156
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
157
+ # print("FUSED ", prev_previous_name, previous_name)
158
+ torch.quantization.fuse_modules(
159
+ m, [prev_previous_name, previous_name], inplace=True)
160
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
161
+ # print("FUSED ", previous_name, name)
162
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
163
+
164
+ prev_previous_type = previous_type
165
+ prev_previous_name = previous_name
166
+ previous_type = type(module)
167
+ previous_name = name
preprocessing/midas/transforms.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+
9
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
10
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
11
+
12
+ Args:
13
+ sample (dict): sample
14
+ size (tuple): image size
15
+
16
+ Returns:
17
+ tuple: new size
18
+ """
19
+ shape = list(sample['disparity'].shape)
20
+
21
+ if shape[0] >= size[0] and shape[1] >= size[1]:
22
+ return sample
23
+
24
+ scale = [0, 0]
25
+ scale[0] = size[0] / shape[0]
26
+ scale[1] = size[1] / shape[1]
27
+
28
+ scale = max(scale)
29
+
30
+ shape[0] = math.ceil(scale * shape[0])
31
+ shape[1] = math.ceil(scale * shape[1])
32
+
33
+ # resize
34
+ sample['image'] = cv2.resize(sample['image'],
35
+ tuple(shape[::-1]),
36
+ interpolation=image_interpolation_method)
37
+
38
+ sample['disparity'] = cv2.resize(sample['disparity'],
39
+ tuple(shape[::-1]),
40
+ interpolation=cv2.INTER_NEAREST)
41
+ sample['mask'] = cv2.resize(
42
+ sample['mask'].astype(np.float32),
43
+ tuple(shape[::-1]),
44
+ interpolation=cv2.INTER_NEAREST,
45
+ )
46
+ sample['mask'] = sample['mask'].astype(bool)
47
+
48
+ return tuple(shape)
49
+
50
+
51
+ class Resize(object):
52
+ """Resize sample to given size (width, height).
53
+ """
54
+ def __init__(
55
+ self,
56
+ width,
57
+ height,
58
+ resize_target=True,
59
+ keep_aspect_ratio=False,
60
+ ensure_multiple_of=1,
61
+ resize_method='lower_bound',
62
+ image_interpolation_method=cv2.INTER_AREA,
63
+ ):
64
+ """Init.
65
+
66
+ Args:
67
+ width (int): desired output width
68
+ height (int): desired output height
69
+ resize_target (bool, optional):
70
+ True: Resize the full sample (image, mask, target).
71
+ False: Resize image only.
72
+ Defaults to True.
73
+ keep_aspect_ratio (bool, optional):
74
+ True: Keep the aspect ratio of the input sample.
75
+ Output sample might not have the given width and height, and
76
+ resize behaviour depends on the parameter 'resize_method'.
77
+ Defaults to False.
78
+ ensure_multiple_of (int, optional):
79
+ Output width and height is constrained to be multiple of this parameter.
80
+ Defaults to 1.
81
+ resize_method (str, optional):
82
+ "lower_bound": Output will be at least as large as the given size.
83
+ "upper_bound": Output will be at max as large as the given size. "
84
+ "(Output size might be smaller than given size.)"
85
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
86
+ Defaults to "lower_bound".
87
+ """
88
+ self.__width = width
89
+ self.__height = height
90
+
91
+ self.__resize_target = resize_target
92
+ self.__keep_aspect_ratio = keep_aspect_ratio
93
+ self.__multiple_of = ensure_multiple_of
94
+ self.__resize_method = resize_method
95
+ self.__image_interpolation_method = image_interpolation_method
96
+
97
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
98
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if max_val is not None and y > max_val:
101
+ y = (np.floor(x / self.__multiple_of) *
102
+ self.__multiple_of).astype(int)
103
+
104
+ if y < min_val:
105
+ y = (np.ceil(x / self.__multiple_of) *
106
+ self.__multiple_of).astype(int)
107
+
108
+ return y
109
+
110
+ def get_size(self, width, height):
111
+ # determine new height and width
112
+ scale_height = self.__height / height
113
+ scale_width = self.__width / width
114
+
115
+ if self.__keep_aspect_ratio:
116
+ if self.__resize_method == 'lower_bound':
117
+ # scale such that output size is lower bound
118
+ if scale_width > scale_height:
119
+ # fit width
120
+ scale_height = scale_width
121
+ else:
122
+ # fit height
123
+ scale_width = scale_height
124
+ elif self.__resize_method == 'upper_bound':
125
+ # scale such that output size is upper bound
126
+ if scale_width < scale_height:
127
+ # fit width
128
+ scale_height = scale_width
129
+ else:
130
+ # fit height
131
+ scale_width = scale_height
132
+ elif self.__resize_method == 'minimal':
133
+ # scale as least as possbile
134
+ if abs(1 - scale_width) < abs(1 - scale_height):
135
+ # fit width
136
+ scale_height = scale_width
137
+ else:
138
+ # fit height
139
+ scale_width = scale_height
140
+ else:
141
+ raise ValueError(
142
+ f'resize_method {self.__resize_method} not implemented')
143
+
144
+ if self.__resize_method == 'lower_bound':
145
+ new_height = self.constrain_to_multiple_of(scale_height * height,
146
+ min_val=self.__height)
147
+ new_width = self.constrain_to_multiple_of(scale_width * width,
148
+ min_val=self.__width)
149
+ elif self.__resize_method == 'upper_bound':
150
+ new_height = self.constrain_to_multiple_of(scale_height * height,
151
+ max_val=self.__height)
152
+ new_width = self.constrain_to_multiple_of(scale_width * width,
153
+ max_val=self.__width)
154
+ elif self.__resize_method == 'minimal':
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(
159
+ f'resize_method {self.__resize_method} not implemented')
160
+
161
+ return (new_width, new_height)
162
+
163
+ def __call__(self, sample):
164
+ width, height = self.get_size(sample['image'].shape[1],
165
+ sample['image'].shape[0])
166
+
167
+ # resize sample
168
+ sample['image'] = cv2.resize(
169
+ sample['image'],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if 'disparity' in sample:
176
+ sample['disparity'] = cv2.resize(
177
+ sample['disparity'],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if 'depth' in sample:
183
+ sample['depth'] = cv2.resize(sample['depth'], (width, height),
184
+ interpolation=cv2.INTER_NEAREST)
185
+
186
+ sample['mask'] = cv2.resize(
187
+ sample['mask'].astype(np.float32),
188
+ (width, height),
189
+ interpolation=cv2.INTER_NEAREST,
190
+ )
191
+ sample['mask'] = sample['mask'].astype(bool)
192
+
193
+ return sample
194
+
195
+
196
+ class NormalizeImage(object):
197
+ """Normlize image by given mean and std.
198
+ """
199
+ def __init__(self, mean, std):
200
+ self.__mean = mean
201
+ self.__std = std
202
+
203
+ def __call__(self, sample):
204
+ sample['image'] = (sample['image'] - self.__mean) / self.__std
205
+
206
+ return sample
207
+
208
+
209
+ class PrepareForNet(object):
210
+ """Prepare sample for usage as network input.
211
+ """
212
+ def __init__(self):
213
+ pass
214
+
215
+ def __call__(self, sample):
216
+ image = np.transpose(sample['image'], (2, 0, 1))
217
+ sample['image'] = np.ascontiguousarray(image).astype(np.float32)
218
+
219
+ if 'mask' in sample:
220
+ sample['mask'] = sample['mask'].astype(np.float32)
221
+ sample['mask'] = np.ascontiguousarray(sample['mask'])
222
+
223
+ if 'disparity' in sample:
224
+ disparity = sample['disparity'].astype(np.float32)
225
+ sample['disparity'] = np.ascontiguousarray(disparity)
226
+
227
+ if 'depth' in sample:
228
+ depth = sample['depth'].astype(np.float32)
229
+ sample['depth'] = np.ascontiguousarray(depth)
230
+
231
+ return sample
preprocessing/midas/utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ """Utils for monoDepth."""
4
+ import re
5
+ import sys
6
+
7
+ import cv2
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ def read_pfm(path):
13
+ """Read pfm file.
14
+
15
+ Args:
16
+ path (str): path to file
17
+
18
+ Returns:
19
+ tuple: (data, scale)
20
+ """
21
+ with open(path, 'rb') as file:
22
+
23
+ color = None
24
+ width = None
25
+ height = None
26
+ scale = None
27
+ endian = None
28
+
29
+ header = file.readline().rstrip()
30
+ if header.decode('ascii') == 'PF':
31
+ color = True
32
+ elif header.decode('ascii') == 'Pf':
33
+ color = False
34
+ else:
35
+ raise Exception('Not a PFM file: ' + path)
36
+
37
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$',
38
+ file.readline().decode('ascii'))
39
+ if dim_match:
40
+ width, height = list(map(int, dim_match.groups()))
41
+ else:
42
+ raise Exception('Malformed PFM header.')
43
+
44
+ scale = float(file.readline().decode('ascii').rstrip())
45
+ if scale < 0:
46
+ # little-endian
47
+ endian = '<'
48
+ scale = -scale
49
+ else:
50
+ # big-endian
51
+ endian = '>'
52
+
53
+ data = np.fromfile(file, endian + 'f')
54
+ shape = (height, width, 3) if color else (height, width)
55
+
56
+ data = np.reshape(data, shape)
57
+ data = np.flipud(data)
58
+
59
+ return data, scale
60
+
61
+
62
+ def write_pfm(path, image, scale=1):
63
+ """Write pfm file.
64
+
65
+ Args:
66
+ path (str): pathto file
67
+ image (array): data
68
+ scale (int, optional): Scale. Defaults to 1.
69
+ """
70
+
71
+ with open(path, 'wb') as file:
72
+ color = None
73
+
74
+ if image.dtype.name != 'float32':
75
+ raise Exception('Image dtype must be float32.')
76
+
77
+ image = np.flipud(image)
78
+
79
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
80
+ color = True
81
+ elif (len(image.shape) == 2
82
+ or len(image.shape) == 3 and image.shape[2] == 1): # greyscale
83
+ color = False
84
+ else:
85
+ raise Exception(
86
+ 'Image must have H x W x 3, H x W x 1 or H x W dimensions.')
87
+
88
+ file.write('PF\n' if color else 'Pf\n'.encode())
89
+ file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
90
+
91
+ endian = image.dtype.byteorder
92
+
93
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
94
+ scale = -scale
95
+
96
+ file.write('%f\n'.encode() % scale)
97
+
98
+ image.tofile(file)
99
+
100
+
101
+ def read_image(path):
102
+ """Read image and output RGB image (0-1).
103
+
104
+ Args:
105
+ path (str): path to file
106
+
107
+ Returns:
108
+ array: RGB image (0-1)
109
+ """
110
+ img = cv2.imread(path)
111
+
112
+ if img.ndim == 2:
113
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
114
+
115
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
116
+
117
+ return img
118
+
119
+
120
+ def resize_image(img):
121
+ """Resize image and make it fit for network.
122
+
123
+ Args:
124
+ img (array): image
125
+
126
+ Returns:
127
+ tensor: data ready for network
128
+ """
129
+ height_orig = img.shape[0]
130
+ width_orig = img.shape[1]
131
+
132
+ if width_orig > height_orig:
133
+ scale = width_orig / 384
134
+ else:
135
+ scale = height_orig / 384
136
+
137
+ height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
138
+ width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
139
+
140
+ img_resized = cv2.resize(img, (width, height),
141
+ interpolation=cv2.INTER_AREA)
142
+
143
+ img_resized = (torch.from_numpy(np.transpose(
144
+ img_resized, (2, 0, 1))).contiguous().float())
145
+ img_resized = img_resized.unsqueeze(0)
146
+
147
+ return img_resized
148
+
149
+
150
+ def resize_depth(depth, width, height):
151
+ """Resize depth map and bring to CPU (numpy).
152
+
153
+ Args:
154
+ depth (tensor): depth
155
+ width (int): image width
156
+ height (int): image height
157
+
158
+ Returns:
159
+ array: processed depth
160
+ """
161
+ depth = torch.squeeze(depth[0, :, :, :]).to('cpu')
162
+
163
+ depth_resized = cv2.resize(depth.numpy(), (width, height),
164
+ interpolation=cv2.INTER_CUBIC)
165
+
166
+ return depth_resized
167
+
168
+
169
+ def write_depth(path, depth, bits=1):
170
+ """Write depth map to pfm and png file.
171
+
172
+ Args:
173
+ path (str): filepath without extension
174
+ depth (array): depth
175
+ """
176
+ write_pfm(path + '.pfm', depth.astype(np.float32))
177
+
178
+ depth_min = depth.min()
179
+ depth_max = depth.max()
180
+
181
+ max_val = (2**(8 * bits)) - 1
182
+
183
+ if depth_max - depth_min > np.finfo('float').eps:
184
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
185
+ else:
186
+ out = np.zeros(depth.shape, dtype=depth.type)
187
+
188
+ if bits == 1:
189
+ cv2.imwrite(path + '.png', out.astype('uint8'))
190
+ elif bits == 2:
191
+ cv2.imwrite(path + '.png', out.astype('uint16'))
192
+
193
+ return
preprocessing/midas/vit.py ADDED
@@ -0,0 +1,510 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ import types
5
+
6
+ import timm
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class Slice(nn.Module):
13
+ def __init__(self, start_index=1):
14
+ super(Slice, self).__init__()
15
+ self.start_index = start_index
16
+
17
+ def forward(self, x):
18
+ return x[:, self.start_index:]
19
+
20
+
21
+ class AddReadout(nn.Module):
22
+ def __init__(self, start_index=1):
23
+ super(AddReadout, self).__init__()
24
+ self.start_index = start_index
25
+
26
+ def forward(self, x):
27
+ if self.start_index == 2:
28
+ readout = (x[:, 0] + x[:, 1]) / 2
29
+ else:
30
+ readout = x[:, 0]
31
+ return x[:, self.start_index:] + readout.unsqueeze(1)
32
+
33
+
34
+ class ProjectReadout(nn.Module):
35
+ def __init__(self, in_features, start_index=1):
36
+ super(ProjectReadout, self).__init__()
37
+ self.start_index = start_index
38
+
39
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features),
40
+ nn.GELU())
41
+
42
+ def forward(self, x):
43
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
44
+ features = torch.cat((x[:, self.start_index:], readout), -1)
45
+
46
+ return self.project(features)
47
+
48
+
49
+ class Transpose(nn.Module):
50
+ def __init__(self, dim0, dim1):
51
+ super(Transpose, self).__init__()
52
+ self.dim0 = dim0
53
+ self.dim1 = dim1
54
+
55
+ def forward(self, x):
56
+ x = x.transpose(self.dim0, self.dim1)
57
+ return x
58
+
59
+
60
+ def forward_vit(pretrained, x):
61
+ b, c, h, w = x.shape
62
+
63
+ _ = pretrained.model.forward_flex(x)
64
+
65
+ layer_1 = pretrained.activations['1']
66
+ layer_2 = pretrained.activations['2']
67
+ layer_3 = pretrained.activations['3']
68
+ layer_4 = pretrained.activations['4']
69
+
70
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
71
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
72
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
73
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
74
+
75
+ unflatten = nn.Sequential(
76
+ nn.Unflatten(
77
+ 2,
78
+ torch.Size([
79
+ h // pretrained.model.patch_size[1],
80
+ w // pretrained.model.patch_size[0],
81
+ ]),
82
+ ))
83
+
84
+ if layer_1.ndim == 3:
85
+ layer_1 = unflatten(layer_1)
86
+ if layer_2.ndim == 3:
87
+ layer_2 = unflatten(layer_2)
88
+ if layer_3.ndim == 3:
89
+ layer_3 = unflatten(layer_3)
90
+ if layer_4.ndim == 3:
91
+ layer_4 = unflatten(layer_4)
92
+
93
+ layer_1 = pretrained.act_postprocess1[3:len(pretrained.act_postprocess1)](
94
+ layer_1)
95
+ layer_2 = pretrained.act_postprocess2[3:len(pretrained.act_postprocess2)](
96
+ layer_2)
97
+ layer_3 = pretrained.act_postprocess3[3:len(pretrained.act_postprocess3)](
98
+ layer_3)
99
+ layer_4 = pretrained.act_postprocess4[3:len(pretrained.act_postprocess4)](
100
+ layer_4)
101
+
102
+ return layer_1, layer_2, layer_3, layer_4
103
+
104
+
105
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
106
+ posemb_tok, posemb_grid = (
107
+ posemb[:, :self.start_index],
108
+ posemb[0, self.start_index:],
109
+ )
110
+
111
+ gs_old = int(math.sqrt(len(posemb_grid)))
112
+
113
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
114
+ -1).permute(0, 3, 1, 2)
115
+ posemb_grid = F.interpolate(posemb_grid,
116
+ size=(gs_h, gs_w),
117
+ mode='bilinear')
118
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
119
+
120
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
121
+
122
+ return posemb
123
+
124
+
125
+ def forward_flex(self, x):
126
+ b, c, h, w = x.shape
127
+
128
+ pos_embed = self._resize_pos_embed(self.pos_embed, h // self.patch_size[1],
129
+ w // self.patch_size[0])
130
+
131
+ B = x.shape[0]
132
+
133
+ if hasattr(self.patch_embed, 'backbone'):
134
+ x = self.patch_embed.backbone(x)
135
+ if isinstance(x, (list, tuple)):
136
+ x = x[
137
+ -1] # last feature if backbone outputs list/tuple of features
138
+
139
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
140
+
141
+ if getattr(self, 'dist_token', None) is not None:
142
+ cls_tokens = self.cls_token.expand(
143
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
144
+ dist_token = self.dist_token.expand(B, -1, -1)
145
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
146
+ else:
147
+ cls_tokens = self.cls_token.expand(
148
+ B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
149
+ x = torch.cat((cls_tokens, x), dim=1)
150
+
151
+ x = x + pos_embed
152
+ x = self.pos_drop(x)
153
+
154
+ for blk in self.blocks:
155
+ x = blk(x)
156
+
157
+ x = self.norm(x)
158
+
159
+ return x
160
+
161
+
162
+ activations = {}
163
+
164
+
165
+ def get_activation(name):
166
+ def hook(model, input, output):
167
+ activations[name] = output
168
+
169
+ return hook
170
+
171
+
172
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
173
+ if use_readout == 'ignore':
174
+ readout_oper = [Slice(start_index)] * len(features)
175
+ elif use_readout == 'add':
176
+ readout_oper = [AddReadout(start_index)] * len(features)
177
+ elif use_readout == 'project':
178
+ readout_oper = [
179
+ ProjectReadout(vit_features, start_index) for out_feat in features
180
+ ]
181
+ else:
182
+ assert (
183
+ False
184
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
185
+
186
+ return readout_oper
187
+
188
+
189
+ def _make_vit_b16_backbone(
190
+ model,
191
+ features=[96, 192, 384, 768],
192
+ size=[384, 384],
193
+ hooks=[2, 5, 8, 11],
194
+ vit_features=768,
195
+ use_readout='ignore',
196
+ start_index=1,
197
+ ):
198
+ pretrained = nn.Module()
199
+
200
+ pretrained.model = model
201
+ pretrained.model.blocks[hooks[0]].register_forward_hook(
202
+ get_activation('1'))
203
+ pretrained.model.blocks[hooks[1]].register_forward_hook(
204
+ get_activation('2'))
205
+ pretrained.model.blocks[hooks[2]].register_forward_hook(
206
+ get_activation('3'))
207
+ pretrained.model.blocks[hooks[3]].register_forward_hook(
208
+ get_activation('4'))
209
+
210
+ pretrained.activations = activations
211
+
212
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
213
+ start_index)
214
+
215
+ # 32, 48, 136, 384
216
+ pretrained.act_postprocess1 = nn.Sequential(
217
+ readout_oper[0],
218
+ Transpose(1, 2),
219
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
220
+ nn.Conv2d(
221
+ in_channels=vit_features,
222
+ out_channels=features[0],
223
+ kernel_size=1,
224
+ stride=1,
225
+ padding=0,
226
+ ),
227
+ nn.ConvTranspose2d(
228
+ in_channels=features[0],
229
+ out_channels=features[0],
230
+ kernel_size=4,
231
+ stride=4,
232
+ padding=0,
233
+ bias=True,
234
+ dilation=1,
235
+ groups=1,
236
+ ),
237
+ )
238
+
239
+ pretrained.act_postprocess2 = nn.Sequential(
240
+ readout_oper[1],
241
+ Transpose(1, 2),
242
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
243
+ nn.Conv2d(
244
+ in_channels=vit_features,
245
+ out_channels=features[1],
246
+ kernel_size=1,
247
+ stride=1,
248
+ padding=0,
249
+ ),
250
+ nn.ConvTranspose2d(
251
+ in_channels=features[1],
252
+ out_channels=features[1],
253
+ kernel_size=2,
254
+ stride=2,
255
+ padding=0,
256
+ bias=True,
257
+ dilation=1,
258
+ groups=1,
259
+ ),
260
+ )
261
+
262
+ pretrained.act_postprocess3 = nn.Sequential(
263
+ readout_oper[2],
264
+ Transpose(1, 2),
265
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
266
+ nn.Conv2d(
267
+ in_channels=vit_features,
268
+ out_channels=features[2],
269
+ kernel_size=1,
270
+ stride=1,
271
+ padding=0,
272
+ ),
273
+ )
274
+
275
+ pretrained.act_postprocess4 = nn.Sequential(
276
+ readout_oper[3],
277
+ Transpose(1, 2),
278
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
279
+ nn.Conv2d(
280
+ in_channels=vit_features,
281
+ out_channels=features[3],
282
+ kernel_size=1,
283
+ stride=1,
284
+ padding=0,
285
+ ),
286
+ nn.Conv2d(
287
+ in_channels=features[3],
288
+ out_channels=features[3],
289
+ kernel_size=3,
290
+ stride=2,
291
+ padding=1,
292
+ ),
293
+ )
294
+
295
+ pretrained.model.start_index = start_index
296
+ pretrained.model.patch_size = [16, 16]
297
+
298
+ # We inject this function into the VisionTransformer instances so that
299
+ # we can use it with interpolated position embeddings without modifying the library source.
300
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
301
+ pretrained.model)
302
+ pretrained.model._resize_pos_embed = types.MethodType(
303
+ _resize_pos_embed, pretrained.model)
304
+
305
+ return pretrained
306
+
307
+
308
+ def _make_pretrained_vitl16_384(pretrained, use_readout='ignore', hooks=None):
309
+ model = timm.create_model('vit_large_patch16_384', pretrained=pretrained)
310
+
311
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
312
+ return _make_vit_b16_backbone(
313
+ model,
314
+ features=[256, 512, 1024, 1024],
315
+ hooks=hooks,
316
+ vit_features=1024,
317
+ use_readout=use_readout,
318
+ )
319
+
320
+
321
+ def _make_pretrained_vitb16_384(pretrained, use_readout='ignore', hooks=None):
322
+ model = timm.create_model('vit_base_patch16_384', pretrained=pretrained)
323
+
324
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
325
+ return _make_vit_b16_backbone(model,
326
+ features=[96, 192, 384, 768],
327
+ hooks=hooks,
328
+ use_readout=use_readout)
329
+
330
+
331
+ def _make_pretrained_deitb16_384(pretrained, use_readout='ignore', hooks=None):
332
+ model = timm.create_model('vit_deit_base_patch16_384',
333
+ pretrained=pretrained)
334
+
335
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
336
+ return _make_vit_b16_backbone(model,
337
+ features=[96, 192, 384, 768],
338
+ hooks=hooks,
339
+ use_readout=use_readout)
340
+
341
+
342
+ def _make_pretrained_deitb16_distil_384(pretrained,
343
+ use_readout='ignore',
344
+ hooks=None):
345
+ model = timm.create_model('vit_deit_base_distilled_patch16_384',
346
+ pretrained=pretrained)
347
+
348
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
349
+ return _make_vit_b16_backbone(
350
+ model,
351
+ features=[96, 192, 384, 768],
352
+ hooks=hooks,
353
+ use_readout=use_readout,
354
+ start_index=2,
355
+ )
356
+
357
+
358
+ def _make_vit_b_rn50_backbone(
359
+ model,
360
+ features=[256, 512, 768, 768],
361
+ size=[384, 384],
362
+ hooks=[0, 1, 8, 11],
363
+ vit_features=768,
364
+ use_vit_only=False,
365
+ use_readout='ignore',
366
+ start_index=1,
367
+ ):
368
+ pretrained = nn.Module()
369
+
370
+ pretrained.model = model
371
+
372
+ if use_vit_only is True:
373
+ pretrained.model.blocks[hooks[0]].register_forward_hook(
374
+ get_activation('1'))
375
+ pretrained.model.blocks[hooks[1]].register_forward_hook(
376
+ get_activation('2'))
377
+ else:
378
+ pretrained.model.patch_embed.backbone.stages[0].register_forward_hook(
379
+ get_activation('1'))
380
+ pretrained.model.patch_embed.backbone.stages[1].register_forward_hook(
381
+ get_activation('2'))
382
+
383
+ pretrained.model.blocks[hooks[2]].register_forward_hook(
384
+ get_activation('3'))
385
+ pretrained.model.blocks[hooks[3]].register_forward_hook(
386
+ get_activation('4'))
387
+
388
+ pretrained.activations = activations
389
+
390
+ readout_oper = get_readout_oper(vit_features, features, use_readout,
391
+ start_index)
392
+
393
+ if use_vit_only is True:
394
+ pretrained.act_postprocess1 = nn.Sequential(
395
+ readout_oper[0],
396
+ Transpose(1, 2),
397
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
398
+ nn.Conv2d(
399
+ in_channels=vit_features,
400
+ out_channels=features[0],
401
+ kernel_size=1,
402
+ stride=1,
403
+ padding=0,
404
+ ),
405
+ nn.ConvTranspose2d(
406
+ in_channels=features[0],
407
+ out_channels=features[0],
408
+ kernel_size=4,
409
+ stride=4,
410
+ padding=0,
411
+ bias=True,
412
+ dilation=1,
413
+ groups=1,
414
+ ),
415
+ )
416
+
417
+ pretrained.act_postprocess2 = nn.Sequential(
418
+ readout_oper[1],
419
+ Transpose(1, 2),
420
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
421
+ nn.Conv2d(
422
+ in_channels=vit_features,
423
+ out_channels=features[1],
424
+ kernel_size=1,
425
+ stride=1,
426
+ padding=0,
427
+ ),
428
+ nn.ConvTranspose2d(
429
+ in_channels=features[1],
430
+ out_channels=features[1],
431
+ kernel_size=2,
432
+ stride=2,
433
+ padding=0,
434
+ bias=True,
435
+ dilation=1,
436
+ groups=1,
437
+ ),
438
+ )
439
+ else:
440
+ pretrained.act_postprocess1 = nn.Sequential(nn.Identity(),
441
+ nn.Identity(),
442
+ nn.Identity())
443
+ pretrained.act_postprocess2 = nn.Sequential(nn.Identity(),
444
+ nn.Identity(),
445
+ nn.Identity())
446
+
447
+ pretrained.act_postprocess3 = nn.Sequential(
448
+ readout_oper[2],
449
+ Transpose(1, 2),
450
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
451
+ nn.Conv2d(
452
+ in_channels=vit_features,
453
+ out_channels=features[2],
454
+ kernel_size=1,
455
+ stride=1,
456
+ padding=0,
457
+ ),
458
+ )
459
+
460
+ pretrained.act_postprocess4 = nn.Sequential(
461
+ readout_oper[3],
462
+ Transpose(1, 2),
463
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
464
+ nn.Conv2d(
465
+ in_channels=vit_features,
466
+ out_channels=features[3],
467
+ kernel_size=1,
468
+ stride=1,
469
+ padding=0,
470
+ ),
471
+ nn.Conv2d(
472
+ in_channels=features[3],
473
+ out_channels=features[3],
474
+ kernel_size=3,
475
+ stride=2,
476
+ padding=1,
477
+ ),
478
+ )
479
+
480
+ pretrained.model.start_index = start_index
481
+ pretrained.model.patch_size = [16, 16]
482
+
483
+ # We inject this function into the VisionTransformer instances so that
484
+ # we can use it with interpolated position embeddings without modifying the library source.
485
+ pretrained.model.forward_flex = types.MethodType(forward_flex,
486
+ pretrained.model)
487
+
488
+ # We inject this function into the VisionTransformer instances so that
489
+ # we can use it with interpolated position embeddings without modifying the library source.
490
+ pretrained.model._resize_pos_embed = types.MethodType(
491
+ _resize_pos_embed, pretrained.model)
492
+
493
+ return pretrained
494
+
495
+
496
+ def _make_pretrained_vitb_rn50_384(pretrained,
497
+ use_readout='ignore',
498
+ hooks=None,
499
+ use_vit_only=False):
500
+ model = timm.create_model('vit_base_resnet50_384', pretrained=pretrained)
501
+
502
+ hooks = [0, 1, 8, 11] if hooks is None else hooks
503
+ return _make_vit_b_rn50_backbone(
504
+ model,
505
+ features=[256, 512, 768, 768],
506
+ size=[384, 384],
507
+ hooks=hooks,
508
+ use_vit_only=use_vit_only,
509
+ use_readout=use_readout,
510
+ )
wan/text2video.py CHANGED
@@ -207,18 +207,19 @@ class WanT2V:
207
  def vace_latent(self, z, m):
208
  return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
209
 
210
- def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, trim_video= 0):
211
  image_sizes = []
212
  for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
213
  if sub_src_mask is not None and sub_src_video is not None:
214
  src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
 
 
215
  src_video[i] = src_video[i].to(device)
216
  src_mask[i] = src_mask[i].to(device)
217
  src_video_shape = src_video[i].shape
218
  if src_video_shape[1] != num_frames:
219
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
220
  src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
221
-
222
  src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
223
  image_sizes.append(src_video[i].shape[2:])
224
  elif sub_src_video is None:
@@ -228,10 +229,11 @@ class WanT2V:
228
  else:
229
  src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
230
  src_video[i] = src_video[i].to(device)
 
231
  src_video_shape = src_video[i].shape
232
  if src_video_shape[1] != num_frames:
233
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
234
- src_mask[i] = torch.ones_like(src_video[i], device=device)
235
  image_sizes.append(src_video[i].shape[2:])
236
 
237
  for i, ref_images in enumerate(src_ref_images):
 
207
  def vace_latent(self, z, m):
208
  return [torch.cat([zz, mm], dim=0) for zz, mm in zip(z, m)]
209
 
210
+ def prepare_source(self, src_video, src_mask, src_ref_images, num_frames, image_size, device, original_video = False, trim_video= 0):
211
  image_sizes = []
212
  for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
213
  if sub_src_mask is not None and sub_src_video is not None:
214
  src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask, max_frames= num_frames, trim_video = trim_video)
215
+ # src_video is [-1, 1], 0 = inpainting area (in fact 127 in [0, 255])
216
+ # src_mask is [-1, 1], 0 = preserve original video (in fact 127 in [0, 255]) and 1 = Inpainting (in fact 255 in [0, 255])
217
  src_video[i] = src_video[i].to(device)
218
  src_mask[i] = src_mask[i].to(device)
219
  src_video_shape = src_video[i].shape
220
  if src_video_shape[1] != num_frames:
221
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
222
  src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
 
223
  src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
224
  image_sizes.append(src_video[i].shape[2:])
225
  elif sub_src_video is None:
 
229
  else:
230
  src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video, max_frames= num_frames, trim_video = trim_video)
231
  src_video[i] = src_video[i].to(device)
232
+ src_mask[i] = torch.zeros_like(src_video[i], device=device) if original_video else torch.ones_like(src_video[i], device=device)
233
  src_video_shape = src_video[i].shape
234
  if src_video_shape[1] != num_frames:
235
  src_video[i] = torch.cat( [src_video[i], src_video[i].new_zeros(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
236
+ src_mask[i] = torch.cat( [src_mask[i], src_mask[i].new_ones(src_video_shape[0], num_frames -src_video_shape[1], *src_video_shape[-2:])], dim=1)
237
  image_sizes.append(src_video[i].shape[2:])
238
 
239
  for i, ref_images in enumerate(src_ref_images):
wan/utils/utils.py CHANGED
@@ -21,6 +21,30 @@ __all__ = ['cache_video', 'cache_image', 'str2bool']
21
 
22
  from PIL import Image
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  def get_video_frame(file_name, frame_no):
25
  decord.bridge.set_bridge('torch')
26
  reader = decord.VideoReader(file_name)
 
21
 
22
  from PIL import Image
23
 
24
+
25
+ def resample(video_fps, video_frames_count, max_frames, target_fps):
26
+ import math
27
+
28
+ video_frame_duration = 1 /video_fps
29
+ target_frame_duration = 1 / target_fps
30
+
31
+ cur_time = 0
32
+ target_time = 0
33
+ frame_no = 0
34
+ frame_ids =[]
35
+ while True:
36
+ if max_frames != 0 and len(frame_ids) >= max_frames:
37
+ break
38
+ add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
39
+ frame_no += add_frames_count
40
+ frame_ids.append(frame_no)
41
+ cur_time += add_frames_count * video_frame_duration
42
+ target_time += target_frame_duration
43
+ if frame_no >= video_frames_count -1:
44
+ break
45
+ frame_ids = frame_ids[:video_frames_count]
46
+ return frame_ids
47
+
48
  def get_video_frame(file_name, frame_no):
49
  decord.bridge.set_bridge('torch')
50
  reader = decord.VideoReader(file_name)
wan/utils/vace_preprocessor.py CHANGED
@@ -180,26 +180,17 @@ class VaceVideoProcessor(object):
180
  ), axis=1).tolist()
181
  return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
182
 
183
- def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
184
- import math
185
- target_fps = self.max_fps
186
- video_frames_count = len(frame_timestamps)
187
- video_frame_duration = 1 /fps
188
- target_frame_duration = 1 / target_fps
189
 
190
- cur_time = 0
191
- target_time = 0
192
- frame_no = 0
193
- frame_ids =[]
194
- for i in range(max_frames):
195
- add_frames_count = math.ceil( (target_time -cur_time) / video_frame_duration )
196
- frame_no += add_frames_count
197
- frame_ids.append(frame_no)
198
- cur_time += add_frames_count * video_frame_duration
199
- target_time += target_frame_duration
200
- if frame_no >= video_frames_count -1:
201
- break
202
- frame_ids = frame_ids[:video_frames_count]
203
  x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
204
  h, w = y2 - y1, x2 - x1
205
  ratio = h / w
@@ -235,11 +226,11 @@ class VaceVideoProcessor(object):
235
 
236
  return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
237
 
238
- def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng, max_frames= 0):
239
  if self.keep_last:
240
- return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
241
  else:
242
- return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng, max_frames= max_frames)
243
 
244
  def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
245
  return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
@@ -253,23 +244,37 @@ class VaceVideoProcessor(object):
253
  import decord
254
  decord.bridge.set_bridge('torch')
255
  readers = []
 
256
  for data_k in data_key_batch:
257
- reader = decord.VideoReader(data_k)
258
- readers.append(reader)
259
-
260
- fps = readers[0].get_avg_fps()
261
- length = min([len(r) for r in readers])
262
- frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
263
- frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
264
- # # frame_timestamps = frame_timestamps[ :max_frames]
265
- # if trim_video > 0:
266
- # frame_timestamps = frame_timestamps[ :trim_video]
 
 
 
 
 
 
 
267
  max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
268
- h, w = readers[0].next().shape[:2]
269
- frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng, max_frames=max_frames)
 
 
 
 
270
 
271
  # preprocess video
272
  videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
 
 
273
  videos = [self._video_preprocess(video, oh, ow) for video in videos]
274
  return *videos, frame_ids, (oh, ow), fps
275
  # return videos if len(videos) > 1 else videos[0]
 
180
  ), axis=1).tolist()
181
  return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
182
 
183
+
 
 
 
 
 
184
 
185
+ def _get_frameid_bbox_adjust_last(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
186
+ from wan.utils.utils import resample
187
+
188
+ target_fps = self.max_fps
189
+
190
+ # video_frames_count = len(frame_timestamps)
191
+
192
+ frame_ids= resample(fps, video_frames_count, max_frames, target_fps)
193
+
 
 
 
 
194
  x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
195
  h, w = y2 - y1, x2 - x1
196
  ratio = h / w
 
226
 
227
  return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
228
 
229
+ def _get_frameid_bbox(self, fps, video_frames_count, h, w, crop_box, rng, max_frames= 0):
230
  if self.keep_last:
231
+ return self._get_frameid_bbox_adjust_last(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
232
  else:
233
+ return self._get_frameid_bbox_default(fps, video_frames_count, h, w, crop_box, rng, max_frames= max_frames)
234
 
235
  def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
236
  return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
 
244
  import decord
245
  decord.bridge.set_bridge('torch')
246
  readers = []
247
+ src_video = None
248
  for data_k in data_key_batch:
249
+ if torch.is_tensor(data_k):
250
+ src_video = data_k
251
+ else:
252
+ reader = decord.VideoReader(data_k)
253
+ readers.append(reader)
254
+
255
+ if src_video != None:
256
+ fps = 16
257
+ length = src_video.shape[1]
258
+ if len(readers) > 0:
259
+ min_readers = min([len(r) for r in readers])
260
+ length = min(length, min_readers )
261
+ else:
262
+ fps = readers[0].get_avg_fps()
263
+ length = min([len(r) for r in readers])
264
+ # frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
265
+ # frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
266
  max_frames = min(max_frames, trim_video) if trim_video > 0 else max_frames
267
+ if src_video != None:
268
+ src_video = src_video[:max_frames]
269
+ h, w = src_video.shape[1:3]
270
+ else:
271
+ h, w = readers[0].next().shape[:2]
272
+ frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, length, h, w, crop_box, rng, max_frames=max_frames)
273
 
274
  # preprocess video
275
  videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
276
+ if src_video != None:
277
+ videos = [src_video] + videos
278
  videos = [self._video_preprocess(video, oh, ow) for video in videos]
279
  return *videos, frame_ids, (oh, ow), fps
280
  # return videos if len(videos) > 1 else videos[0]
wgp.py CHANGED
@@ -141,12 +141,27 @@ def process_prompt_and_add_tasks(state, model_choice):
141
  res = VACE_SIZE_CONFIGS.keys().join(" and ")
142
  gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
143
  return
144
- if not "I" in video_prompt_type:
 
 
 
 
145
  image_refs = None
146
- if not "V" in video_prompt_type:
 
 
 
 
147
  video_guide = None
148
- if not "M" in video_prompt_type:
 
 
 
 
149
  video_mask = None
 
 
 
150
 
151
  if isinstance(image_refs, list):
152
  image_refs = [ convert_image(tup[0]) for tup in image_refs ]
@@ -260,7 +275,7 @@ def add_video_task(**inputs):
260
  queue = gen["queue"]
261
  task_id += 1
262
  current_task_id = task_id
263
- inputs_to_query = ["image_start", "image_end", "image_refs", "video_guide", "video_mask"]
264
  start_image_data = None
265
  end_image_data = None
266
  for name in inputs_to_query:
@@ -718,7 +733,7 @@ if not Path(server_config_filename).is_file():
718
  "transformer_types": [],
719
  "transformer_quantization": "int8",
720
  "text_encoder_filename" : text_encoder_choices[1],
721
- "save_path": os.path.join(os.getcwd(), "gradio_outputs"),
722
  "compile" : "",
723
  "metadata_type": "metadata",
724
  "default_ui": "t2v",
@@ -726,7 +741,7 @@ if not Path(server_config_filename).is_file():
726
  "clear_file_list" : 0,
727
  "vae_config": 0,
728
  "profile" : profile_type.LowRAM_LowVRAM,
729
- "reload_model": 2 }
730
 
731
  with open(server_config_filename, "w", encoding="utf-8") as writer:
732
  writer.write(json.dumps(server_config))
@@ -860,7 +875,7 @@ if len(args.vae_config) > 0:
860
  reload_needed = False
861
  default_ui = server_config.get("default_ui", "t2v")
862
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
863
- reload_model = server_config.get("reload_model", 2)
864
 
865
 
866
  if args.t2v_14B or args.t2v:
@@ -962,8 +977,8 @@ def download_models(transformer_filename, text_encoder_filename):
962
 
963
  from huggingface_hub import hf_hub_download, snapshot_download
964
  repoId = "DeepBeepMeep/Wan2.1"
965
- sourceFolderList = ["xlm-roberta-large", "", ]
966
- fileList = [ [], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
967
  targetRoot = "ckpts/"
968
  for sourceFolder, files in zip(sourceFolderList,fileList ):
969
  if len(files)==0:
@@ -1166,7 +1181,7 @@ def load_models(model_filename):
1166
 
1167
  return wan_model, offloadobj, pipe["transformer"]
1168
 
1169
- if reload_model ==3 or reload_model ==4:
1170
  wan_model, offloadobj, transformer = None, None, None
1171
  reload_needed = True
1172
  else:
@@ -1254,7 +1269,7 @@ def apply_changes( state,
1254
  quantization_choice,
1255
  boost_choice = 1,
1256
  clear_file_list = 0,
1257
- reload_choice = 1,
1258
  ):
1259
  if args.lock_config:
1260
  return
@@ -1272,7 +1287,7 @@ def apply_changes( state,
1272
  "transformer_quantization" : quantization_choice,
1273
  "boost" : boost_choice,
1274
  "clear_file_list" : clear_file_list,
1275
- "reload_model" : reload_choice,
1276
  }
1277
 
1278
  if Path(server_config_filename).is_file():
@@ -1295,14 +1310,14 @@ def apply_changes( state,
1295
  if v != v_old:
1296
  changes.append(k)
1297
 
1298
- global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, reload_model, transformer_quantization, transformer_types
1299
  attention_mode = server_config["attention_mode"]
1300
  profile = server_config["profile"]
1301
  compile = server_config["compile"]
1302
  text_encoder_filename = server_config["text_encoder_filename"]
1303
  vae_config = server_config["vae_config"]
1304
  boost = server_config["boost"]
1305
- reload_model = server_config["reload_model"]
1306
  transformer_quantization = server_config["transformer_quantization"]
1307
  transformer_types = server_config["transformer_types"]
1308
  transformer_type = get_model_type(transformer_filename)
@@ -1381,7 +1396,8 @@ def abort_generation(state):
1381
 
1382
  gen["abort"] = True
1383
  gen["extra_orders"] = 0
1384
- wan_model._interrupt= True
 
1385
  msg = "Processing Request to abort Current Generation"
1386
  gr.Info(msg)
1387
  return msg, gr.Button(interactive= False)
@@ -1480,24 +1496,68 @@ def expand_slist(slist, num_inference_steps ):
1480
  return new_slist
1481
  def convert_image(image):
1482
 
1483
- from PIL import ExifTags, ImageOps
1484
  from typing import cast
1485
 
1486
  return cast(Image, ImageOps.exif_transpose(image))
1487
- # image = image.convert('RGB')
1488
- # for orientation in ExifTags.TAGS.keys():
1489
- # if ExifTags.TAGS[orientation]=='Orientation':
1490
- # break
1491
- # exif = image.getexif()
1492
- # return image
1493
- # if not orientation in exif:
1494
- # if exif[orientation] == 3:
1495
- # image=image.rotate(180, expand=True)
1496
- # elif exif[orientation] == 6:
1497
- # image=image.rotate(270, expand=True)
1498
- # elif exif[orientation] == 8:
1499
- # image=image.rotate(90, expand=True)
1500
- # return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1501
 
1502
  def generate_video(
1503
  task_id,
@@ -1551,7 +1611,7 @@ def generate_video(
1551
  # gr.Info("Unable to generate a Video while a new configuration is being applied.")
1552
  # return
1553
 
1554
- if reload_model !=3 and reload_model !=4 :
1555
  while wan_model == None:
1556
  time.sleep(1)
1557
 
@@ -1681,10 +1741,32 @@ def generate_video(
1681
  raise gr.Error("Teacache not supported for this model")
1682
 
1683
  if "Vace" in model_filename:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1684
  src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
1685
  [video_mask],
1686
  [image_refs],
1687
  video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
 
1688
  trim_video=max_frames)
1689
  else:
1690
  src_video, src_mask, src_ref_images = None, None, None
@@ -2539,9 +2621,9 @@ def fill_inputs(state):
2539
 
2540
  return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
2541
 
2542
- def preload_model(state):
2543
  global reload_needed, wan_model, offloadobj
2544
- if reload_model == 1:
2545
  model_filename = state["model_filename"]
2546
  if state["model_filename"] != transformer_filename:
2547
  wan_model = None
@@ -2558,7 +2640,7 @@ def preload_model(state):
2558
 
2559
  def unload_model_if_needed(state):
2560
  global reload_needed, wan_model, offloadobj
2561
- if reload_model == 4:
2562
  if wan_model != None:
2563
  wan_model = None
2564
  if offloadobj is not None:
@@ -2567,7 +2649,39 @@ def unload_model_if_needed(state):
2567
  gc.collect()
2568
  reload_needed= True
2569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2570
 
 
2571
  def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
2572
  global inputs_names #, advanced
2573
 
@@ -2676,19 +2790,36 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2676
  image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
2677
 
2678
  with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
2679
- gr.Markdown("<B>Control conditions: Images References (custom Faces or Objects), Video (Open Pose, Depth maps), Mask (inpainting)")
2680
- video_prompt_type_value= ui_defaults.get("video_prompt_type","I")
2681
- video_prompt_type = gr.Radio( [("Images Ref", "I"),("a Video", "V"), ("Images Refs + a Video", "IV"), ("Video + Video Mask", "VM"), ("Images + Video + Mask", "IVM")], value =video_prompt_type_value, label="Location", show_label= False, scale= 3)
2682
- image_refs = gr.Gallery(
2683
- label="Images Referencse (Custom faces and Objects to be found in the Video)", type ="pil",
2684
- columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value, value= ui_defaults.get("image_refs", None) )
2685
-
2686
- video_guide = gr.Video(label= "Reference Video (an animated Video in the Open Pose format or Depth Map video)", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None) )
2687
- with gr.Row():
2688
- max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Ref. Video (0 = as many as possible)", visible= "V" in video_prompt_type_value, scale = 2 )
2689
- remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Images Ref. Background", visible= "I" in video_prompt_type_value, scale =1 )
 
 
 
 
 
 
 
 
2690
 
2691
- video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None) )
 
 
 
 
 
 
 
 
 
2692
 
2693
 
2694
  advanced_prompt = advanced_ui
@@ -2923,7 +3054,10 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2923
  target_settings = gr.Text(value = "settings", interactive= False, visible= False)
2924
 
2925
  image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
2926
- video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
 
 
 
2927
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
2928
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
2929
  queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
@@ -2967,7 +3101,7 @@ def generate_video_tab(update_form = False, state_dict = None, ui_defaults = Non
2967
  ).then(fn= fill_inputs,
2968
  inputs=[state],
2969
  outputs=gen_inputs + extra_inputs
2970
- ).then(fn= preload_model,
2971
  inputs=[state],
2972
  outputs=[gen_status])
2973
 
@@ -3149,14 +3283,8 @@ def generate_configuration_tab(header, model_choice):
3149
  value=server_config.get("metadata_type", "metadata"),
3150
  label="Metadata Handling"
3151
  )
3152
- reload_choice = gr.Dropdown(
3153
- choices=[
3154
- ("Load Model When Starting the App and Changing Model if Model Changed", 1),
3155
- ("Load Model When Starting the App and Pressing Generate if Model Changed", 2),
3156
- ("Load Model When Pressing Generate if Model Changed", 3),
3157
- ("Load Model When Pressing Generate and Unload Model when Finished", 4),
3158
- ],
3159
- value=server_config.get("reload_model",2),
3160
  label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
3161
  )
3162
 
@@ -3191,7 +3319,7 @@ def generate_configuration_tab(header, model_choice):
3191
  quantization_choice,
3192
  boost_choice,
3193
  clear_file_list_choice,
3194
- reload_choice,
3195
  ],
3196
  outputs= [msg , header, model_choice]
3197
  )
@@ -3201,10 +3329,16 @@ def generate_about_tab():
3201
  gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
3202
  gr.Markdown("Many thanks to:")
3203
  gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
 
3204
  gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
3205
  gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
3206
  gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
3207
- gr.Markdown("- <B>Remade_AI</B> : for creating their awesome Loras collection")
 
 
 
 
 
3208
 
3209
  def generate_info_tab():
3210
  gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
@@ -3231,17 +3365,26 @@ def generate_dropdown_model_list():
3231
  choices= dropdown_choices,
3232
  value= current_model_type,
3233
  show_label= False,
3234
- scale= 2
 
 
3235
  )
3236
 
3237
 
3238
 
3239
  def create_demo():
3240
  css = """
 
 
 
 
 
 
 
3241
  .title-with-lines {
3242
  display: flex;
3243
  align-items: center;
3244
- margin: 30px 0;
3245
  }
3246
  .line {
3247
  flex-grow: 1;
@@ -3462,7 +3605,7 @@ def create_demo():
3462
  pointer-events: none;
3463
  }
3464
  """
3465
- with gr.Blocks(css=css, theme=gr.themes.Soft(primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
3466
  gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
3467
  global model_list
3468
 
 
141
  res = VACE_SIZE_CONFIGS.keys().join(" and ")
142
  gr.Info(f"Video Resolution for Vace model is not supported. Only {res} resolutions are allowed.")
143
  return
144
+ if "I" in video_prompt_type:
145
+ if image_refs == None:
146
+ gr.Info("You must provide at one Refererence Image")
147
+ return
148
+ else:
149
  image_refs = None
150
+ if "V" in video_prompt_type:
151
+ if video_guide == None:
152
+ gr.Info("You must provide a Control Video")
153
+ return
154
+ else:
155
  video_guide = None
156
+ if "M" in video_prompt_type:
157
+ if video_mask == None:
158
+ gr.Info("You must provide a Video Mask ")
159
+ return
160
+ else:
161
  video_mask = None
162
+ if "O" in video_prompt_type and inputs["max_frames"]==0:
163
+ gr.Info(f"In order to extend a video, you need to indicate how many frames you want to reuse in the source video.")
164
+ return
165
 
166
  if isinstance(image_refs, list):
167
  image_refs = [ convert_image(tup[0]) for tup in image_refs ]
 
275
  queue = gen["queue"]
276
  task_id += 1
277
  current_task_id = task_id
278
+ inputs_to_query = ["image_start", "image_end", "video_guide", "image_refs","video_mask"]
279
  start_image_data = None
280
  end_image_data = None
281
  for name in inputs_to_query:
 
733
  "transformer_types": [],
734
  "transformer_quantization": "int8",
735
  "text_encoder_filename" : text_encoder_choices[1],
736
+ "save_path": "outputs", #os.path.join(os.getcwd(),
737
  "compile" : "",
738
  "metadata_type": "metadata",
739
  "default_ui": "t2v",
 
741
  "clear_file_list" : 0,
742
  "vae_config": 0,
743
  "profile" : profile_type.LowRAM_LowVRAM,
744
+ "preload_model_policy": [] }
745
 
746
  with open(server_config_filename, "w", encoding="utf-8") as writer:
747
  writer.write(json.dumps(server_config))
 
875
  reload_needed = False
876
  default_ui = server_config.get("default_ui", "t2v")
877
  save_path = server_config.get("save_path", os.path.join(os.getcwd(), "gradio_outputs"))
878
+ preload_model_policy = server_config.get("preload_model_policy", [])
879
 
880
 
881
  if args.t2v_14B or args.t2v:
 
977
 
978
  from huggingface_hub import hf_hub_download, snapshot_download
979
  repoId = "DeepBeepMeep/Wan2.1"
980
+ sourceFolderList = ["xlm-roberta-large", "pose", "depth", "", ]
981
+ fileList = [ [], [],[], ["Wan2.1_VAE_bf16.safetensors", "models_clip_open-clip-xlm-roberta-large-vit-huge-14-bf16.safetensors", "flownet.pkl" ] + computeList(text_encoder_filename) + computeList(transformer_filename) ]
982
  targetRoot = "ckpts/"
983
  for sourceFolder, files in zip(sourceFolderList,fileList ):
984
  if len(files)==0:
 
1181
 
1182
  return wan_model, offloadobj, pipe["transformer"]
1183
 
1184
+ if not "P" in preload_model_policy:
1185
  wan_model, offloadobj, transformer = None, None, None
1186
  reload_needed = True
1187
  else:
 
1269
  quantization_choice,
1270
  boost_choice = 1,
1271
  clear_file_list = 0,
1272
+ preload_model_policy_choice = 1,
1273
  ):
1274
  if args.lock_config:
1275
  return
 
1287
  "transformer_quantization" : quantization_choice,
1288
  "boost" : boost_choice,
1289
  "clear_file_list" : clear_file_list,
1290
+ "preload_model_policy" : preload_model_policy_choice,
1291
  }
1292
 
1293
  if Path(server_config_filename).is_file():
 
1310
  if v != v_old:
1311
  changes.append(k)
1312
 
1313
+ global attention_mode, profile, compile, transformer_filename, text_encoder_filename, vae_config, boost, lora_dir, reload_needed, preload_model_policy, transformer_quantization, transformer_types
1314
  attention_mode = server_config["attention_mode"]
1315
  profile = server_config["profile"]
1316
  compile = server_config["compile"]
1317
  text_encoder_filename = server_config["text_encoder_filename"]
1318
  vae_config = server_config["vae_config"]
1319
  boost = server_config["boost"]
1320
+ preload_model_policy = server_config["preload_model_policy"]
1321
  transformer_quantization = server_config["transformer_quantization"]
1322
  transformer_types = server_config["transformer_types"]
1323
  transformer_type = get_model_type(transformer_filename)
 
1396
 
1397
  gen["abort"] = True
1398
  gen["extra_orders"] = 0
1399
+ if wan_model != None:
1400
+ wan_model._interrupt= True
1401
  msg = "Processing Request to abort Current Generation"
1402
  gr.Info(msg)
1403
  return msg, gr.Button(interactive= False)
 
1496
  return new_slist
1497
  def convert_image(image):
1498
 
1499
+ from PIL import ImageOps
1500
  from typing import cast
1501
 
1502
  return cast(Image, ImageOps.exif_transpose(image))
1503
+
1504
+
1505
+ def preprocess_video(process_type, height, width, video_in, max_frames):
1506
+
1507
+ from wan.utils.utils import resample
1508
+
1509
+ import decord
1510
+ decord.bridge.set_bridge('torch')
1511
+ reader = decord.VideoReader(video_in)
1512
+
1513
+ fps = reader.get_avg_fps()
1514
+
1515
+ frame_nos = resample(fps, len(reader), max_frames= max_frames, target_fps=16)
1516
+ frames_list = reader.get_batch(frame_nos)
1517
+ frame_height, frame_width, _ = frames_list[0].shape
1518
+
1519
+ scale = ((height * width ) / (frame_height * frame_width))**(1/2)
1520
+ # scale = min(height / frame_height, width / frame_width)
1521
+
1522
+ new_height = (int(frame_height * scale) // 16) * 16
1523
+ new_width = (int(frame_width * scale) // 16) * 16
1524
+
1525
+ processed_frames_list = []
1526
+ for frame in frames_list:
1527
+ frame = Image.fromarray(np.clip(frame.cpu().numpy(), 0, 255).astype(np.uint8))
1528
+ frame = frame.resize((new_width,new_height), resample=Image.Resampling.LANCZOS)
1529
+ processed_frames_list.append(frame)
1530
+
1531
+ if process_type=="pose":
1532
+ from preprocessing.dwpose.pose import PoseBodyFaceVideoAnnotator
1533
+ cfg_dict = {
1534
+ "DETECTION_MODEL": "ckpts/pose/yolox_l.onnx",
1535
+ "POSE_MODEL": "ckpts/pose/dw-ll_ucoco_384.onnx",
1536
+ "RESIZE_SIZE": 1024
1537
+ }
1538
+ anno_ins = PoseBodyFaceVideoAnnotator(cfg_dict)
1539
+ elif process_type=="depth":
1540
+ from preprocessing.midas.depth import DepthVideoAnnotator
1541
+ cfg_dict = {
1542
+ "PRETRAINED_MODEL": "ckpts/depth/dpt_hybrid-midas-501f0c75.pt"
1543
+ }
1544
+ anno_ins = DepthVideoAnnotator(cfg_dict)
1545
+ else:
1546
+ from preprocessing.gray import GrayVideoAnnotator
1547
+ cfg_dict = {}
1548
+ anno_ins = GrayVideoAnnotator(cfg_dict)
1549
+
1550
+ np_frames = anno_ins.forward(processed_frames_list)
1551
+
1552
+ # from preprocessing.dwpose.pose import save_one_video
1553
+ # save_one_video("test.mp4", np_frames, fps=8, quality=8, macro_block_size=None)
1554
+
1555
+ torch_frames = []
1556
+ for np_frame in np_frames:
1557
+ torch_frame = torch.from_numpy(np_frame)
1558
+ torch_frames.append(torch_frame)
1559
+
1560
+ return torch.stack(torch_frames)
1561
 
1562
  def generate_video(
1563
  task_id,
 
1611
  # gr.Info("Unable to generate a Video while a new configuration is being applied.")
1612
  # return
1613
 
1614
+ if "P" in preload_model_policy:
1615
  while wan_model == None:
1616
  time.sleep(1)
1617
 
 
1741
  raise gr.Error("Teacache not supported for this model")
1742
 
1743
  if "Vace" in model_filename:
1744
+ # video_prompt_type = video_prompt_type +"G"
1745
+ if any(process in video_prompt_type for process in ("P", "D", "G")) :
1746
+ prompts_max = gen["prompts_max"]
1747
+
1748
+ status = get_generation_status(prompt_no, prompts_max, 1, 1)
1749
+ preprocess_type = None
1750
+ if "P" in video_prompt_type :
1751
+ progress_args = [0, status + " - Extracting Open Pose Information"]
1752
+ preprocess_type = "pose"
1753
+ elif "D" in video_prompt_type :
1754
+ progress_args = [0, status + " - Extracting Depth Information"]
1755
+ preprocess_type = "depth"
1756
+ elif "G" in video_prompt_type :
1757
+ progress_args = [0, status + " - Extracting Gray Level Information"]
1758
+ preprocess_type = "gray"
1759
+
1760
+ if preprocess_type != None :
1761
+ progress(*progress_args )
1762
+ gen["progress_args"] = progress_args
1763
+ video_guide = preprocess_video(preprocess_type, width=width, height=height,video_in=video_guide, max_frames= video_length)
1764
+
1765
  src_video, src_mask, src_ref_images = wan_model.prepare_source([video_guide],
1766
  [video_mask],
1767
  [image_refs],
1768
  video_length, VACE_SIZE_CONFIGS[resolution_reformated], "cpu",
1769
+ original_video= "O" in video_prompt_type,
1770
  trim_video=max_frames)
1771
  else:
1772
  src_video, src_mask, src_ref_images = None, None, None
 
2621
 
2622
  return generate_video_tab(update_form = True, state_dict = state, ui_defaults = ui_defaults)
2623
 
2624
+ def preload_model_when_switching(state):
2625
  global reload_needed, wan_model, offloadobj
2626
+ if "S" in preload_model_policy:
2627
  model_filename = state["model_filename"]
2628
  if state["model_filename"] != transformer_filename:
2629
  wan_model = None
 
2640
 
2641
  def unload_model_if_needed(state):
2642
  global reload_needed, wan_model, offloadobj
2643
+ if "U" in preload_model_policy:
2644
  if wan_model != None:
2645
  wan_model = None
2646
  if offloadobj is not None:
 
2649
  gc.collect()
2650
  reload_needed= True
2651
 
2652
+ def filter_letters(source_str, letters):
2653
+ ret = ""
2654
+ for letter in letters:
2655
+ if letter in source_str:
2656
+ ret += letter
2657
+ return ret
2658
+
2659
+ def add_to_sequence(source_str, letters):
2660
+ ret = source_str
2661
+ for letter in letters:
2662
+ if not letter in source_str:
2663
+ ret += letter
2664
+ return ret
2665
+
2666
+ def del_in_sequence(source_str, letters):
2667
+ ret = source_str
2668
+ for letter in letters:
2669
+ if letter in source_str:
2670
+ ret = ret.replace(letter, "")
2671
+ return ret
2672
+
2673
+
2674
+ def refresh_video_prompt_type_image_refs(video_prompt_type, video_prompt_type_image_refs):
2675
+ video_prompt_type = add_to_sequence(video_prompt_type, "I") if video_prompt_type_image_refs else del_in_sequence(video_prompt_type, "I")
2676
+ return video_prompt_type, gr.update(visible = video_prompt_type_image_refs),gr.update(visible = video_prompt_type_image_refs)
2677
+
2678
+ def refresh_video_prompt_type_video_guide(video_prompt_type, video_prompt_type_video_guide):
2679
+ video_prompt_type = del_in_sequence(video_prompt_type, "ODPCMV")
2680
+ video_prompt_type = add_to_sequence(video_prompt_type, video_prompt_type_video_guide)
2681
+ visible = "V" in video_prompt_type
2682
+ return video_prompt_type, gr.update(visible = visible), gr.update(visible = visible), gr.update(visible= "M" in video_prompt_type )
2683
 
2684
+
2685
  def generate_video_tab(update_form = False, state_dict = None, ui_defaults = None, model_choice = None, header = None):
2686
  global inputs_names #, advanced
2687
 
 
2790
  image_end = gr.Image(label= "Last Image for a new video", type ="pil", visible="E" in image_prompt_type_value, value= ui_defaults.get("image_end", None))
2791
 
2792
  with gr.Column(visible= "Vace" in model_filename ) as video_prompt_column:
2793
+ video_prompt_type_value= ui_defaults.get("video_prompt_type","")
2794
+ video_prompt_type = gr.Text(value= video_prompt_type_value, visible= False)
2795
+ video_prompt_type_video_guide = gr.Dropdown(
2796
+ choices=[
2797
+ ("None, use only the Text Prompt", ""),
2798
+ ("Extend the Control Video", "OV"),
2799
+ ("Transfer Human Motion from the Control Video", "PV"),
2800
+ ("Transfer Depth from the Control Video", "DV"),
2801
+ ("Recolorize the Control Video", "CV"),
2802
+ ("Control Video contains Open Pose, Depth or Black & White ", "V"),
2803
+ ("Inpainting of Control Video using Mask Video ", "MV"),
2804
+ ],
2805
+ value=filter_letters(video_prompt_type_value, "ODPCMV"),
2806
+ label="Video to Video"
2807
+ )
2808
+ video_prompt_type_image_refs = gr.Checkbox(value="I" in video_prompt_type_value , label= "Use References Images (Faces, Objects) to customize New Video", scale =1 )
2809
+
2810
+ video_guide = gr.Video(label= "Control Video", visible= "V" in video_prompt_type_value, value= ui_defaults.get("video_guide", None),)
2811
+ max_frames = gr.Slider(0, 100, value=ui_defaults.get("max_frames",0), step=1, label="Nb of frames in Control Video to use (0 = max)", visible= "V" in video_prompt_type_value, scale = 2 )
2812
 
2813
+ image_refs = gr.Gallery( label ="Reference Images",
2814
+ type ="pil", show_label= True,
2815
+ columns=[3], rows=[1], object_fit="contain", height="auto", selected_index=0, interactive= True, visible= "I" in video_prompt_type_value,
2816
+ value= ui_defaults.get("image_refs", None) )
2817
+
2818
+ # with gr.Row():
2819
+ remove_background_image_ref = gr.Checkbox(value=ui_defaults.get("remove_background_image_ref",1), label= "Remove Background of Images References", visible= "I" in video_prompt_type_value, scale =1 )
2820
+
2821
+
2822
+ video_mask = gr.Video(label= "Video Mask (for Inpainting or Outpaing, white pixels = Mask)", visible= "M" in video_prompt_type_value, value= ui_defaults.get("video_mask", None))
2823
 
2824
 
2825
  advanced_prompt = advanced_ui
 
3054
  target_settings = gr.Text(value = "settings", interactive= False, visible= False)
3055
 
3056
  image_prompt_type.change(fn=refresh_image_prompt_type, inputs=[state, image_prompt_type], outputs=[image_start, image_end])
3057
+ # video_prompt_type.change(fn=refresh_video_prompt_type, inputs=[state, video_prompt_type], outputs=[image_refs, video_guide, video_mask, max_frames, remove_background_image_ref])
3058
+ video_prompt_type_image_refs.input(fn=refresh_video_prompt_type_image_refs, inputs = [video_prompt_type, video_prompt_type_image_refs], outputs = [video_prompt_type, image_refs, remove_background_image_ref ])
3059
+ video_prompt_type_video_guide.input(fn=refresh_video_prompt_type_video_guide, inputs = [video_prompt_type, video_prompt_type_video_guide], outputs = [video_prompt_type, video_guide, max_frames, video_mask])
3060
+
3061
  show_advanced.change(fn=switch_advanced, inputs=[state, show_advanced, lset_name], outputs=[advanced_row, preset_buttons_rows, refresh_lora_btn, refresh2_row ,lset_name ]).then(
3062
  fn=switch_prompt_type, inputs = [state, wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, *prompt_vars], outputs = [wizard_prompt_activated_var, wizard_variables_var, prompt, wizard_prompt, prompt_column_advanced, prompt_column_wizard, prompt_column_wizard_vars, *prompt_vars])
3063
  queue_df.select( fn=handle_celll_selection, inputs=state, outputs=[queue_df, modal_image_display, modal_container])
 
3101
  ).then(fn= fill_inputs,
3102
  inputs=[state],
3103
  outputs=gen_inputs + extra_inputs
3104
+ ).then(fn= preload_model_when_switching,
3105
  inputs=[state],
3106
  outputs=[gen_status])
3107
 
 
3283
  value=server_config.get("metadata_type", "metadata"),
3284
  label="Metadata Handling"
3285
  )
3286
+ preload_model_policy_choice = gr.CheckboxGroup([("Preload Model while Launching the App","P"), ("Preload Model while Switching Model", "S"), ("Unload Model when Queue is Done", "U")],
3287
+ value=server_config.get("preload_model_policy",[]),
 
 
 
 
 
 
3288
  label="RAM Loading / Unloading Model Policy (in any case VRAM will be freed once the queue has been processed)"
3289
  )
3290
 
 
3319
  quantization_choice,
3320
  boost_choice,
3321
  clear_file_list_choice,
3322
+ preload_model_policy_choice,
3323
  ],
3324
  outputs= [msg , header, model_choice]
3325
  )
 
3329
  gr.Markdown("Original Wan 2.1 Model by <B>Alibaba</B> (<A HREF='https://github.com/Wan-Video/Wan2.1'>GitHub</A>)")
3330
  gr.Markdown("Many thanks to:")
3331
  gr.Markdown("- <B>Alibaba Wan team for the best open source video generator")
3332
+ gr.Markdown("- <B>Alibaba Vace and Fun Teams for their incredible control net models")
3333
  gr.Markdown("- <B>Cocktail Peanuts</B> : QA and simple installation via Pinokio.computer")
3334
  gr.Markdown("- <B>Tophness</B> : created multi tabs and queuing frameworks")
3335
  gr.Markdown("- <B>AmericanPresidentJimmyCarter</B> : added original support for Skip Layer Guidance")
3336
+ gr.Markdown("- <B>Remade_AI</B> : for their awesome Loras collection")
3337
+ gr.Markdown("<BR>Huge acknowlegments to these great open source projects used in WanGP:")
3338
+ gr.Markdown("- <B>Rife</B>: temporal upsampler (https://github.com/hzwer/ECCV2022-RIFE)")
3339
+ gr.Markdown("- <B>DwPose</B>: Open Pose extractor (https://github.com/IDEA-Research/DWPose)")
3340
+ gr.Markdown("- <B>Midas</B>: Depth extractor (https://github.com/isl-org/MiDaS")
3341
+
3342
 
3343
  def generate_info_tab():
3344
  gr.Markdown("<FONT SIZE=3>Welcome to WanGP a super fast and low VRAM AI Video Generator !</FONT>")
 
3365
  choices= dropdown_choices,
3366
  value= current_model_type,
3367
  show_label= False,
3368
+ scale= 2,
3369
+ elem_id="model_list",
3370
+ elem_classes="model_list_class",
3371
  )
3372
 
3373
 
3374
 
3375
  def create_demo():
3376
  css = """
3377
+ #model_list{
3378
+ background-color:black;
3379
+ padding:1px}
3380
+
3381
+ #model_list input {
3382
+ font-size:25px}
3383
+
3384
  .title-with-lines {
3385
  display: flex;
3386
  align-items: center;
3387
+ margin: 25px 0;
3388
  }
3389
  .line {
3390
  flex-grow: 1;
 
3605
  pointer-events: none;
3606
  }
3607
  """
3608
+ with gr.Blocks(css=css, theme=gr.themes.Soft(font=["Verdana"], primary_hue="sky", neutral_hue="slate", text_size="md"), title= "Wan2GP") as demo:
3609
  gr.Markdown("<div align=center><H1>Wan<SUP>GP</SUP> v4.0 <FONT SIZE=4>by <I>DeepBeepMeep</I></FONT> <FONT SIZE=3>") # (<A HREF='https://github.com/deepbeepmeep/Wan2GP'>Updates</A>)</FONT SIZE=3></H1></div>")
3610
  global model_list
3611