MViT-Stroke-CycleGAN - Pretrained Weights
Pretrained weights for MViT-Stroke-CycleGAN, a CycleGAN variant using Mobile Vision Transformer (MViT) for hand-drawn stroke synthesis.
🔗 Code: GitHub Repository
📦 Checkpoint
File | Description |
---|---|
MViT-Stroke-GAN.pth |
Trained model weights. |
⚠️ Inference: Use train()
Mode Only
Do not use model.eval()
— it causes severe artifacts due to MViT's design.
✅ Correct:
# Load the model and weights
model = MViTCycleGANModel(3, 3) # or your model class
model.netG_A.load_state_dict(torch.load("Stroke_MViTGAN.pth", map_location="cpu")["netG_A"])
model.train() # Keep training mode
def disable_dropout(m):
if isinstance(m, torch.nn.Dropout):
m.eval()
if isinstance(m, torch.nn.Dropout2d):
m.eval()
model.apply(disable_dropout)
with torch.no_grad():
output = model.netG_A(input_tensor)
📦 Installation
pip install torch torchvision timm
📄 Citation
If you use this model or weights in your work, please cite:
@misc{mvit_stroke_cyclegan_2025,
author = {},
title = {},
year = {2025},
publisher = {Hugging Face},
howpublished = {\url{https://github.com/HJ-Peng/MViT-Stroke-GAN}}
}