Spaces:
Running
on
Zero
Running
on
Zero
IceClear
commited on
Commit
·
1fd3071
1
Parent(s):
c9102eb
update
Browse files- app.py +5 -5
- common/distributed/basic.py +3 -3
- projects/video_diffusion_sr/infer.py +28 -28
app.py
CHANGED
@@ -139,9 +139,9 @@ torch.hub.download_url_to_file(
|
|
139 |
'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
|
140 |
'03.mp4')
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
|
146 |
@spaces.GPU(duration=120)
|
147 |
def configure_runner(sp_size):
|
@@ -150,8 +150,8 @@ def configure_runner(sp_size):
|
|
150 |
runner = VideoDiffusionInfer(config)
|
151 |
OmegaConf.set_readonly(runner.config, False)
|
152 |
|
153 |
-
|
154 |
-
|
155 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
156 |
runner.configure_vae_model()
|
157 |
# Set memory limit.
|
|
|
139 |
'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
|
140 |
'03.mp4')
|
141 |
|
142 |
+
def configure_sequence_parallel(sp_size):
|
143 |
+
if sp_size > 1:
|
144 |
+
init_sequence_parallel(sp_size)
|
145 |
|
146 |
@spaces.GPU(duration=120)
|
147 |
def configure_runner(sp_size):
|
|
|
150 |
runner = VideoDiffusionInfer(config)
|
151 |
OmegaConf.set_readonly(runner.config, False)
|
152 |
|
153 |
+
init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
|
154 |
+
configure_sequence_parallel(sp_size)
|
155 |
runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
|
156 |
runner.configure_vae_model()
|
157 |
# Set memory limit.
|
common/distributed/basic.py
CHANGED
@@ -66,11 +66,11 @@ def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
|
|
66 |
torch.backends.cuda.matmul.allow_tf32 = True
|
67 |
torch.backends.cudnn.allow_tf32 = True
|
68 |
torch.backends.cudnn.benchmark = cudnn_benchmark
|
69 |
-
torch.cuda.set_device(
|
70 |
dist.init_process_group(
|
71 |
backend="nccl",
|
72 |
-
rank=
|
73 |
-
world_size=
|
74 |
timeout=timeout,
|
75 |
)
|
76 |
|
|
|
66 |
torch.backends.cuda.matmul.allow_tf32 = True
|
67 |
torch.backends.cudnn.allow_tf32 = True
|
68 |
torch.backends.cudnn.benchmark = cudnn_benchmark
|
69 |
+
torch.cuda.set_device(0)
|
70 |
dist.init_process_group(
|
71 |
backend="nccl",
|
72 |
+
rank=0,
|
73 |
+
world_size=1,
|
74 |
timeout=timeout,
|
75 |
)
|
76 |
|
projects/video_diffusion_sr/infer.py
CHANGED
@@ -26,14 +26,14 @@ from common.diffusion import (
|
|
26 |
create_sampling_timesteps_from_config,
|
27 |
create_schedule_from_config,
|
28 |
)
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
# from common.fs import download
|
38 |
|
39 |
from models.dit_v2 import na
|
@@ -68,20 +68,20 @@ class VideoDiffusionInfer():
|
|
68 |
return cond
|
69 |
raise NotImplementedError
|
70 |
|
71 |
-
|
72 |
-
|
73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
74 |
# Load dit checkpoint.
|
75 |
# For fast init & resume,
|
76 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
77 |
# otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
|
83 |
# Create dit model.
|
84 |
-
with torch.device(
|
85 |
self.dit = create_object(self.config.dit.model)
|
86 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
87 |
|
@@ -90,27 +90,27 @@ class VideoDiffusionInfer():
|
|
90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
92 |
print(f"Loading info: {loading_info}")
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
|
98 |
# Print model size.
|
99 |
num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
|
100 |
print(f"DiT trainable parameters: {num_params:,}")
|
101 |
|
102 |
-
|
103 |
-
|
104 |
def configure_vae_model(self):
|
105 |
# Create vae model.
|
106 |
dtype = getattr(torch, self.config.vae.dtype)
|
107 |
self.vae = create_object(self.config.vae.model)
|
108 |
self.vae.requires_grad_(False).eval()
|
109 |
-
self.vae.to(device=
|
110 |
|
111 |
# Load vae checkpoint.
|
112 |
state = torch.load(
|
113 |
-
self.config.vae.checkpoint, map_location=
|
114 |
)
|
115 |
self.vae.load_state_dict(state)
|
116 |
|
@@ -123,12 +123,12 @@ class VideoDiffusionInfer():
|
|
123 |
def configure_diffusion(self):
|
124 |
self.schedule = create_schedule_from_config(
|
125 |
config=self.config.diffusion.schedule,
|
126 |
-
device=
|
127 |
)
|
128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
129 |
config=self.config.diffusion.timesteps.sampling,
|
130 |
schedule=self.schedule,
|
131 |
-
device=
|
132 |
)
|
133 |
self.sampler = create_sampler_from_config(
|
134 |
config=self.config.diffusion.sampler,
|
@@ -143,7 +143,7 @@ class VideoDiffusionInfer():
|
|
143 |
use_sample = self.config.vae.get("use_sample", True)
|
144 |
latents = []
|
145 |
if len(samples) > 0:
|
146 |
-
device =
|
147 |
dtype = getattr(torch, self.config.vae.dtype)
|
148 |
scale = self.config.vae.scaling_factor
|
149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
@@ -186,7 +186,7 @@ class VideoDiffusionInfer():
|
|
186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
187 |
samples = []
|
188 |
if len(latents) > 0:
|
189 |
-
device =
|
190 |
dtype = getattr(torch, self.config.vae.dtype)
|
191 |
scale = self.config.vae.scaling_factor
|
192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
@@ -340,9 +340,9 @@ class VideoDiffusionInfer():
|
|
340 |
self.dit.to("cpu")
|
341 |
|
342 |
# Vae decode.
|
343 |
-
self.vae.to(
|
344 |
samples = self.vae_decode(latents)
|
345 |
|
346 |
if dit_offload:
|
347 |
-
self.dit.to(
|
348 |
return samples
|
|
|
26 |
create_sampling_timesteps_from_config,
|
27 |
create_schedule_from_config,
|
28 |
)
|
29 |
+
from common.distributed import (
|
30 |
+
get_device,
|
31 |
+
get_global_rank,
|
32 |
+
)
|
33 |
+
|
34 |
+
from common.distributed.meta_init_utils import (
|
35 |
+
meta_non_persistent_buffer_init_fn,
|
36 |
+
)
|
37 |
# from common.fs import download
|
38 |
|
39 |
from models.dit_v2 import na
|
|
|
68 |
return cond
|
69 |
raise NotImplementedError
|
70 |
|
71 |
+
@log_on_entry
|
72 |
+
@log_runtime
|
73 |
def configure_dit_model(self, device="cpu", checkpoint=None):
|
74 |
# Load dit checkpoint.
|
75 |
# For fast init & resume,
|
76 |
# when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
|
77 |
# otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
|
78 |
+
if self.config.dit.get("init_with_meta_device", False):
|
79 |
+
init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
|
80 |
+
else:
|
81 |
+
init_device = "cpu"
|
82 |
|
83 |
# Create dit model.
|
84 |
+
with torch.device(init_device):
|
85 |
self.dit = create_object(self.config.dit.model)
|
86 |
self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
|
87 |
|
|
|
90 |
loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
|
91 |
print(f"Loading pretrained ckpt from {checkpoint}")
|
92 |
print(f"Loading info: {loading_info}")
|
93 |
+
self.dit = meta_non_persistent_buffer_init_fn(self.dit)
|
94 |
|
95 |
+
if device in [get_device(), "cuda"]:
|
96 |
+
self.dit.to(get_device())
|
97 |
|
98 |
# Print model size.
|
99 |
num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
|
100 |
print(f"DiT trainable parameters: {num_params:,}")
|
101 |
|
102 |
+
@log_on_entry
|
103 |
+
@log_runtime
|
104 |
def configure_vae_model(self):
|
105 |
# Create vae model.
|
106 |
dtype = getattr(torch, self.config.vae.dtype)
|
107 |
self.vae = create_object(self.config.vae.model)
|
108 |
self.vae.requires_grad_(False).eval()
|
109 |
+
self.vae.to(device=get_device(), dtype=dtype)
|
110 |
|
111 |
# Load vae checkpoint.
|
112 |
state = torch.load(
|
113 |
+
self.config.vae.checkpoint, map_location=get_device(), mmap=True
|
114 |
)
|
115 |
self.vae.load_state_dict(state)
|
116 |
|
|
|
123 |
def configure_diffusion(self):
|
124 |
self.schedule = create_schedule_from_config(
|
125 |
config=self.config.diffusion.schedule,
|
126 |
+
device=get_device(),
|
127 |
)
|
128 |
self.sampling_timesteps = create_sampling_timesteps_from_config(
|
129 |
config=self.config.diffusion.timesteps.sampling,
|
130 |
schedule=self.schedule,
|
131 |
+
device=get_device(),
|
132 |
)
|
133 |
self.sampler = create_sampler_from_config(
|
134 |
config=self.config.diffusion.sampler,
|
|
|
143 |
use_sample = self.config.vae.get("use_sample", True)
|
144 |
latents = []
|
145 |
if len(samples) > 0:
|
146 |
+
device = get_device()
|
147 |
dtype = getattr(torch, self.config.vae.dtype)
|
148 |
scale = self.config.vae.scaling_factor
|
149 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
186 |
def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
|
187 |
samples = []
|
188 |
if len(latents) > 0:
|
189 |
+
device = get_device()
|
190 |
dtype = getattr(torch, self.config.vae.dtype)
|
191 |
scale = self.config.vae.scaling_factor
|
192 |
shift = self.config.vae.get("shifting_factor", 0.0)
|
|
|
340 |
self.dit.to("cpu")
|
341 |
|
342 |
# Vae decode.
|
343 |
+
self.vae.to(get_device())
|
344 |
samples = self.vae_decode(latents)
|
345 |
|
346 |
if dit_offload:
|
347 |
+
self.dit.to(get_device())
|
348 |
return samples
|