Genomic Language Models for Metagenomic Sequence Analysis
We provide genomic language models fine-tuned for the following tasks:
- Taxonomic hierarchical classification
- Anti-microbial resistance gene identification
- Pathogenicity detection
See code for details on fine-tuning, evaluation, and implementation.
These are the official models implemented in Evaluating the Effectiveness of Parameter-Efficient Fine-Tuning in Genomic Classification Tasks.
Pretrained Foundation Models
Our models are built upon several pretrained genomic foundation models:
Nucleotide Transformer (NT)
- InstaDeepAI/nucleotide-transformer-v2-50m-multi-species
- InstaDeepAI/nucleotide-transformer-v2-100m-multi-species
- InstaDeepAI/nucleotide-transformer-v2-250m-multi-species
DNABERT
HyenaDNA
- LongSafari/hyenadna-large-1m-seqlen-hf
- LongSafari/hyenadna-medium-450k-seqlen-hf
- LongSafari/hyenadna-medium-160k-seqlen-hf
- LongSafari/hyenadna-small-32k-seqlen-hf
We sincerely thank the teams behind NT, DNABERT, and HyenaDNA for making their tokenizers and pre-trained models available for use :)
Available Fine-Tuned Models
We provide the following available models for use.
taxonomy/DNABERT-2-117M-taxonomy
taxonomy/hyenadna-large-1m-seqlen-hf-taxonomy
taxonomy/nucleotide-transformer-v2-50m-multi-species-taxonomy
amr/binary/hyenadna-small-32k-seqlen-hf
amr/binary/nucleotide-transformer-v2-100m-multi-species
amr/multiclass/DNABERT-S
amr/multiclass/hyenadna-medium-450k-seqlen-hf
amr/multiclass/nucleotide-transformer-v2-250m-multi-species
pathogenicity/hyenadna-small-32k-seqlen-hf-DeePaC-fungal
pathogenicity/hyenadna-small-32k-seqlen-hf-DeePaC-viral
pathogenicity/hyenadna-small-32k-seqlen-hf-DeepSim-bacterial
pathogenicity/hyenadna-small-32k-seqlen-hf-DeepSim-viral
pathogenicity/nucleotide-transformer-v2-50m-multi-species-DeePaC-fungal
pathogenicity/nucleotide-transformer-v2-50m-multi-species-DeePaC-viral
pathogenicity/nucleotide-transformer-v2-50m-multi-species-DeepSim-bacterial
pathogenicity/nucleotide-transformer-v2-50m-multi-species-DeepSim-viral
To use these models, download the directories available here. You should also follow the installation instructions available at our code. There are two available modes of operation: setup from source code and setup from our pre-built docker image. Given that you have followed the setup instructions from source code and have downloaded the model directories here, here is sample code to run inference:
import json
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import (
AutoTokenizer,
)
from safetensors.torch import load_file
from analysis.experiment.utils.data_processor import DataProcessor
from analysis.experiment.models.hierarchical_model import (
HierarchicalClassificationModel,
)
# Replace with base directory containing all data processor, base model tokenizers, and trained model weights files
model_dir = Path('data/LongSafari__hyenadna-large-1m-seqlen-hf')
data_processor_dir = model_dir / "data_processor" # replace with directory containing your data processor
metadata_path = data_processor_dir / "metadata.json"
base_model_dir = model_dir / "base_model" # replace with directory containing your base model files
trained_model_dir = model_dir / "model" # replace with directory containing your trained model files
trained_model_path = trained_model_dir / "model.safetensors"
# Load metadata
with open(metadata_path, "r") as f:
metadata = json.load(f)
sequence_column = metadata["sequence_column"]
labels = metadata["labels"]
data_processor_filename = 'data_processor.pkl'
# load data processor
data_processor = DataProcessor(
sequence_column=sequence_column,
labels=labels,
save_file=data_processor_filename,
)
data_processor.load_processor(data_processor_dir)
# Get metadata-driven values
num_labels = data_processor.num_labels
class_weights = data_processor.class_weights
# Load tokenizer from Hugging Face Hub or local path
tokenizer = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=base_model_dir.as_posix(),
trust_remote_code=True,
local_files_only=True,
)
# Load fine-tuned model weights
model = HierarchicalClassificationModel(base_model_dir.as_posix(), num_labels, class_weights)
state_dict = load_file(trained_model_path)
model.load_state_dict(state_dict, strict=False)
input = "ATCG"
# Run inference
tokenized_input = tokenizer(
input,
return_tensors="pt", # Return results as PyTorch tensors
)
with torch.no_grad():
outputs = model(**tokenized_input)
for idx, col in enumerate(labels):
logits = outputs['logits'][idx] # [num_classes]
probs = F.softmax(logits, dim=-1).cpu()
topk = torch.topk(probs, k=1, dim=-1)
topk_index = topk.indices.numpy().ravel()
topk_prob = topk.values
topk_label = data_processor.encoders[col].inverse_transform(topk_index)
Authors & Contact
- Daniel Berman β [email protected]
- Daniel Jimenez β [email protected]
- Stanley Ta β [email protected]
- Brian Merritt β [email protected]
- Jeremy Ratcliff β [email protected]
- Vijay Narayan β [email protected]
- Molly Gallagher - [email protected]
Acknowledgement
This work was supported by funding from the U.S. Centers for Disease Control and Prevention through the Office of Readiness and Response under Contract # 75D30124C20202.