Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	
		Rex Cheng
		
	commited on
		
		
					Commit 
							
							·
						
						b0ec3f5
	
1
								Parent(s):
							
							164c335
								
test
Browse files- app.py +5 -6
 - demo.py +9 -9
 - mmaudio/eval_utils.py +20 -58
 
    	
        app.py
    CHANGED
    
    | 
         @@ -67,7 +67,10 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int 
     | 
|
| 67 | 
         
             
                rng.manual_seed(seed)
         
     | 
| 68 | 
         
             
                fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
         
     | 
| 69 | 
         | 
| 70 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 71 | 
         
             
                clip_frames = clip_frames.unsqueeze(0)
         
     | 
| 72 | 
         
             
                sync_frames = sync_frames.unsqueeze(0)
         
     | 
| 73 | 
         
             
                seq_cfg.duration = duration
         
     | 
| 
         @@ -87,11 +90,7 @@ def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int 
     | 
|
| 87 | 
         
             
                video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
         
     | 
| 88 | 
         
             
                # output_dir.mkdir(exist_ok=True, parents=True)
         
     | 
| 89 | 
         
             
                # video_save_path = output_dir / f'{current_time_string}.mp4'
         
     | 
| 90 | 
         
            -
                make_video( 
     | 
| 91 | 
         
            -
                           video_save_path,
         
     | 
| 92 | 
         
            -
                           audio,
         
     | 
| 93 | 
         
            -
                           sampling_rate=seq_cfg.sampling_rate,
         
     | 
| 94 | 
         
            -
                           duration_sec=seq_cfg.duration)
         
     | 
| 95 | 
         
             
                log.info(f'Saved video to {video_save_path}')
         
     | 
| 96 | 
         
             
                return video_save_path
         
     | 
| 97 | 
         | 
| 
         | 
|
| 67 | 
         
             
                rng.manual_seed(seed)
         
     | 
| 68 | 
         
             
                fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
         
     | 
| 69 | 
         | 
| 70 | 
         
            +
                video_info = load_video(video, duration)
         
     | 
| 71 | 
         
            +
                clip_frames = video_info.clip_frames
         
     | 
| 72 | 
         
            +
                sync_frames = video_info.sync_frames
         
     | 
| 73 | 
         
            +
                duration = video_info.duration_sec
         
     | 
| 74 | 
         
             
                clip_frames = clip_frames.unsqueeze(0)
         
     | 
| 75 | 
         
             
                sync_frames = sync_frames.unsqueeze(0)
         
     | 
| 76 | 
         
             
                seq_cfg.duration = duration
         
     | 
| 
         | 
|
| 90 | 
         
             
                video_save_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
         
     | 
| 91 | 
         
             
                # output_dir.mkdir(exist_ok=True, parents=True)
         
     | 
| 92 | 
         
             
                # video_save_path = output_dir / f'{current_time_string}.mp4'
         
     | 
| 93 | 
         
            +
                make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 94 | 
         
             
                log.info(f'Saved video to {video_save_path}')
         
     | 
| 95 | 
         
             
                return video_save_path
         
     | 
| 96 | 
         | 
    	
        demo.py
    CHANGED
    
    | 
         @@ -5,8 +5,8 @@ from pathlib import Path 
     | 
|
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            import torchaudio
         
     | 
| 7 | 
         | 
| 8 | 
         
            -
            from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate,
         
     | 
| 9 | 
         
            -
                                             
     | 
| 10 | 
         
             
            from mmaudio.model.flow_matching import FlowMatching
         
     | 
| 11 | 
         
             
            from mmaudio.model.networks import MMAudio, get_my_mmaudio
         
     | 
| 12 | 
         
             
            from mmaudio.model.utils.features_utils import FeaturesUtils
         
     | 
| 
         @@ -81,12 +81,16 @@ def main(): 
     | 
|
| 81 | 
         
             
                                              synchformer_ckpt=model.synchformer_ckpt,
         
     | 
| 82 | 
         
             
                                              enable_conditions=True,
         
     | 
| 83 | 
         
             
                                              mode=model.mode,
         
     | 
| 84 | 
         
            -
                                              bigvgan_vocoder_ckpt=model.bigvgan_16k_path 
     | 
| 
         | 
|
| 85 | 
         
             
                feature_utils = feature_utils.to(device, dtype).eval()
         
     | 
| 86 | 
         | 
| 87 | 
         
             
                if video_path is not None:
         
     | 
