IceClear commited on
Commit
1fd3071
·
1 Parent(s): c9102eb
app.py CHANGED
@@ -139,9 +139,9 @@ torch.hub.download_url_to_file(
139
  'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
140
  '03.mp4')
141
 
142
- # def configure_sequence_parallel(sp_size):
143
- # if sp_size > 1:
144
- # init_sequence_parallel(sp_size)
145
 
146
  @spaces.GPU(duration=120)
147
  def configure_runner(sp_size):
@@ -150,8 +150,8 @@ def configure_runner(sp_size):
150
  runner = VideoDiffusionInfer(config)
151
  OmegaConf.set_readonly(runner.config, False)
152
 
153
- # init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
154
- # configure_sequence_parallel(sp_size)
155
  runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
156
  runner.configure_vae_model()
157
  # Set memory limit.
 
139
  'https://huggingface.co/datasets/Iceclear/SeedVR_VideoDemos/resolve/main/seedvr_videos_crf23/aigc1k/2_1_lq.mp4',
140
  '03.mp4')
141
 
142
+ def configure_sequence_parallel(sp_size):
143
+ if sp_size > 1:
144
+ init_sequence_parallel(sp_size)
145
 
146
  @spaces.GPU(duration=120)
147
  def configure_runner(sp_size):
 
150
  runner = VideoDiffusionInfer(config)
151
  OmegaConf.set_readonly(runner.config, False)
152
 
153
+ init_torch(cudnn_benchmark=False, timeout=datetime.timedelta(seconds=3600))
154
+ configure_sequence_parallel(sp_size)
155
  runner.configure_dit_model(device="cuda", checkpoint='./ckpts/seedvr2_ema_3b.pth')
156
  runner.configure_vae_model()
157
  # Set memory limit.
common/distributed/basic.py CHANGED
@@ -66,11 +66,11 @@ def init_torch(cudnn_benchmark=True, timeout=timedelta(seconds=600)):
66
  torch.backends.cuda.matmul.allow_tf32 = True
67
  torch.backends.cudnn.allow_tf32 = True
68
  torch.backends.cudnn.benchmark = cudnn_benchmark
69
- torch.cuda.set_device(get_local_rank())
70
  dist.init_process_group(
71
  backend="nccl",
72
- rank=get_global_rank(),
73
- world_size=get_world_size(),
74
  timeout=timeout,
75
  )
76
 
 
66
  torch.backends.cuda.matmul.allow_tf32 = True
67
  torch.backends.cudnn.allow_tf32 = True
68
  torch.backends.cudnn.benchmark = cudnn_benchmark
69
+ torch.cuda.set_device(0)
70
  dist.init_process_group(
71
  backend="nccl",
72
+ rank=0,
73
+ world_size=1,
74
  timeout=timeout,
75
  )
76
 
projects/video_diffusion_sr/infer.py CHANGED
@@ -26,14 +26,14 @@ from common.diffusion import (
26
  create_sampling_timesteps_from_config,
27
  create_schedule_from_config,
28
  )
29
- # from common.distributed import (
30
- # get_device,
31
- # get_global_rank,
32
- # )
33
-
34
- # from common.distributed.meta_init_utils import (
35
- # meta_non_persistent_buffer_init_fn,
36
- # )
37
  # from common.fs import download
38
 
39
  from models.dit_v2 import na
@@ -68,20 +68,20 @@ class VideoDiffusionInfer():
68
  return cond
69
  raise NotImplementedError
70
 
71
- # @log_on_entry
72
- # @log_runtime
73
  def configure_dit_model(self, device="cpu", checkpoint=None):
74
  # Load dit checkpoint.
75
  # For fast init & resume,
76
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
77
  # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
78
- # if self.config.dit.get("init_with_meta_device", False):
79
- # init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
80
- # else:
81
- # init_device = "cpu"
82
 
83
  # Create dit model.
84
- with torch.device("cpu"):
85
  self.dit = create_object(self.config.dit.model)
86
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
87
 
@@ -90,27 +90,27 @@ class VideoDiffusionInfer():
90
  loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
91
  print(f"Loading pretrained ckpt from {checkpoint}")
92
  print(f"Loading info: {loading_info}")
93
- # self.dit = meta_non_persistent_buffer_init_fn(self.dit)
94
 
95
- # if device in [get_device(), "cuda"]:
96
- self.dit.to("cuda")
97
 
98
  # Print model size.
99
  num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
100
  print(f"DiT trainable parameters: {num_params:,}")
101
 
102
- # @log_on_entry
103
- # @log_runtime
104
  def configure_vae_model(self):
105
  # Create vae model.
106
  dtype = getattr(torch, self.config.vae.dtype)
107
  self.vae = create_object(self.config.vae.model)
108
  self.vae.requires_grad_(False).eval()
109
- self.vae.to(device="cuda", dtype=dtype)
110
 
111
  # Load vae checkpoint.
112
  state = torch.load(
113
- self.config.vae.checkpoint, map_location="cuda", mmap=True
114
  )
115
  self.vae.load_state_dict(state)
116
 
@@ -123,12 +123,12 @@ class VideoDiffusionInfer():
123
  def configure_diffusion(self):
124
  self.schedule = create_schedule_from_config(
125
  config=self.config.diffusion.schedule,
126
- device="cuda",
127
  )
