GMT: Graph Matching Transformer

GMT (Graph Matching Transformer) is a PyTorch-based framework for matching and aligning 2D curves (graphs) using rich geometric embeddings and a cross-attention Transformer architecture. It supports four model variantsβ€”tiny, small, medium, and largeβ€”to scale computational complexity and capacity.


Key Features

  • Multi-Geometry Support: Generates and processes sinusoids, circles, ellipses, and random polylines.
  • Curvature & Ray Embeddings: Computes curvature, ray distances, incidence angles, and hit flags for each point.
  • Index & Initial Shift Embedding: Includes normalized index, curvature, and initial displacement as features.
  • Cross-Attention Transformer: Two-stream self-attention on target & baseline, followed by cross-attention for fine-grained alignment.
  • Variants: Four predefined configurations (tiny, small, medium, large) with adjustable d_model, depth, and feed-forward dimensions.
  • Metal/CUDA/CPU: Auto-selects MPS (Apple Silicon), CUDA, or CPU device.
  • Visualizations: Built-in training loss curves, inference progression plots, and error distribution histograms.

Repository Structure

weights/                 # Weights folder
README.md
train.py                 # Entry-point for training all variants
infer.py                 # CLI for inference and mapping extraction
gmt/                     # Core package
  __init__.py
  variants.py            # Model configurations
  utils.py               # Geometry & resampling utilities
  embeddings.py          # Ray-segment embedding functions
  dataset.py             # ThreadedRayDataset & helpers
  model.py               # Transformer definitions
  trainer.py             # Training loop and checkpointing
experiment.ipynb         # Jupyter notebook demo
LICENSE
requirements.txt         # Python dependencies

Installation

# Clone repository
git clone https://github.com/raildart/gmt.git
cd gmt

# (Optional) Create virtual environment
python -m venv .venv
source .venv/bin/activate  # or .venv\Scripts\activate on Windows

# Install dependencies
pip install -r requirements.txt

Quick Start

Training All Variants

python train.py \
  --epochs 30 \
  --batch_size 64 \
  --lr 5e-5

This will train tiny, small, medium, and large sequentially and save checkpoints as GMT_<variant>.pth.

Running Inference with External Geometries

python infer.py \
  --variant medium \
  --external path/to/geoms.npz \
  --samples 5 \
  --batch_size 16 \
  --save

This loads your own .npz with baseline and target arrays, runs the model, plots 5 sample alignments, and saves mappings_medium.npz.


Model Variants & Performance

Below is a summary of each variant’s architecture along with its final test MSE (mean squared error). Replace the placeholder MSE values with your actual results.

Variant d_model Layers FF Dim Dropout Test MSE
tiny 128 2 256 0.10 0.0034
small 256 3 512 0.15 0.0028
medium 512 4 1024 0.20 0.0026
large 768 5 1536 0.20 X

Mean Squared Error (MSE)

The Mean Squared Error (MSE) is our primary training and evaluation metric. For a single predicted sequence $\hat{\mathbf{y}} = [\hat{y}_1, \hat{y}_2, \dots, \hat{y}_N]$ and its ground-truth sequence $\mathbf{y} = [y_1, y_2, \dots, y_N]$, the MSE is computed as:

MSE(y,y^)β€…β€Š=β€…β€Š1Nβˆ‘i=1N(yiβˆ’y^i)2. \mathrm{MSE}(\mathbf{y}, \hat{\mathbf{y}}) \;=\; \frac{1}{N} \sum_{i=1}^{N} \bigl(y_i - \hat{y}_i\bigr)^{2}.

In our setting, each sequence consists of 2-D displacements for $N$ resampled points, so we actually average over both dimensions:

MSE=1Nβˆ‘i=1N[(Ξ”xiβˆ’Ξ”x^i)2+(Ξ”yiβˆ’Ξ”y^i)2]. \mathrm{MSE} = \frac{1}{N}\sum_{i=1}^{N}\Bigl[(\Delta x_i - \widehat{\Delta x}_i)^2 + (\Delta y_i - \widehat{\Delta y}_i)^2\Bigr].

During training, we report the batch-averaged MSE each epoch, and at the end we compute the dataset-wide MSE by averaging over all samples. Lower MSE indicates that the model’s predicted alignment shifts more closely match the true geometric offsets.


API Usage

from gmt.dataset import ThreadedRayDataset
from gmt.model import ComplexCrossTransformer
from gmt.trainer import train
from gmt.variants import define_variants

# Create dataset
ds = ThreadedRayDataset(num_samples=5000, max_workers=8)
feat_dim = ds.tgt_feats.shape[-1]

# Choose a variant
variant = 'medium'
model = ComplexCrossTransformer(tgt_dim=feat_dim, base_dim=3, variant=variant)

# Train
dtrained_model = train(ds, model, variant=variant, epochs=20, batch_size=64, lr=5e-5)

GITHUB

https://github.com/raildart/GMT


License

This project is licensed under the MIT License.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support