| 88 | 
         
             
                    log.info(f'Using video {video_path}')
         
     | 
| 89 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 90 | 
         
             
                    if mask_away_clip:
         
     | 
| 91 | 
         
             
                        clip_frames = None
         
     | 
| 92 | 
         
             
                    else:
         
     | 
| 
         @@ -121,11 +125,7 @@ def main(): 
     | 
|
| 121 | 
         
             
                log.info(f'Audio saved to {save_path}')
         
     | 
| 122 | 
         
             
                if video_path is not None and not skip_video_composite:
         
     | 
| 123 | 
         
             
                    video_save_path = output_dir / f'{video_path.stem}.mp4'
         
     | 
| 124 | 
         
            -
                    make_video( 
     | 
| 125 | 
         
            -
                               video_save_path,
         
     | 
| 126 | 
         
            -
                               audio,
         
     | 
| 127 | 
         
            -
                               sampling_rate=seq_cfg.sampling_rate,
         
     | 
| 128 | 
         
            -
                               duration_sec=seq_cfg.duration)
         
     | 
| 129 | 
         
             
                    log.info(f'Video saved to {output_dir / video_save_path}')
         
     | 
| 130 | 
         | 
| 131 | 
         
             
                log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
         
     | 
| 
         | 
|
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            import torchaudio
         
     | 
| 7 | 
         | 
| 8 | 
         
            +
            from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
         
     | 
| 9 | 
         
            +
                                            setup_eval_logging)
         
     | 
| 10 | 
         
             
            from mmaudio.model.flow_matching import FlowMatching
         
     | 
| 11 | 
         
             
            from mmaudio.model.networks import MMAudio, get_my_mmaudio
         
     | 
| 12 | 
         
             
            from mmaudio.model.utils.features_utils import FeaturesUtils
         
     | 
| 
         | 
|
| 81 | 
         
             
                                              synchformer_ckpt=model.synchformer_ckpt,
         
     | 
| 82 | 
         
             
                                              enable_conditions=True,
         
     | 
| 83 | 
         
             
                                              mode=model.mode,
         
     | 
| 84 | 
         
            +
                                              bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
         
     | 
| 85 | 
         
            +
                                              need_vae_encoder=False)
         
     | 
| 86 | 
         
             
                feature_utils = feature_utils.to(device, dtype).eval()
         
     | 
| 87 | 
         | 
| 88 | 
         
             
                if video_path is not None:
         
     | 
| 89 | 
         
             
                    log.info(f'Using video {video_path}')
         
     | 
| 90 | 
         
            +
                    video_info = load_video(video_path, duration)
         
     | 
| 91 | 
         
            +
                    clip_frames = video_info.clip_frames
         
     | 
| 92 | 
         
            +
                    sync_frames = video_info.sync_frames
         
     | 
| 93 | 
         
            +
                    duration = video_info.duration_sec
         
     | 
| 94 | 
         
             
                    if mask_away_clip:
         
     | 
| 95 | 
         
             
                        clip_frames = None
         
     | 
| 96 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 125 | 
         
             
                log.info(f'Audio saved to {save_path}')
         
     | 
| 126 | 
         
             
                if video_path is not None and not skip_video_composite:
         
     | 
| 127 | 
         
             
                    video_save_path = output_dir / f'{video_path.stem}.mp4'
         
     | 
| 128 | 
         
            +
                    make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 129 | 
         
             
                    log.info(f'Video saved to {output_dir / video_save_path}')
         
     | 
| 130 | 
         | 
| 131 | 
         
             
                log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
         
     | 
    	
        mmaudio/eval_utils.py
    CHANGED
    
    | 
         @@ -3,12 +3,11 @@ import logging 
     | 
|
| 3 | 
         
             
            from pathlib import Path
         
     | 
| 4 | 
         
             
            from typing import Optional
         
     | 
| 5 | 
         | 
| 6 | 
         
            -
            import av
         
     | 
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         
             
            from colorlog import ColoredFormatter
         
     | 
| 9 | 
         
             
            from torchvision.transforms import v2
         
     | 
| 10 | 
         
            -
            from torio.io import StreamingMediaDecoder, StreamingMediaEncoder
         
     | 
| 11 | 
         | 
| 
         | 
|
| 12 | 
         
             
            from mmaudio.model.flow_matching import FlowMatching
         
     | 
