metadata
license: mit
Model Card for Omni-DNA
Requirement
pip install datasets ai2-olmo
Overview
Omni-DNA is a cross-modal, multi-task genomic foundation model designed to generalize across diverse genomic tasks. Unlike previous Genomic Foundation Models (GFMs), which require separate fine-tuning for each task, Omni-DNA leverages auto-regressive transformer-based training and multi-task fine-tuning, enabling a single model to perform a wide range of genomic tasks with state-of-the-art performance.
Omni-DNA models range from 20M to 1B parameters and support tasks such as sequence annotation, regulatory element classification, acetylation/methylation prediction, and DNA2Function/DNA2Image mapping.
Base Model Details
Size | Training Tokens | Layers | Hidden Size | Attention Heads | Context Length |
---|---|---|---|---|---|
Omni-DNA 20M | 300B | 8 | 256 | 8 | 250 |
Omni-DNA 60M | 300B | 8 | 512 | 8 | 250 |
Omni-DNA 116M | 300B | 12 | 768 | 16 | 250 |
Omni-DNA 300M | 300B | 16 | 1024 | 16 | 250 |
Omni-DNA 700M | 300B | 16 | 1536 | 16 | 250 |
Omni-DNA 1B | 300B | 16 | 2048 | 16 | 250 |
Model Description
- Supported by: Microsoft Research Asia
- Model type: Auto-regressive transformer-based genomic model
- License: mit
- Date cutoff: 2024
- Contact: Research inquiries at
[email protected]
Model Sources
- Paper: Omni-DNA: Scaling Auto-Regressive Transformer to Multi-Tasking Genomic Foundation Model
- Codebase: https://github.com/Zehui127/Omni-DNA
- Dataset: Pretrained on 300B nucleotides from multi-species genome datasets
Capabilities
Omni-DNA is trained to perform multiple genomic tasks including:
- Regulatory Element Classification: Enhancer/promoter/splice site detection
- Histone Modification Prediction: Acetylation and methylation state identification
- Genomic Function Annotation: DNA-to-text mapping (DNA2Function)
- Cross-modal Learning: DNA-to-image mapping (DNA2Image)
- Multi-task Learning: A single model can solve multiple tasks simultaneously
Usage
import argparse
import json
import os
import re
import torch
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
def preprocess_response(response, mask_token="[MASK]"):
"""
Preprocess the response to extract text after the [MASK] token.
Args:
response (str): The raw model output.
mask_token (str): The token after which the response is extracted.
Returns:
str: Processed response text.
"""
if mask_token in response:
response = response.split(mask_token, 1)[1]
response = re.sub(r'^[\sATGC]+', '', response)
return response
def generate(message, model, tokenizer):
message = message + "[MASK]"
tokenized_message = tokenizer(
[message], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True
).to('cuda')
response = model.generate(**tokenized_message, max_new_tokens=110, do_sample=False)
reply = tokenizer.batch_decode(response, skip_special_tokens=True)[0]
return preprocess_response(reply)
model_tokenizer_path = "zehui127/Omni-DNA-DNA2Function"
tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path).to('cuda')
# Define the input dna sequence
dna = "TGCTGGCTTCAGGGGCACAGATGCTAACATTGGAGCGATACAGAGAAGATTAACGTGGCCACTGCGCAAGCATGACATGCAAACTCGTAAAGCATTCTTTTAATTT"
generate(dna, model, tokenizer)