Model Card for Omni-DNA

Requirement

pip install datasets ai2-olmo

Overview

Omni-DNA is a cross-modal, multi-task genomic foundation model designed to generalize across diverse genomic tasks. Unlike previous Genomic Foundation Models (GFMs), which require separate fine-tuning for each task, Omni-DNA leverages auto-regressive transformer-based training and multi-task fine-tuning, enabling a single model to perform a wide range of genomic tasks with state-of-the-art performance.

Omni-DNA models range from 20M to 1B parameters and support tasks such as sequence annotation, regulatory element classification, acetylation/methylation prediction, and DNA2Function/DNA2Image mapping.

Multi-Task Performance

Histone Mark Accuracy F1 Score Matthews Correlation Precision Recall
H3 0.8778 0.8775 0.7558 0.8787 0.8771
H4 0.8919 0.8897 0.7808 0.8934 0.8873
H3K9ac 0.8543 0.8531 0.7069 0.8522 0.8548
H3K14ac 0.9540 0.9532 0.9065 0.9526 0.9538
H4ac 0.9396 0.9392 0.8785 0.9390 0.9395
H3K4me1 0.7926 0.7912 0.5825 0.7909 0.7916
H3K4me2 0.8664 0.8623 0.7248 0.8611 0.8638
H3K4me3 0.9046 0.9041 0.8084 0.9049 0.9034
H3K36me3 0.8581 0.8571 0.7142 0.8571 0.8570
H3K79me3 0.8651 0.8641 0.7291 0.8660 0.8631

Model Description

  • Supported by: Microsoft Research Asia
  • Model type: Auto-regressive transformer-based genomic model
  • License: mit
  • Date cutoff: 2024
  • Contact: Research inquiries at [email protected]

Model Sources

Capabilities

Omni-DNA is trained to perform multiple genomic tasks including:

  • Regulatory Element Classification: Enhancer/promoter/splice site detection
  • Histone Modification Prediction: Acetylation and methylation state identification
  • Genomic Function Annotation: DNA-to-text mapping (DNA2Function)
  • Cross-modal Learning: DNA-to-image mapping (DNA2Image)
  • Multi-task Learning: A single model can solve multiple tasks simultaneously

Usage

Inference on 10 tasks at the same time

import argparse
import os
import torch
import re
import numpy as np
import sklearn
from tqdm import tqdm
from datasets import load_dataset, DatasetDict
from transformers import AutoModelForCausalLM, AutoTokenizer

def parse_args():
    parser = argparse.ArgumentParser(description="Run DNA task inference with a specified model and tokenizer.")
    parser.add_argument(
        "--model_tokenizer_path",
        type=str,
        default="zehui127/Omni-DNA-Multitask",  # Set default value
        help="Path to the pretrained model and tokenizer. Default: zehui127/Omni-DNA-Multitask"
    )
    return parser.parse_args()

def load_model_and_tokenizer(model_tokenizer_path):
    tokenizer = AutoTokenizer.from_pretrained(model_tokenizer_path)
    model = AutoModelForCausalLM.from_pretrained(model_tokenizer_path).to('cuda')
    return model, tokenizer

def generate(message, task_type, model, tokenizer, sample_num=1):
    tokenized_message = tokenizer([message], return_tensors='pt', return_token_type_ids=False, add_special_tokens=True).to('cuda')
    response = model.generate(**tokenized_message, max_new_tokens=sample_num, do_sample=False)
    reply = tokenizer.batch_decode(response, skip_special_tokens=False)[0].replace(" ", "")
    return extract_label(reply, task_type)

def extract_label(message, task_type):
    task_type = '[MASK]'
    answer = message.split(task_type)[1]
    match = re.search(r'\d+', answer)
    return match.group() if match else None

def load_and_format_dataset():
    raw_dataset = load_dataset("zehui127/Omni-DNA-dataset-nt-downstream-multitask")
    dataset = raw_dataset['test']

    def formatting_prompts_func(example):
        output_texts = [f"{instr}[MASK]" for instr in example['instruction']]
        labels = [output[-1] for output in example['output']]
        task_types = example['task']
        return {'formatted_text': output_texts, 'label': labels, 'task_type': task_types}

    formatted_dataset = dataset.map(formatting_prompts_func, batched=True, remove_columns=dataset.column_names, desc="Formatting dataset")
    return formatted_dataset

def group_by_task_type(dataset):
    task_types = set(dataset['task_type'])
    task_datasets = DatasetDict()

    for task_type in task_types:
        filtered_dataset = dataset.filter(lambda x: x['task_type'] == task_type, num_proc=1, desc=f"Filtering {task_type} examples")
        if len(filtered_dataset) > 0:
            task_datasets[task_type] = filtered_dataset
            print(f"\nTask type '{task_type}': {len(filtered_dataset)} examples")

    return task_datasets

def calculate_metrics(predictions, labels):
    valid_mask = labels != -100
    valid_predictions = predictions[valid_mask]
    valid_labels = labels[valid_mask]
    return {
        "accuracy": sklearn.metrics.accuracy_score(valid_labels, valid_predictions),
        "f1": sklearn.metrics.f1_score(valid_labels, valid_predictions, average="macro", zero_division=0),
        "matthews_correlation": sklearn.metrics.matthews_corrcoef(valid_labels, valid_predictions),
        "precision": sklearn.metrics.precision_score(valid_labels, valid_predictions, average="macro", zero_division=0),
        "recall": sklearn.metrics.recall_score(valid_labels, valid_predictions, average="macro", zero_division=0),
    }

def inference(dataset, model, tokenizer):
    predictions, labels = [], []

    for element in tqdm(dataset):
        prediction = generate(element['formatted_text'], element['task_type'], model, tokenizer)
        sample_num = 2
        while prediction is None:
            prediction = generate(element['formatted_text'], element['task_type'], model, tokenizer, sample_num)
            sample_num += 1
            if sample_num >= 20:
                prediction = '0'
                print("Warning: No valid result")
                break
        predictions.append(int(str(prediction)[0]))
        labels.append(int(element['label']))

    return calculate_metrics(np.array(predictions), np.array(labels))

def main():
    args = parse_args()
    model, tokenizer = load_model_and_tokenizer(args.model_tokenizer_path)
    formatted_dataset = load_and_format_dataset()
    task_specific_datasets = group_by_task_type(formatted_dataset)

    tasks = ['H3', 'H4', 'H3K9ac', 'H3K14ac', 'H4ac', 'H3K4me1', 'H3K4me2', 'H3K4me3', 'H3K36me3', 'H3K79me3']

    for task in tasks:
        print(f"==========={task}=========")
        dataset_test = task_specific_datasets.get(task, None)
        if dataset_test:
            print(inference(dataset_test, model, tokenizer))
        else:
            print(f"No data for task {task}")

if __name__ == "__main__":
    main()
Downloads last month
31
Safetensors
Model size
1.09B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using zehui127/Omni-DNA-Multitask 1

Collection including zehui127/Omni-DNA-Multitask