SMDM (Scaling up Masked Diffusion Models on Text)

This is the official implementation of the paper "Scaling up Masked Diffusion Models on Text" (https://arxiv.org/abs/2410.18514).

Model Description

SMDM is a family of masked diffusion models (MDMs) trained on the SlimPajama dataset. The models demonstrate competitive performance with autoregressive models (ARMs) while offering unique advantages in terms of bidirectional reasoning and temporal adaptation.

Key features:

  • Scalable architecture with up to 1.1B parameters
  • Unsupervised classifier-free guidance for better conditional inference
  • Competitive performance on language understanding tasks
  • Strong math reasoning capabilities
  • Flexible trade-off between speed and quality in text generation

Usage

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# Load model and tokenizer
model_name = "nieshen/SMDM"  # Replace with your model name
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Generate text
input_text = "Once upon a time"
inputs = tokenizer(input_text, return_tensors="pt")
outputs = model.generate(**inputs, max_length=100)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(generated_text)

Training Details

The model was trained on the SlimPajama dataset using the following key components:

  • Architecture: Masked Diffusion Model
  • Training objective: Denoising diffusion
  • Dataset: SlimPajama
  • Hardware: Multiple GPUs
  • Framework: PyTorch with Lightning

Evaluation

The model has been evaluated on various benchmarks:

  • Language understanding tasks (HellaSwag, OpenBookQA, ARC-Easy, etc.)
  • Math reasoning (GSM8K)
  • Conditional generation (MT-Bench)
  • Reverse curse tasks
  • Temporal quality degradation

Citation

If you use this model, please cite our paper:

@article{smdm2024,
  title={Scaling up Masked Diffusion Models on Text},
  author={[Authors]},
  journal={arXiv preprint arXiv:2410.18514},
  year={2024}
}

License

This model is released under the Apache 2.0 license.

Downloads last month
30
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support