| 13 | 
         
             
            from mmaudio.model.networks import MMAudio
         
     | 
| 14 | 
         
             
            from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
         
     | 
| 
         @@ -154,7 +153,7 @@ def setup_eval_logging(log_level: int = logging.INFO): 
     | 
|
| 154 | 
         
             
                log.addHandler(stream)
         
     | 
| 155 | 
         | 
| 156 | 
         | 
| 157 | 
         
            -
            def load_video(video_path: Path, duration_sec: float) ->  
     | 
| 158 | 
         
             
                _CLIP_SIZE = 384
         
     | 
| 159 | 
         
             
                _CLIP_FPS = 8.0
         
     | 
| 160 | 
         | 
| 
         @@ -175,26 +174,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor 
     | 
|
| 175 | 
         
             
                    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         
     | 
| 176 | 
         
             
                ])
         
     | 
| 177 | 
         | 
| 178 | 
         
            -
                 
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
             
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
            -
                    format='rgb24',
         
     | 
| 184 | 
         
            -
                )
         
     | 
| 185 | 
         
            -
                reader.add_basic_video_stream(
         
     | 
| 186 | 
         
            -
                    frames_per_chunk=int(_SYNC_FPS * duration_sec),
         
     | 
| 187 | 
         
            -
                    buffer_chunk_size=-1,
         
     | 
| 188 | 
         
            -
                    frame_rate=_SYNC_FPS,
         
     | 
| 189 | 
         
            -
                    format='rgb24',
         
     | 
| 190 | 
         
            -
                )
         
     | 
| 191 | 
         | 
| 192 | 
         
            -
                 
     | 
| 193 | 
         
            -
                 
     | 
| 194 | 
         
            -
                 
     | 
| 195 | 
         
            -
                sync_chunk = data_chunk[1]
         
     | 
| 196 | 
         
            -
                assert clip_chunk is not None
         
     | 
| 197 | 
         
            -
                assert sync_chunk is not None
         
     | 
| 198 | 
         | 
| 199 | 
         
             
                clip_frames = clip_transform(clip_chunk)
         
     | 
| 200 | 
         
             
                sync_frames = sync_transform(sync_chunk)
         
     | 
| 
         @@ -215,41 +203,15 @@ def load_video(video_path: Path, duration_sec: float) -> tuple[torch.Tensor, tor 
     | 
|
| 215 | 
         
             
                clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
         
     | 
| 216 | 
         
             
                sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
         
     | 
| 217 | 
         | 
| 218 | 
         
            -
                 
     | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
             
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 223 | 
         | 
| 224 | 
         
            -
                av_video = av.open(video_path)
         
     | 
| 225 | 
         
            -
                frame_rate = av_video.streams.video[0].guessed_rate
         
     | 
| 226 | 
         | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
                 
     | 
| 229 | 
         
            -
                reader.add_basic_video_stream(
         
     | 
| 230 | 
         
            -
                    frames_per_chunk=approx_max_length,
         
     | 
| 231 | 
         
            -
                    buffer_chunk_size=-1,
         
     | 
| 232 | 
         
            -
                    format='rgb24',
         
     | 
| 233 | 
         
            -
                )
         
     | 
| 234 | 
         
            -
                reader.fill_buffer()
         
     | 
| 235 | 
         
            -
                video_chunk = reader.pop_chunks()[0]
         
     | 
| 236 | 
         
            -
                assert video_chunk is not None
         
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
                h, w = video_chunk.shape[-2:]
         
     | 
| 239 | 
         
            -
                video_chunk = video_chunk[:int(frame_rate * duration_sec)]
         
     | 
| 240 | 
         
            -
             
     | 
| 241 | 
         
            -
                writer = StreamingMediaEncoder(output_path)
         
     | 
| 242 | 
         
            -
                writer.add_audio_stream(
         
     | 
| 243 | 
         
            -
                    sample_rate=sampling_rate,
         
     | 
| 244 | 
         
            -
                    num_channels=audio.shape[0],
         
     | 
| 245 | 
         
            -
                    encoder='aac',  # 'flac' does not work for some reason?
         
     | 
| 246 | 
         
            -
                )
         
     | 
| 247 | 
         
            -
                writer.add_video_stream(frame_rate=frame_rate,
         
     | 
| 248 | 
         
            -
                                        width=w,
         
     | 
| 249 | 
         
            -
                                        height=h,
         
     | 
| 250 | 
         
            -
                                        format='rgb24',
         
     | 
| 251 | 
         
            -
                                        encoder='libx264',
         
     | 
| 252 | 
         
            -
                                        encoder_format='yuv420p')
         
     | 
| 253 | 
         
            -
                with writer.open():
         
     | 
| 254 | 
         
            -
                    writer.write_audio_chunk(0, audio.float().transpose(0, 1))
         
     | 
| 255 | 
         
            -
                    writer.write_video_chunk(1, video_chunk)
         
     | 
| 
         | 
|
| 3 | 
         
             
            from pathlib import Path
         
     | 
| 4 | 
         
             
            from typing import Optional
         
     | 
| 5 | 
         | 
| 
         | 
|
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
             
            from colorlog import ColoredFormatter
         
     | 
| 8 | 
         
             
            from torchvision.transforms import v2
         
     | 
| 
         | 
|
| 9 | 
         | 
| 10 | 
         
            +
            from mmaudio.data.av_utils import VideoInfo, read_frames, reencode_with_audio
         
     | 
| 11 | 
         
             
            from mmaudio.model.flow_matching import FlowMatching
         
     | 
| 12 | 
         
             
            from mmaudio.model.networks import MMAudio
         
     | 
| 13 | 
         
             
            from mmaudio.model.sequence_config import (CONFIG_16K, CONFIG_44K, SequenceConfig)
         
     | 
| 
         | 
|
| 153 | 
         
             
                log.addHandler(stream)
         
     | 
| 154 | 
         | 
| 155 | 
         | 
| 156 | 
         
            +
            def load_video(video_path: Path, duration_sec: float, load_all_frames: bool = True) -> VideoInfo:
         
     | 
| 157 | 
         
             
                _CLIP_SIZE = 384
         
     | 
| 158 | 
         
             
                _CLIP_FPS = 8.0
         
     | 
| 159 | 
         | 
| 
         | 
|
| 174 | 
         
             
                    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
         
     | 
| 175 | 
         
             
                ])
         
     | 
