MeanFlowSE is a conditional generative approach to speech enhancement that learns average velocities over short time spans and performs enhancement in a single step. Instead of rolling out a long ODE trajectory, it applies one backward-in-time displacement directly in the complex STFT domain, delivering competitive quality at a fraction of the compute and latency. The model is trained end-to-end with a local JVP-based objective and remains consistent with conditional flow matching on the diagonalβno teacher models, schedulers, or distillation required. In practice, 1-NFE inference makes real-time deployment on standard hardware straightforward.
- π§ Demo: demo page coming soon.
Table of Contents
Highlights
- One-step enhancement (1-NFE): A single displacement update replaces long ODE rolloutsβfast enough for real-time use on standard GPUs/CPUs.
- No teachers, no distillation: Trains with a local, JVP-based objective; on the diagonal it exactly matches conditional flow matching.
- Same model, two samplers: Use the displacement sampler for 1-step (or few-step) inference; fall back to Euler along the instantaneous field if you prefer multi-step.
- Competitive & fast: strong ESTOI / SI-SDR / DNSMOS with very low RTF on VoiceBank-DEMAND.
Whatβs inside
- Training with Average field supervision (for the 1-step displacement sampler).
- Inference with euler_mf β single-step displacement along average field.
- Audio front-end: complex STFT pipeline; configurable transforms & normalization.
- Metrics: PESQ, ESTOI, SI-SDR; end-to-end RTF measurement.
Quick start
Installation
# Python 3.10 recommended
pip install -r requirements.txt
# Use a recent PyTorch + CUDA build for multi-GPU training
Data preparation
Expected layout:
<BASE_DIR>/
train/clean/*.wav train/noisy/*.wav
valid/clean/*.wav valid/noisy/*.wav
test/clean/*.wav test/noisy/*.wav
Defaults assume 16 kHz audio, centered frames, Hann windows, and a complex STFT representation (see SpecsDataModule
for knobs).
Training
Single machine, multi-GPU (DDP):
# Edit DATA_DIR and GPUs inside the script if needed
bash train_vbd.sh
Or run directly:
torchrun --standalone --nproc_per_node=4 train.py \
--backbone ncsnpp \
--ode flowmatching \
--base_dir <BASE_DIR> \
--batch_size 2 \
--num_workers 8 \
--max_epochs 150 \
--precision 32 \
--gradient_clip_val 1.0 \
--t_eps 0.03 --T_rev 1.0 \
--sigma_min 0.0 --sigma_max 0.487 \
--use_mfse \
--mf_weight_final 0.25 \
--mf_warmup_frac 0.5 \
--mf_delta_gamma_start 8.0 --mf_delta_gamma_end 1.0 \
--mf_delta_warmup_frac 0.7 \
--mf_r_equals_t_prob 0.1 \
--mf_jvp_clip 5.0 --mf_jvp_eps 1e-3 \
--mf_jvp_impl fd --mf_jvp_chunk 1 \
--mf_skip_weight_thresh 0.05 \
--val_metrics_every_n_epochs 1 \
--default_root_dir lightning_logs
- Logging & checkpoints live under
lightning_logs/<exp_name>/version_x/
. - Heavy validation (PESQ/ESTOI/SI-SDR) runs every N epochs on rank-0; placeholders are logged otherwise so checkpoint monitors remain valid.
Inference
Use the helper script:
# MODE = multistep | multistep_mf | onestep
MODE=onestep STEPS=1 \
TEST_DATA_DIR=<BASE_DIR> \
CKPT_INPUT=path/to/best.ckpt \
bash run_inference.sh
Or call the evaluator:
python evaluate.py \
--test_dir <BASE_DIR> \
--folder_destination /path/to/output \
--ckpt path/to/best.ckpt \
--odesolver euler_mf \
--reverse_starting_point 1.0 \
--last_eval_point 0.0 \
--one_step
evaluate.py
writes enhanced WAVs. If--odesolver
is not given, it auto-picks (euler_mf
when MF-SE was used; otherwiseeuler
).
Configuration
Common flags you may want to tweak:
Time & schedule
--T_rev
(reverse start, default 1.0),--t_eps
(terminal time),--sigma_min
,--sigma_max
MF-SE stability
--mf_jvp_impl {auto,fd,autograd}
,--mf_jvp_chunk
,--mf_jvp_clip
,--mf_jvp_eps
- Curriculum:
--mf_weight_final
,--mf_warmup_frac
,--mf_delta_*
,--mf_r_equals_t_prob
Validation cost
--val_metrics_every_n_epochs
,--num_eval_files
Backbone & front-end
- Defined in
backbones/
andSpecsDataModule
(STFT, transforms, normalization)
- Defined in
Repository structure
MeanFlowSE/
βββ train.py # Lightning entry
βββ evaluate.py # Enhancement script (WAV out)
βββ run_inference.sh # One-step / few-step convenience runner
βββ flowmse/
β βββ model.py # Losses, JVP, curriculum, logging
β βββ odes.py # Path definition & registry
β βββ sampling/
β β βββ __init__.py
β β βββ odesolvers.py # Euler (instantaneous) & Euler-MF (displacement)
β βββ backbones/
β β βββ ncsnpp.py # U-Net w/ time & delta embeddings
β β βββ ...
β βββ data_module.py # STFT I/O pipeline
β βββ util/ # metrics, registry, tensors, inference helpers
βββ requirements.txt
βββ scripts/
βββ train_vbd.sh
Built upon & related work
This repository builds upon previous great works:
- SGMSE β https://github.com/sp-uhh/sgmse
- SGMSE-CRP β https://github.com/sp-uhh/sgmse_crp
- SGMSE-BBED β https://github.com/sp-uhh/sgmse-bbed
- FLOWMSE (FlowSE) β https://github.com/seongq/flowmse
Many design choices (complex STFT pipeline, training infrastructure) are inspired by these excellent projects.
Pretrained models
- VoiceBankβDEMAND (16 kHz): We have hosted the weight files on Google Drive and added the link here.β Google Drive Link
Acknowledgments
We gratefully acknowledge Prof. Xie Chenβs group (X-LANCE Lab, SJTU) for their valuable guidance and support on training practices and engineering tips that helped this work a lot.
Citation
- Citation: The paper is currently under review. We will add a BibTeX entry and article link once available.
Questions or issues? Please open a GitHub issue or pull request. We welcome contributions β from bug fixes to new backbones and front-ends.
Model tree for liduojia/MeanFlowSE
Unable to build the model tree, the base model loops to the model itself. Learn more.