IceClear commited on
Commit
56d0abb
·
1 Parent(s): f61c793

delete degrad

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: SeedVR2-3B
3
- emoji: 🚀
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
 
1
  ---
2
  title: SeedVR2-3B
3
+ emoji: 🎥
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: gradio
projects/video_diffusion_sr/degradation_utils.py DELETED
@@ -1,522 +0,0 @@
1
- # Copyright (c) 2022 BasicSR: Xintao Wang and Liangbin Xie and Ke Yu and Kelvin C.K. Chan and Chen Change Loy and Chao Dong
2
- # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates.
3
- # SPDX-License-Identifier: Apache License, Version 2.0 (the "License")
4
- #
5
- # This file has been modified by ByteDance Ltd. and/or its affiliates. on 1st June 2025
6
- #
7
- # Original file was released under Apache License, Version 2.0 (the "License"), with the full license text
8
- # available at http://www.apache.org/licenses/LICENSE-2.0.
9
- #
10
- # This modified file is released under the same license.
11
-
12
- import io
13
- import math
14
- import random
15
- from typing import Dict
16
- import av
17
- import numpy as np
18
- import torch
19
- from basicsr.data.degradations import (
20
- circular_lowpass_kernel,
21
- random_add_gaussian_noise_pt,
22
- random_add_poisson_noise_pt,
23
- random_mixed_kernels,
24
- )
25
- from basicsr.utils import DiffJPEG, USMSharp
26
- from basicsr.utils.img_process_util import filter2D
27
- from einops import rearrange
28
- from torch import nn
29
- from torch.nn import functional as F
30
-
31
-
32
- def remove_prefix(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
33
- for k in list(state_dict.keys()):
34
- if k.startswith("_flops_wrap_module."):
35
- v = state_dict.pop(k)
36
- state_dict[k.replace("_flops_wrap_module.", "")] = v
37
- if k.startswith("module."):
38
- v = state_dict.pop(k)
39
- state_dict[k.replace("module.", "")] = v
40
- return state_dict
41
-
42
-
43
- def clean_memory_bank(module: nn.Module):
44
- if hasattr(module, "padding_bank"):
45
- module.padding_bank = None
46
- for child in module.children():
47
- clean_memory_bank(child)
48
-
49
-
50
- para_dic = {
51
- "kernel_list": [
52
- "iso",
53
- "aniso",
54
- "generalized_iso",
55
- "generalized_aniso",
56
- "plateau_iso",
57
- "plateau_aniso",
58
- ],
59
- "kernel_prob": [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
60
- "sinc_prob": 0.1,
61
- "blur_sigma": [0.2, 1.5],
62
- "betag_range": [0.5, 2.0],
63
- "betap_range": [1, 1.5],
64
- "kernel_list2": [
65
- "iso",
66
- "aniso",
67
- "generalized_iso",
68
- "generalized_aniso",
69
- "plateau_iso",
70
- "plateau_aniso",
71
- ],
72
- "kernel_prob2": [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
73
- "sinc_prob2": 0.1,
74
- "blur_sigma2": [0.2, 1.0],
75
- "betag_range2": [0.5, 2.0],
76
- "betap_range2": [1, 1.5],
77
- "final_sinc_prob": 0.5,
78
- }
79
-
80
- degrade_dic = {
81
- # "gt_usm": True, # USM the ground-truth
82
- # the first degradation process
83
- "resize_prob": [0.2, 0.7, 0.1], # up, down, keep
84
- "resize_range": [0.3, 1.5],
85
- "gaussian_noise_prob": 0.5,
86
- "noise_range": [1, 15],
87
- "poisson_scale_range": [0.05, 2],
88
- "gray_noise_prob": 0.4,
89
- "jpeg_range": [60, 95],
90
- # the second degradation process
91
- "second_blur_prob": 0.5,
92
- "resize_prob2": [0.3, 0.4, 0.3], # up, down, keep
93
- "resize_range2": [0.6, 1.2],
94
- "gaussian_noise_prob2": 0.5,
95
- "noise_range2": [1, 12],
96
- "poisson_scale_range2": [0.05, 1.0],
97
- "gray_noise_prob2": 0.4,
98
- "jpeg_range2": [60, 95],
99
- "queue_size": 180,
100
- "scale": 4, # output size: ori_h // scale
101
- "sharpen": False,
102
- }
103
-
104
-
105
- def set_para(para_dic):
106
- # blur settings for the first degradation
107
- # blur_kernel_size = opt['blur_kernel_size']
108
- kernel_list = para_dic["kernel_list"]
109
- kernel_prob = para_dic["kernel_prob"]
110
- blur_sigma = para_dic["blur_sigma"]
111
- betag_range = para_dic["betag_range"]
112
- betap_range = para_dic["betap_range"]
113
- sinc_prob = para_dic["sinc_prob"]
114
-
115
- # blur settings for the second degradation
116
- # blur_kernel_size2 = opt['blur_kernel_size2']
117
- kernel_list2 = para_dic["kernel_list2"]
118
- kernel_prob2 = para_dic["kernel_prob2"]
119
- blur_sigma2 = para_dic["blur_sigma2"]
120
- betag_range2 = para_dic["betag_range2"]
121
- betap_range2 = para_dic["betap_range2"]
122
- sinc_prob2 = para_dic["sinc_prob2"]
123
-
124
- # a final sinc filter
125
- final_sinc_prob = para_dic["final_sinc_prob"]
126
-
127
- kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21
128
- pulse_tensor = torch.zeros(
129
- 21, 21
130
- ).float() # convolving with pulse tensor brings no blurry effect
131
- pulse_tensor[10, 10] = 1
132
- kernel_size = random.choice(kernel_range)
133
- if np.random.uniform() < sinc_prob:
134
- # this sinc filter setting is for kernels ranging from [7, 21]
135
- if kernel_size < 13:
136
- omega_c = np.random.uniform(np.pi / 3, np.pi)
137
- else:
138
- omega_c = np.random.uniform(np.pi / 5, np.pi)
139
- kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
140
- else:
141
- kernel = random_mixed_kernels(
142
- kernel_list,
143
- kernel_prob,
144
- kernel_size,
145
- blur_sigma,
146
- blur_sigma,
147
- [-math.pi, math.pi],
148
- betag_range,
149
- betap_range,
150
- noise_range=None,
151
- )
152
- # pad kernel
153
- pad_size = (21 - kernel_size) // 2
154
- kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
155
-
156
- # ------------------------ Generate kernels (used in the second degradation) -------------- #
157
- kernel_size = random.choice(kernel_range)
158
- if np.random.uniform() < sinc_prob2:
159
- if kernel_size < 13:
160
- omega_c = np.random.uniform(np.pi / 3, np.pi)
161
- else:
162
- omega_c = np.random.uniform(np.pi / 5, np.pi)
163
- kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
164
- else:
165
- kernel2 = random_mixed_kernels(
166
- kernel_list2,
167
- kernel_prob2,
168
- kernel_size,
169
- blur_sigma2,
170
- blur_sigma2,
171
- [-math.pi, math.pi],
172
- betag_range2,
173
- betap_range2,
174
- noise_range=None,
175
- )
176
- pad_size = (21 - kernel_size) // 2
177
- kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size)))
178
-
179
- # ------------------------------------- sinc kernel ------------------------------------- #
180
- if np.random.uniform() < final_sinc_prob:
181
- kernel_size = random.choice(kernel_range)
182
- omega_c = np.random.uniform(np.pi / 3, np.pi)
183
- sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21)
184
- sinc_kernel = torch.FloatTensor(sinc_kernel)
185
- else:
186
- sinc_kernel = pulse_tensor
187
- kernel = torch.FloatTensor(kernel)
188
- kernel2 = torch.FloatTensor(kernel2)
189
- return_d = {"kernel1": kernel, "kernel2": kernel2, "sinc_kernel": sinc_kernel}
190
- return return_d
191
-
192
-
193
- def print_stat(a):
194
- print(
195
- f"shape={a.shape}, min={a.min():.2f}, \
196
- max={a.max():.2f}, var={a.var():.2f}, {a.flatten()[0]}"
197
- )
198
-
199
-
200
- @torch.no_grad()
201
- def esr_blur_gpu(image, paras, usm_sharpener, jpeger, device="cpu"):
202
- """
203
- input and output: image is a tensor with shape: b f c h w, range (-1, 1)
204
- """
205
- video_length = image.shape[1]
206
- image = rearrange(image, "b f c h w -> (b f) c h w").to(device)
207
- image = (image + 1) * 0.5
208
- if degrade_dic["sharpen"]:
209
- gt_usm = usm_sharpener(image)
210
- else:
211
- gt_usm = image
212
- ori_h, ori_w = image.size()[2:4]
213
- # ----------------------- The first degradation process ----------------------- #
214
- # blur
215
- out = filter2D(gt_usm, paras["kernel1"].unsqueeze(0).to(device))
216
- # random resize
217
- updown_type = random.choices(["up", "down", "keep"], degrade_dic["resize_prob"])[0]
218
- if updown_type == "up":
219
- scale = np.random.uniform(1, degrade_dic["resize_range"][1])
220
- elif updown_type == "down":
221
- scale = np.random.uniform(degrade_dic["resize_range"][0], 1)
222
- else:
223
- scale = 1
224
- mode = random.choice(["area", "bilinear", "bicubic"])
225
- out = F.interpolate(out, scale_factor=scale, mode=mode)
226
- # noise
227
- gray_noise_prob = degrade_dic["gray_noise_prob"]
228
- out = out.to(torch.float32)
229
- if np.random.uniform() < degrade_dic["gaussian_noise_prob"]:
230
- out = random_add_gaussian_noise_pt(
231
- out,
232
- # video_length=video_length,
233
- sigma_range=degrade_dic["noise_range"],
234
- clip=True,
235
- rounds=False,
236
- gray_prob=gray_noise_prob,
237
- )
238
- else:
239
- out = random_add_poisson_noise_pt(
240
- out,
241
- # video_length=video_length,
242
- scale_range=degrade_dic["poisson_scale_range"],
243
- gray_prob=gray_noise_prob,
244
- clip=True,
245
- rounds=False,
246
- )
247
- # out = out.to(torch.bfloat16)
248
-
249
- # JPEG compression
250
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*degrade_dic["jpeg_range"])
251
- out = torch.clamp(out, 0, 1)
252
- out = jpeger(out, quality=jpeg_p)
253
-
254
- # Video compression 1
255
- # print('Video compression 1')
256
-
257
- # print_stat(out)
258
- if video_length > 1:
259
- out = video_compression(out, device=device)
260
- # print('After video compression 1')
261
-
262
- # print_stat(out)
263
-
264
- # ----------------------- The second degradation process ----------------------- #
265
- # blur
266
- if np.random.uniform() < degrade_dic["second_blur_prob"]:
267
- out = filter2D(out, paras["kernel2"].unsqueeze(0).to(device))
268
- # random resize
269
- updown_type = random.choices(["up", "down", "keep"], degrade_dic["resize_prob2"])[0]
270
- if updown_type == "up":
271
- scale = np.random.uniform(1, degrade_dic["resize_range2"][1])
272
- elif updown_type == "down":
273
- scale = np.random.uniform(degrade_dic["resize_range2"][0], 1)
274
- else:
275
- scale = 1
276
- mode = random.choice(["area", "bilinear", "bicubic"])
277
- out = F.interpolate(
278
- out,
279
- size=(
280
- int(ori_h / degrade_dic["scale"] * scale),
281
- int(ori_w / degrade_dic["scale"] * scale),
282
- ),
283
- mode=mode,
284
- )
285
- # noise
286
- gray_noise_prob = degrade_dic["gray_noise_prob2"]
287
- out = out.to(torch.float32)
288
- if np.random.uniform() < degrade_dic["gaussian_noise_prob2"]:
289
- out = random_add_gaussian_noise_pt(
290
- out,
291
- # video_length=video_length,
292
- sigma_range=degrade_dic["noise_range2"],
293
- clip=True,
294
- rounds=False,
295
- gray_prob=gray_noise_prob,
296
- )
297
- else:
298
- out = random_add_poisson_noise_pt(
299
- out,
300
- # video_length=video_length,
301
- scale_range=degrade_dic["poisson_scale_range2"],
302
- gray_prob=gray_noise_prob,
303
- clip=True,
304
- rounds=False,
305
- )
306
- # out = out.to(torch.bfloat16)
307
-
308
- if np.random.uniform() < 0.5:
309
- # resize back + the final sinc filter
310
- mode = random.choice(["area", "bilinear", "bicubic"])
311
- out = F.interpolate(
312
- out, size=(ori_h // degrade_dic["scale"], ori_w // degrade_dic["scale"]), mode=mode
313
- )
314
- out = filter2D(out, paras["sinc_kernel"].unsqueeze(0).to(device))
315
- # JPEG compression
316
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*degrade_dic["jpeg_range2"])
317
- out = torch.clamp(out, 0, 1)
318
- out = jpeger(out, quality=jpeg_p)
319
- else:
320
- # JPEG compression
321
- jpeg_p = out.new_zeros(out.size(0)).uniform_(*degrade_dic["jpeg_range2"])
322
- out = torch.clamp(out, 0, 1)
323
- out = jpeger(out, quality=jpeg_p)
324
- # resize back + the final sinc filter
325
- mode = random.choice(["area", "bilinear", "bicubic"])
326
- out = F.interpolate(
327
- out, size=(ori_h // degrade_dic["scale"], ori_w // degrade_dic["scale"]), mode=mode
328
- )
329
- out = filter2D(out, paras["sinc_kernel"].unsqueeze(0).to(device))
330
-
331
- # print('Video compression 2')
332
-
333
- # print_stat(out)
334
- if video_length > 1:
335
- out = video_compression(out, device=device)
336
- # print('After video compression 2')
337
-
338
- # print_stat(out)
339
-
340
- out = F.interpolate(out, size=(ori_h, ori_w), mode="bicubic")
341
- blur_image = torch.clamp(out, 0, 1)
342
- # blur_image = ColorJitter(0.1, 0.1, 0.1, 0.05)(blur_image) # 颜色数据增广
343
- # (-1, 1)
344
- blur_image = 2.0 * blur_image - 1
345
- blur_image = rearrange(blur_image, "(b f) c h w->b f c h w", f=video_length)
346
- return blur_image
347
-
348
-
349
- def video_compression(video_in, device="cpu"):
350
- # Shape: (t, c, h, w); channel order: RGB; image range: [0, 1], float32.
351
-
352
- video_in = torch.clamp(video_in, 0, 1)
353
- params = dict(
354
- codec=["libx264", "h264", "mpeg4"],
355
- codec_prob=[1 / 3.0, 1 / 3.0, 1 / 3.0],
356
- bitrate=[1e4, 1e5],
357
- ) # 1e4, 1e5
358
- codec = random.choices(params["codec"], params["codec_prob"])[0]
359
- # print(f"use codec {codec}")
360
-
361
- bitrate = params["bitrate"]
362
- bitrate = np.random.randint(bitrate[0], bitrate[1] + 1)
363
-
364
- h, w = video_in.shape[-2:]
365
- video_in = F.interpolate(video_in, (h // 2 * 2, w // 2 * 2), mode="bilinear")
366
-
367
- buf = io.BytesIO()
368
- with av.open(buf, "w", "mp4") as container:
369
- stream = container.add_stream(codec, rate=1)
370
- stream.height = video_in.shape[-2]
371
- stream.width = video_in.shape[-1]
372
- stream.pix_fmt = "yuv420p"
373
- stream.bit_rate = bitrate
374
-
375
- for img in video_in: # img: C H W; 0-1
376
- img_np = img.permute(1, 2, 0).contiguous() * 255.0
377
- # 1 reference_np = reference.detach(). to (torch.float) .cpu() .numpy ()
378
- img_np = img_np.detach().to(torch.float).cpu().numpy().astype(np.uint8)
379
- frame = av.VideoFrame.from_ndarray(img_np, format="rgb24")
380
- frame.pict_type = "NONE"
381
- for packet in stream.encode(frame):
382
- container.mux(packet)
383
-
384
- # Flush stream
385
- for packet in stream.encode():
386
- container.mux(packet)
387
-
388
- outputs = []
389
- with av.open(buf, "r", "mp4") as container:
390
- if container.streams.video:
391
- for frame in container.decode(**{"video": 0}):
392
- outputs.append(frame.to_rgb().to_ndarray().astype(np.float32))
393
-
394
- video_in = torch.Tensor(np.array(outputs)).permute(0, 3, 1, 2).contiguous() # T C H W
395
- video_in = torch.clamp(video_in / 255.0, 0, 1).to(device) # 0-1
396
- return video_in
397
-
398
-
399
- @torch.no_grad()
400
- def my_esr_blur(images, device="cpu"):
401
- """
402
- images is a list of tensor with shape: b f c h w, range (-1, 1)
403
- """
404
- jpeger = DiffJPEG(differentiable=False).to(device)
405
- usm_sharpener = USMSharp()
406
- if degrade_dic["sharpen"]:
407
- usm_sharpener = usm_sharpener.to(device)
408
- paras = set_para(para_dic)
409
- blur_image = [
410
- esr_blur_gpu(image, paras, usm_sharpener, jpeger, device=device) for image in images
411
- ]
412
-
413
- return blur_image
414
-
415
-
416
- para_dic_latent = {
417
- "kernel_list": [
418
- "iso",
419
- "aniso",
420
- "generalized_iso",
421
- "generalized_aniso",
422
- "plateau_iso",
423
- "plateau_aniso",
424
- ],
425
- "kernel_prob": [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
426
- "sinc_prob": 0.1,
427
- "blur_sigma": [0.2, 1.5],
428
- "betag_range": [0.5, 2.0],
429
- "betap_range": [1, 1.5],
430
- "kernel_list2": [
431
- "iso",
432
- "aniso",
433
- "generalized_iso",
434
- "generalized_aniso",
435
- "plateau_iso",
436
- "plateau_aniso",
437
- ],
438
- "kernel_prob2": [0.45, 0.25, 0.12, 0.03, 0.12, 0.03],
439
- "sinc_prob2": 0.1,
440
- "blur_sigma2": [0.2, 1.0],
441
- "betag_range2": [0.5, 2.0],
442
- "betap_range2": [1, 1.5],
443
- "final_sinc_prob": 0.5,
444
- }
445
-
446
-
447
- def set_para_latent(para_dic):
448
- # blur settings for the first degradation
449
- # blur_kernel_size = opt['blur_kernel_size']
450
- kernel_list = para_dic["kernel_list"]
451
- kernel_prob = para_dic["kernel_prob"]
452
- blur_sigma = para_dic["blur_sigma"]
453
- betag_range = para_dic["betag_range"]
454
- betap_range = para_dic["betap_range"]
455
- sinc_prob = para_dic["sinc_prob"]
456
-
457
- # a final sinc filter
458
-
459
- kernel_range = [2 * v + 1 for v in range(1, 11)] # kernel size ranges from 7 to 21
460
- pulse_tensor = torch.zeros(
461
- 21, 21
462
- ).float() # convolving with pulse tensor brings no blurry effect
463
- pulse_tensor[10, 10] = 1
464
- kernel_size = random.choice(kernel_range)
465
- if np.random.uniform() < sinc_prob:
466
- # this sinc filter setting is for kernels ranging from [7, 21]
467
- if kernel_size < 13:
468
- omega_c = np.random.uniform(np.pi / 3, np.pi)
469
- else:
470
- omega_c = np.random.uniform(np.pi / 5, np.pi)
471
- kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False)
472
- else:
473
- kernel = random_mixed_kernels(
474
- kernel_list,
475
- kernel_prob,
476
- kernel_size,
477
- blur_sigma,
478
- blur_sigma,
479
- [-math.pi, math.pi],
480
- betag_range,
481
- betap_range,
482
- noise_range=None,
483
- )
484
- # pad kernel
485
- pad_size = (21 - kernel_size) // 2
486
- kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
487
- kernel = torch.FloatTensor(kernel)
488
- return_d = {"kernel1": kernel}
489
- return return_d
490
-
491
-
492
- @torch.no_grad()
493
- def latent_blur_gpu(image, paras, device="cpu"):
494
- """
495
- input and output: image is a tensor with shape: b f c h w, range (-1, 1)
496
- """
497
- video_length = image.shape[1]
498
- image = rearrange(image, "b f c h w -> (b f) c h w").to(device)
499
- image = (image + 1) * 0.5
500
- gt_usm = image
501
- ori_h, ori_w = image.size()[2:4]
502
- # ----------------------- The first degradation process ----------------------- #
503
- # blur
504
- out = filter2D(gt_usm, paras["kernel1"].unsqueeze(0).to(device))
505
- blur_image = torch.clamp(out, 0, 1)
506
- # blur_image = ColorJitter(0.1, 0.1, 0.1, 0.05)(blur_image) # 颜色数据增广
507
- # (-1, 1)
508
- blur_image = 2.0 * blur_image - 1
509
- blur_image = rearrange(blur_image, "(b f) c h w->b f c h w", f=video_length)
510
- return blur_image
511
-
512
-
513
- @torch.no_grad()
514
- def add_latent_blur(images, device="cpu"):
515
- """
516
- images is a list of tensor with shape: b f c h w, range (-1, 1)
517
- """
518
- paras = set_para_latent(para_dic_latent)
519
- blur_image = [latent_blur_gpu(image, paras, device=device) for image in images]
520
- print("apply blur to the latents")
521
-
522
- return blur_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
projects/video_diffusion_sr/infer.py CHANGED
@@ -37,31 +37,11 @@ from common.distributed.meta_init_utils import (
37
  # from common.fs import download
38
 
39
  from models.dit_v2 import na
40
- from projects.video_diffusion_sr.degradation_utils import my_esr_blur
41
-
42
 
43
  class VideoDiffusionInfer():
44
  def __init__(self, config: DictConfig):
45
  self.config = config
46
 
47
- @log_on_entry
48
- def configure_blur(self):
49
- # Create degradation.
50
- def _blur_fn(x: List[torch.Tensor]):
51
- if x[0].ndim == 4:
52
- x = my_esr_blur(
53
- [rearrange(i, "c f h w -> 1 f c h w") for i in x], device=get_device()
54
- )
55
- x = [rearrange(i, "1 f c h w -> c f h w") for i in x]
56
- else:
57
- x = my_esr_blur(
58
- [rearrange(i, "c h w -> 1 1 c h w") for i in x], device=get_device()
59
- )
60
- x = [i[0, 0] for i in x]
61
- return x
62
-
63
- self.my_esr_blur = _blur_fn
64
-
65
  def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor:
66
  t, h, w, c = latent.shape
67
  cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)
 
37
  # from common.fs import download
38
 
39
  from models.dit_v2 import na
 
 
40
 
41
  class VideoDiffusionInfer():
42
  def __init__(self, config: DictConfig):
43
  self.config = config
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def get_condition(self, latent: Tensor, latent_blur: Tensor, task: str) -> Tensor:
46
  t, h, w, c = latent.shape
47
  cond = torch.zeros([t, h, w, c + 1], device=latent.device, dtype=latent.dtype)