import logging logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) import gc from argparse import ArgumentParser from datetime import datetime from fractions import Fraction from pathlib import Path import gradio as gr import torch import torchaudio import torch.hub from mmaudio.eval_utils import (ModelConfig, VideoInfo, all_model_cfg, generate, load_image, load_video, make_video, setup_eval_logging) from mmaudio.model.flow_matching import FlowMatching from mmaudio.model.networks import MMAudio, get_my_mmaudio from mmaudio.model.sequence_config import SequenceConfig from mmaudio.model.utils.features_utils import FeaturesUtils torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True log = logging.getLogger() device = 'cpu' if torch.cuda.is_available(): device = 'cuda' elif torch.backends.mps.is_available(): device = 'mps' else: log.warning('CUDA/MPS are not available, running on CPU') dtype = torch.float32 MY_CHECKPOINT_PATH = './nsfw_gold_8.5k_final.pth' MY_MODEL_NAME = 'large_44k' EXT_WEIGHTS_DIR = Path('./ext_weights') EXT_WEIGHTS_DIR.mkdir(exist_ok=True) VAE_URL = "https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth" SYNCHFORMER_URL = "https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth" def download_dependency(url: str, local_path: Path): if not local_path.exists(): log.info(f"Downloading dependency from {url} to {local_path}...") torch.hub.download_url_to_file(url, str(local_path), progress=True) log.info(f"Download complete.") log.info("Checking for dependencies (VAE and Synchformer)...") VAE_PATH = EXT_WEIGHTS_DIR / 'v1-44.pth' SYNCHFORMER_PATH = EXT_WEIGHTS_DIR / 'synchformer_state_dict.pth' download_dependency(VAE_URL, VAE_PATH) download_dependency(SYNCHFORMER_URL, SYNCHFORMER_PATH) model_cfg_for_params: ModelConfig = all_model_cfg['large_44k_v2'] output_dir = Path('./output/gradio') setup_eval_logging() def get_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]: seq_cfg = model_cfg_for_params.seq_cfg net: MMAudio = get_my_mmaudio(MY_MODEL_NAME).to(device, dtype).eval() log.info(f'Loading YOUR fine-tuned weights from {MY_CHECKPOINT_PATH}') if not Path(MY_CHECKPOINT_PATH).exists(): raise FileNotFoundError(f"FATAL: Your model file was not found at {MY_CHECKPOINT_PATH}") net.load_weights(torch.load(MY_CHECKPOINT_PATH, map_location=device, weights_only=True)) log.info(f'Successfully loaded your weights!') feature_utils = FeaturesUtils(tod_vae_ckpt=VAE_PATH, synchformer_ckpt=SYNCHFORMER_PATH, enable_conditions=True, mode=model_cfg_for_params.mode, bigvgan_vocoder_ckpt=None, need_vae_encoder=False) feature_utils = feature_utils.to(device, dtype).eval() return net, feature_utils, seq_cfg net, feature_utils, seq_cfg = get_model() @torch.inference_mode() def video_to_audio(video: gr.Video, prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) video_info = load_video(video, duration) clip_frames = video_info.clip_frames sync_frames = video_info.sync_frames duration = video_info.duration_sec clip_frames = clip_frames.unsqueeze(0) sync_frames = sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S') output_dir.mkdir(exist_ok=True, parents=True) video_save_path = output_dir / f'{current_time_string}.mp4' make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) gc.collect() return video_save_path @torch.inference_mode() def image_to_audio(image: gr.Image, prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) image_info = load_image(image) clip_frames = image_info.clip_frames sync_frames = image_info.sync_frames clip_frames = clip_frames.unsqueeze(0) sync_frames = sync_frames.unsqueeze(0) seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength, image_input=True) audio = audios.float().cpu()[0] current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S') output_dir.mkdir(exist_ok=True, parents=True) video_save_path = output_dir / f'{current_time_string}.mp4' video_info = VideoInfo.from_image_info(image_info, duration, fps=Fraction(1)) make_video(video_info, video_save_path, audio, sampling_rate=seq_cfg.sampling_rate) gc.collect() return video_save_path @torch.inference_mode() def text_to_audio(prompt: str, negative_prompt: str, seed: int, num_steps: int, cfg_strength: float, duration: float): rng = torch.Generator(device=device) if seed >= 0: rng.manual_seed(seed) else: rng.seed() fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps) clip_frames = sync_frames = None seq_cfg.duration = duration net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) audios = generate(clip_frames, sync_frames, [prompt], negative_text=[negative_prompt], feature_utils=feature_utils, net=net, fm=fm, rng=rng, cfg_strength=cfg_strength) audio = audios.float().cpu()[0] current_time_string = datetime.now().strftime('%Y%m%d_%H%M%S') output_dir.mkdir(exist_ok=True, parents=True) audio_save_path = output_dir / f'{current_time_string}.flac' torchaudio.save(audio_save_path, audio, seq_cfg.sampling_rate) gc.collect() return audio_save_path video_to_audio_tab = gr.Interface( fn=video_to_audio, description=""" Fine-tuned model: cloud19/NSFW_MMaudio
Based on the original project: https://github.com/hkchengrex/MMAudio

NOTE: It takes longer to process high-resolution videos (>384 px on the shorter side). Doing so does not improve results. """, inputs=[ gr.Video(), gr.Text(label='Prompt'), gr.Text(label='Negative prompt', value='music'), gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1), gr.Number(label='Num steps', value=25, precision=0, minimum=1), gr.Number(label='Guidance Strength', value=4.5, minimum=1), gr.Number(label='Duration (sec)', value=8, minimum=1), ], outputs='playable_video', cache_examples=False, title='MMAudio — Video-to-Audio Synthesis', ) text_to_audio_tab = gr.Interface( fn=text_to_audio, description=""" Fine-tuned model: cloud19/NSFW_MMaudio
Based on the original project: https://github.com/hkchengrex/MMAudio """, inputs=[ gr.Text(label='Prompt'), gr.Text(label='Negative prompt'), gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1), gr.Number(label='Num steps', value=25, precision=0, minimum=1), gr.Number(label='Guidance Strength', value=4.5, minimum=1), gr.Number(label='Duration (sec)', value=8, minimum=1), ], outputs='audio', cache_examples=False, title='MMAudio — Text-to-Audio Synthesis', ) image_to_audio_tab = gr.Interface( fn=image_to_audio, description=""" Fine-tuned model: cloud19/NSFW_MMaudio
Based on the original project: https://github.com/hkchengrex/MMAudio

NOTE: It takes longer to process high-resolution images (>384 px on the shorter side). Doing so does not improve results. """, inputs=[ gr.Image(type='filepath'), gr.Text(label='Prompt'), gr.Text(label='Negative prompt'), gr.Number(label='Seed (-1: random)', value=-1, precision=0, minimum=-1), gr.Number(label='Num steps', value=25, precision=0, minimum=1), gr.Number(label='Guidance Strength', value=4.5, minimum=1), gr.Number(label='Duration (sec)', value=8, minimum=1), ], outputs='playable_video', cache_examples=False, title='MMAudio — Image-to-Audio Synthesis (experimental)', ) if __name__ == "__main__": parser = ArgumentParser() parser.add_argument('--port', type=int, default=7860) parser.add_argument('--share', action='store_true', help='Create a public link') args = parser.parse_args() app = gr.TabbedInterface([video_to_audio_tab, text_to_audio_tab, image_to_audio_tab], ['Video-to-Audio', 'Text-to-Audio', 'Image-to-Audio (experimental)']) app.launch(server_name="0.0.0.0", server_port=args.port, share=args.share, allowed_paths=[output_dir])