Model Card for Model ID

This modelcard aims to be a base template for new models. It has been generated using this raw template.

Model Details

Model Description

AlignDRAW is first model can generates images from text but, you can't find this model on web anymore so we decided to do again in python! We trained on hand-written numbers and prompts!

  • Developed by: Bertug Gunel
  • Funded by [optional]: NoBody
  • Shared by [optional]: NoBody
  • Model type: Attention + VAE + RNN
  • Language(s) (NLP): EN
  • License: cc-by-nc-sa-4.0
  • Finetuned from model [optional]: NoBody

Model Sources [optional]

  • Repository: Cooming soon!
  • Paper [optional]: Cooming soon!
  • Demo [optional]: Cooming soon!

Uses

Direct Use

You can install weights, and embed head, direct use cooming soon!

CODE:

import torch
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from safetensors.torch import load_file

# —— Configurations ——
IMG_SIZE = 28
INPUT_DIM = IMG_SIZE * IMG_SIZE
LATENT_DIM = 100
TIMESTEPS = 10
CAPTION_EMBED_DIM = 50
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# —— Model Definitions ——
import torch.nn as nn
import torch.nn.functional as F

class CaptionEmbed(nn.Module):
    def __init__(self, num_classes=10, embed_dim=CAPTION_EMBED_DIM):
        super().__init__()
        self.embed = nn.Embedding(num_classes, embed_dim)
    def forward(self, labels):
        return self.embed(labels)

class DRAWTextModel(nn.Module):
    def __init__(self, input_dim, latent_dim, timesteps, caption_dim):
        super().__init__()
        self.encoder = nn.LSTM(input_dim + caption_dim, 256)
        self.decoder = nn.LSTM(latent_dim + caption_dim, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        self.fc_dec = nn.Linear(256, input_dim)

    def forward(self, x_seq, cap_seq):
        batch = x_seq.size(1)
        canvas = torch.zeros_like(x_seq)
        h_enc = (torch.zeros(1, batch, 256, device=x_seq.device),
                 torch.zeros(1, batch, 256, device=x_seq.device))
        h_dec = (torch.zeros(1, batch, 256, device=x_seq.device),
                 torch.zeros(1, batch, 256, device=x_seq.device))
        mus, logvars = [], []
        for t in range(x_seq.size(0)):
            diff = x_seq[t] - torch.sigmoid(canvas[t])
            diff_cap = torch.cat([diff, cap_seq[t]], dim=-1).unsqueeze(0)
            _, h_enc = self.encoder(diff_cap, h_enc)
            enc_h = h_enc[0].squeeze(0)
            mu = self.fc_mu(enc_h); logvar = self.fc_logvar(enc_h)
            std = torch.exp(0.5 * logvar)
            z = mu + std * torch.randn_like(std)
            z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0)
            _, h_dec = self.decoder(z_cap, h_dec)
            dec_h = h_dec[0].squeeze(0)
            canvas[t] = canvas[t] + self.fc_dec(dec_h)
        return canvas

# —— Load Pretrained Models ——
caption_model = CaptionEmbed().to(DEVICE)
model = DRAWTextModel(INPUT_DIM, LATENT_DIM, TIMESTEPS, CAPTION_EMBED_DIM).to(DEVICE)

caption_state = load_file("caption_embed.safetensors") #PATH TO EMBED HEAD IN YOUR PC
model_state = load_file("draw_model.safetensors") #PATH TO MODEL IN YOUR PC
caption_model.load_state_dict(caption_state)
model.load_state_dict(model_state)

caption_model.eval()
model.eval()

# —— Prompt Mapping ——
prompt2digit = {
    "number zero": 0,
    "number one": 1,
    "number two": 2,
    "number three": 3,
    "number four": 4,
    "number five": 5,
    "number six": 6,
    "number seven": 7,
    "number eight": 8,
    "number nine": 9
}
# —— Interactive Generation Loop ——
while True:
    prompt = input("Prompt gir (örn: 'iki sayısı', çıkmak için 'q'): ").strip().lower()
    if prompt == 'q':
        print('Çıkış yapıldı.')
        break
    if prompt not in prompt2digit:
        print(f"Bilinmeyen prompt: {prompt}")
        continue

    digit = prompt2digit[prompt]
    labels = torch.tensor([digit], device=DEVICE)
    caption_vec = caption_model(labels)
    cap_seq = caption_vec.unsqueeze(0).repeat(TIMESTEPS, 1, 1)

    # Generation
    h_dec = (torch.zeros(1, 1, 256, device=DEVICE),
             torch.zeros(1, 1, 256, device=DEVICE))
    canvas = torch.zeros(TIMESTEPS, 1, INPUT_DIM, device=DEVICE)
    for t in range(TIMESTEPS):
        z = torch.randn(1, LATENT_DIM, device=DEVICE)
        z_cap = torch.cat([z, cap_seq[t]], dim=-1).unsqueeze(0)
        _, h_dec = model.decoder(z_cap, h_dec)
        dec_h = h_dec[0].squeeze(0)
        canvas[t] = canvas[t] + model.fc_dec(dec_h)

    img = torch.sigmoid(canvas[-1]).view(1,1,IMG_SIZE,IMG_SIZE)
    grid = make_grid(img.cpu(), normalize=True)
    plt.figure(figsize=(3,3)); plt.axis('off'); plt.imshow(grid.permute(1,2,0)); plt.show()

Downstream Use [optional]

You can't fine-tune! (Or can you?)

Out-of-Scope Use

This model is only trained with: "MNIST" dataset (Handwritten numbers) and it can only generates numbers. FULL GENERATION cooming soon!

EMbed IDs: 0 = "0" 1 = "1" 2 = "2" 3 = "3" 4 = "4" 5 = "5" 6 = "6" 7 = "7" 8 = "8" 9 = "9"

No tokenization needed. ID = Number Class

Bias, Risks, and Limitations

No any risks!

How to Get Started with the Model

Use the code below to get started with the model.

[More Information Needed]

Training Details

Training Data

Model trained on 1 epochs, on "MNIST" dataset.

Preprocessing [optional]

Evaluation

No any tests!

Testing Data, Factors & Metrics

Testing Data

No any tests!

Metrics

-Accuracy: Training accuracy.

Results

Examples:

Input ID 0 (Prompt = "Number zero")

indir.png

Input ID 3 (Prompt = "Number three")

indir (1).png

Summary

Model can generate good quality numbers (0,1,2,3,4,5,6,7,8,9)! FULL version coomig soon!

Environmental Impact

Carbon emissions can be estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019).

  • Hardware Type: A100 40gb
  • Hours used: <0.1
  • Cloud Provider: Google COLAB
  • Compute Region: -
  • Carbon Emitted: -

Technical Specifications [optional]

Model Architecture and Objective

RRN + VAE + Attention

Model Card Authors [optional]

Bertug Gunel

Model Card Contact

[email protected]

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support