| 176 | 
         | 
| 177 | 
         
            +
                output_frames, all_frames, orig_fps = read_frames(video_path,
         
     | 
| 178 | 
         
            +
                                                                  list_of_fps=[_CLIP_FPS, _SYNC_FPS],
         
     | 
| 179 | 
         
            +
                                                                  start_sec=0,
         
     | 
| 180 | 
         
            +
                                                                  end_sec=duration_sec,
         
     | 
| 181 | 
         
            +
                                                                  need_all_frames=load_all_frames)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 182 | 
         | 
| 183 | 
         
            +
                clip_chunk, sync_chunk = output_frames
         
     | 
| 184 | 
         
            +
                clip_chunk = torch.from_numpy(clip_chunk).permute(0, 3, 1, 2)
         
     | 
| 185 | 
         
            +
                sync_chunk = torch.from_numpy(sync_chunk).permute(0, 3, 1, 2)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 186 | 
         | 
| 187 | 
         
             
                clip_frames = clip_transform(clip_chunk)
         
     | 
| 188 | 
         
             
                sync_frames = sync_transform(sync_chunk)
         
     | 
| 
         | 
|
| 203 | 
         
             
                clip_frames = clip_frames[:int(_CLIP_FPS * duration_sec)]
         
     | 
| 204 | 
         
             
                sync_frames = sync_frames[:int(_SYNC_FPS * duration_sec)]
         
     | 
| 205 | 
         | 
| 206 | 
         
            +
                video_info = VideoInfo(
         
     | 
| 207 | 
         
            +
                    duration_sec=duration_sec,
         
     | 
| 208 | 
         
            +
                    fps=orig_fps,
         
     | 
| 209 | 
         
            +
                    clip_frames=clip_frames,
         
     | 
| 210 | 
         
            +
                    sync_frames=sync_frames,
         
     | 
| 211 | 
         
            +
                    all_frames=all_frames if load_all_frames else None,
         
     | 
| 212 | 
         
            +
                )
         
     | 
| 213 | 
         
            +
                return video_info
         
     | 
| 214 | 
         | 
| 
         | 
|
| 
         | 
|
| 215 | 
         | 
| 216 | 
         
            +
            def make_video(video_info: VideoInfo, output_path: Path, audio: torch.Tensor, sampling_rate: int):
         
     | 
| 217 | 
         
            +
                reencode_with_audio(video_info, output_path, audio, sampling_rate)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         |