MeanFlowSE β€” One-Step Generative Speech Enhancement

Paper Hugging Face Model Code

MeanFlowSE is a conditional generative approach to speech enhancement that learns average velocities over short time spans and performs enhancement in a single step. Instead of rolling out a long ODE trajectory, it applies one backward-in-time displacement directly in the complex STFT domain, delivering competitive quality at a fraction of the compute and latency. The model is trained end-to-end with a local JVP-based objective and remains consistent with conditional flow matching on the diagonalβ€”no teacher models, schedulers, or distillation required. In practice, 1-NFE inference makes real-time deployment on standard hardware straightforward.

  • 🎧 Demo: demo page coming soon.

Table of Contents

Highlights

  • One-step enhancement (1-NFE): A single displacement update replaces long ODE rolloutsβ€”fast enough for real-time use on standard GPUs/CPUs.
  • No teachers, no distillation: Trains with a local, JVP-based objective; on the diagonal it exactly matches conditional flow matching.
  • Same model, two samplers: Use the displacement sampler for 1-step (or few-step) inference; fall back to Euler along the instantaneous field if you prefer multi-step.
  • Competitive & fast: strong ESTOI / SI-SDR / DNSMOS with very low RTF on VoiceBank-DEMAND.

What’s inside

  • Training with Average field supervision (for the 1-step displacement sampler).
  • Inference with euler_mf β€” single-step displacement along average field.
  • Audio front-end: complex STFT pipeline; configurable transforms & normalization.
  • Metrics: PESQ, ESTOI, SI-SDR; end-to-end RTF measurement.

Quick start

Installation

# Python 3.10 recommended

pip install -r requirements.txt
# Use a recent PyTorch + CUDA build for multi-GPU training

Data preparation

Expected layout:

<BASE_DIR>/
  train/clean/*.wav   train/noisy/*.wav
  valid/clean/*.wav   valid/noisy/*.wav
  test/clean/*.wav    test/noisy/*.wav

Defaults assume 16 kHz audio, centered frames, Hann windows, and a complex STFT representation (see SpecsDataModule for knobs).

Training

Single machine, multi-GPU (DDP):

# Edit DATA_DIR and GPUs inside the script if needed
bash train_vbd.sh

Or run directly:

torchrun --standalone --nproc_per_node=4 train.py \
  --backbone ncsnpp \
  --ode flowmatching \
  --base_dir <BASE_DIR> \
  --batch_size 2 \
  --num_workers 8 \
  --max_epochs 150 \
  --precision 32 \
  --gradient_clip_val 1.0 \
  --t_eps 0.03 --T_rev 1.0 \
  --sigma_min 0.0 --sigma_max 0.487 \
  --use_mfse \
  --mf_weight_final 0.25 \
  --mf_warmup_frac 0.5 \
  --mf_delta_gamma_start 8.0 --mf_delta_gamma_end 1.0 \
  --mf_delta_warmup_frac 0.7 \
  --mf_r_equals_t_prob 0.1 \
  --mf_jvp_clip 5.0 --mf_jvp_eps 1e-3 \
  --mf_jvp_impl fd --mf_jvp_chunk 1 \
  --mf_skip_weight_thresh 0.05 \
  --val_metrics_every_n_epochs 1 \
  --default_root_dir lightning_logs
  • Logging & checkpoints live under lightning_logs/<exp_name>/version_x/.
  • Heavy validation (PESQ/ESTOI/SI-SDR) runs every N epochs on rank-0; placeholders are logged otherwise so checkpoint monitors remain valid.

Inference

Use the helper script:

# MODE = multistep | multistep_mf | onestep
MODE=onestep STEPS=1 \
TEST_DATA_DIR=<BASE_DIR> \
CKPT_INPUT=path/to/best.ckpt \
bash run_inference.sh

Or call the evaluator:

python evaluate.py \
  --test_dir <BASE_DIR> \
  --folder_destination /path/to/output \
  --ckpt path/to/best.ckpt \
  --odesolver euler_mf \
  --reverse_starting_point 1.0 \
  --last_eval_point 0.0 \
  --one_step

evaluate.py writes enhanced WAVs. If --odesolver is not given, it auto-picks (euler_mf when MF-SE was used; otherwise euler).

Configuration

Common flags you may want to tweak:

  • Time & schedule

    • --T_rev (reverse start, default 1.0), --t_eps (terminal time), --sigma_min, --sigma_max
  • MF-SE stability

    • --mf_jvp_impl {auto,fd,autograd}, --mf_jvp_chunk, --mf_jvp_clip, --mf_jvp_eps
    • Curriculum: --mf_weight_final, --mf_warmup_frac, --mf_delta_*, --mf_r_equals_t_prob
  • Validation cost

    • --val_metrics_every_n_epochs, --num_eval_files
  • Backbone & front-end

    • Defined in backbones/ and SpecsDataModule (STFT, transforms, normalization)

Repository structure

MeanFlowSE/
β”œβ”€β”€ train.py                 # Lightning entry
β”œβ”€β”€ evaluate.py              # Enhancement script (WAV out)
β”œβ”€β”€ run_inference.sh         # One-step / few-step convenience runner
β”œβ”€β”€ flowmse/
β”‚   β”œβ”€β”€ model.py             # Losses, JVP, curriculum, logging
β”‚   β”œβ”€β”€ odes.py              # Path definition & registry
β”‚   β”œβ”€β”€ sampling/
β”‚   β”‚   β”œβ”€β”€ __init__.py
β”‚   β”‚   └── odesolvers.py    # Euler (instantaneous) & Euler-MF (displacement)
β”‚   β”œβ”€β”€ backbones/
β”‚   β”‚   β”œβ”€β”€ ncsnpp.py        # U-Net w/ time & delta embeddings
β”‚   β”‚   └── ...
β”‚   β”œβ”€β”€ data_module.py       # STFT I/O pipeline
β”‚   └── util/                # metrics, registry, tensors, inference helpers
β”œβ”€β”€ requirements.txt
└── scripts/
    └── train_vbd.sh

Built upon & related work

This repository builds upon previous great works:

Many design choices (complex STFT pipeline, training infrastructure) are inspired by these excellent projects.

Pretrained models

  • VoiceBank–DEMAND (16 kHz): We have hosted the weight files on Google Drive and added the link here.β€” Google Drive Link

Acknowledgments

We gratefully acknowledge Prof. Xie Chen’s group (X-LANCE Lab, SJTU) for their valuable guidance and support on training practices and engineering tips that helped this work a lot.

Citation

  • Citation: The paper is currently under review. We will add a BibTeX entry and article link once available.

Questions or issues? Please open a GitHub issue or pull request. We welcome contributions β€” from bug fixes to new backbones and front-ends.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for liduojia/MeanFlowSE

Unable to build the model tree, the base model loops to the model itself. Learn more.

Dataset used to train liduojia/MeanFlowSE