128
  self.sampling_timesteps = create_sampling_timesteps_from_config(
129
  config=self.config.diffusion.timesteps.sampling,
130
  schedule=self.schedule,
131
- device="cuda",
132
  )
133
  self.sampler = create_sampler_from_config(
134
  config=self.config.diffusion.sampler,
@@ -143,7 +143,7 @@ class VideoDiffusionInfer():
143
  use_sample = self.config.vae.get("use_sample", True)
144
  latents = []
145
  if len(samples) > 0:
146
- device = "cuda"
147
  dtype = getattr(torch, self.config.vae.dtype)
148
  scale = self.config.vae.scaling_factor
149
  shift = self.config.vae.get("shifting_factor", 0.0)
@@ -186,7 +186,7 @@ class VideoDiffusionInfer():
186
  def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
187
  samples = []
188
  if len(latents) > 0:
189
- device = "cuda"
190
  dtype = getattr(torch, self.config.vae.dtype)
191
  scale = self.config.vae.scaling_factor
192
  shift = self.config.vae.get("shifting_factor", 0.0)
@@ -340,9 +340,9 @@ class VideoDiffusionInfer():
340
  self.dit.to("cpu")
341
 
342
  # Vae decode.
343
- self.vae.to("cuda")
344
  samples = self.vae_decode(latents)
345
 
346
  if dit_offload:
347
- self.dit.to("cuda")
348
  return samples
 
26
  create_sampling_timesteps_from_config,
27
  create_schedule_from_config,
28
  )
29
+ from common.distributed import (
30
+ get_device,
31
+ get_global_rank,
32
+ )
33
+
34
+ from common.distributed.meta_init_utils import (
35
+ meta_non_persistent_buffer_init_fn,
36
+ )
37
  # from common.fs import download
38
 
39
  from models.dit_v2 import na
 
68
  return cond
69
  raise NotImplementedError
70
 
71
+ @log_on_entry
72
+ @log_runtime
73
  def configure_dit_model(self, device="cpu", checkpoint=None):
74
  # Load dit checkpoint.
75
  # For fast init & resume,
76
  # when training from scratch, rank0 init DiT on cpu, then sync to other ranks with FSDP.
77
  # otherwise, all ranks init DiT on meta device, then load_state_dict with assign=True.
78
+ if self.config.dit.get("init_with_meta_device", False):
79
+ init_device = "cpu" if get_global_rank() == 0 and checkpoint is None else "meta"
80
+ else:
81
+ init_device = "cpu"
82
 
83
  # Create dit model.
84
+ with torch.device(init_device):
85
  self.dit = create_object(self.config.dit.model)
86
  self.dit.set_gradient_checkpointing(self.config.dit.gradient_checkpoint)
87
 
 
90
  loading_info = self.dit.load_state_dict(state, strict=True, assign=True)
91
  print(f"Loading pretrained ckpt from {checkpoint}")
92
  print(f"Loading info: {loading_info}")
93
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
94
 
95
+ if device in [get_device(), "cuda"]:
96
+ self.dit.to(get_device())
97
 
98
  # Print model size.
99
  num_params = sum(p.numel() for p in self.dit.parameters() if p.requires_grad)
100
  print(f"DiT trainable parameters: {num_params:,}")
101
 
102
+ @log_on_entry
103
+ @log_runtime
104
  def configure_vae_model(self):
105
  # Create vae model.
106
  dtype = getattr(torch, self.config.vae.dtype)
107
  self.vae = create_object(self.config.vae.model)
108
  self.vae.requires_grad_(False).eval()
109
+ self.vae.to(device=get_device(), dtype=dtype)
110
 
111
  # Load vae checkpoint.
112
  state = torch.load(
113
+ self.config.vae.checkpoint, map_location=get_device(), mmap=True
114
  )
115
  self.vae.load_state_dict(state)
116
 
 
123
  def configure_diffusion(self):
124
  self.schedule = create_schedule_from_config(
125
  config=self.config.diffusion.schedule,
126
+ device=get_device(),
127
  )
128
  self.sampling_timesteps = create_sampling_timesteps_from_config(
129
  config=self.config.diffusion.timesteps.sampling,
130
  schedule=self.schedule,
131
+ device=get_device(),
132
  )
133
  self.sampler = create_sampler_from_config(
134
  config=self.config.diffusion.sampler,
 
143
  use_sample = self.config.vae.get("use_sample", True)
144
  latents = []
145
  if len(samples) > 0:
146
+ device = get_device()
147
  dtype = getattr(torch, self.config.vae.dtype)
148
  scale = self.config.vae.scaling_factor
149
  shift = self.config.vae.get("shifting_factor", 0.0)
 
186
  def vae_decode(self, latents: List[Tensor]) -> List[Tensor]:
187
  samples = []
188
  if len(latents) > 0:
189
+ device = get_device()
190
  dtype = getattr(torch, self.config.vae.dtype)
191
  scale = self.config.vae.scaling_factor
192
  shift = self.config.vae.get("shifting_factor", 0.0)
 
340
  self.dit.to("cpu")
341
 
342
  # Vae decode.
343
+ self.vae.to(get_device())
344
  samples = self.vae_decode(latents)
345
 
346
  if dit_offload:
347
+ self.dit.to(get_device())
348
  return samples