BERT for Medical Named Entity Recognition (Disease Extraction)

Model Description

This model is a fine-tuned version of emilyalsentzer/Bio_ClinicalBERT for Named Entity Recognition (NER) specifically designed to extract disease names from medical text. The model uses BIO tagging schema to identify and classify disease entities in clinical narratives.

Model Details

  • Base Model: emilyalsentzer/Bio_ClinicalBERT
  • Task: Token Classification (Named Entity Recognition)
  • Domain: Medical/Healthcare
  • Target Entities: Diseases
  • Tagging Schema: BIO (Beginning-Inside-Outside)
  • Labels:
    • O: Outside (not a disease entity)
    • B-DISEASE: Beginning of a disease entity
    • I-DISEASE: Inside/continuation of a disease entity

Training Details

  • Training Epochs: 50
  • Batch Size: 16
  • Learning Rate: 2e-5
  • Optimizer: AdamW
  • Scheduler: Linear schedule with warmup
  • Max Sequence Length: 128 tokens
  • Train/Validation Split: 80/20

Performance Metrics

The model achieved the following performance on the validation set:

  • Accuracy: [Will be filled with actual values]
  • Precision: [Will be filled with actual values]
  • Recall: [Will be filled with actual values]
  • F1 Score: [Will be filled with actual values]
  • AUC: [Will be filled with actual values]

Usage

Quick Start

from transformers import AutoTokenizer, AutoModelForTokenClassification
import torch
import nltk

# Load model and tokenizer
model_name = "keanteng/bert-sentiment-wqd7007"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# Example text
text = "Patient has a history of hypertension and type 2 diabetes."

# Tokenize
tokens = nltk.word_tokenize(text)
inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)

# Predict
with torch.no_grad():
    outputs = model(**inputs)
    predictions = torch.argmax(outputs.logits, dim=2)

# Map predictions to labels
id2label = {0: 'O', 1: 'B-DISEASE', 2: 'I-DISEASE'}
predicted_labels = [id2label[pred.item()] for pred in predictions[0]]

# Extract diseases
diseases = []
current_disease = []
word_ids = inputs.word_ids()

for i, (word_idx, label) in enumerate(zip(word_ids, predicted_labels)):
    if word_idx is not None and word_idx < len(tokens):
        if label == 'B-DISEASE':
            if current_disease:
                diseases.append(' '.join(current_disease))
                current_disease = []
            current_disease.append(tokens[word_idx])
        elif label == 'I-DISEASE' and current_disease:
            current_disease.append(tokens[word_idx])
        elif current_disease:
            diseases.append(' '.join(current_disease))
            current_disease = []

if current_disease:
    diseases.append(' '.join(current_disease))

print(f"Extracted diseases: {diseases}")

Using the Prediction Function

def predict_diseases(text, model, tokenizer):
    import nltk
    
    # Tokenize the text
    tokens = nltk.word_tokenize(text)
    token_tags = [(token, 'O') for token in tokens]
    
    # Prepare BERT input
    inputs = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", 
                      padding=True, truncation=True, max_length=128)
    
    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=2).squeeze(0).numpy()
    
    # Map predictions to labels
    id2tag = {0: 'O', 1: 'B-DISEASE', 2: 'I-DISEASE'}
    
    # Extract diseases
    diseases = []
    current_disease = []
    word_ids = inputs.word_ids()
    
    for i, word_idx in enumerate(word_ids):
        if word_idx is not None and i < len(predictions):
            prediction = id2tag[predictions[i]]
            if word_idx < len(tokens):
                if prediction == 'B-DISEASE':
                    if current_disease:
                        diseases.append(' '.join(current_disease))
                        current_disease = []
                    current_disease.append(tokens[word_idx])
                elif prediction == 'I-DISEASE' and current_disease:
                    current_disease.append(tokens[word_idx])
                elif current_disease:
                    diseases.append(' '.join(current_disease))
                    current_disease = []
    
    if current_disease:
        diseases.append(' '.join(current_disease))
    
    return diseases

# Example usage
text = "Patient diagnosed with hypertension, diabetes mellitus, and chronic kidney disease."
diseases = predict_diseases(text, model, tokenizer)
print(f"Extracted diseases: {diseases}")

Training Data

The model was trained on a custom dataset of medical patient records containing:

  • Medical history narratives
  • Manually extracted disease entities
  • BIO-tagged training examples

Limitations

  • The model is specifically trained for disease entity extraction
  • Performance may vary on medical texts from different domains or institutions
  • May not capture very rare or newly named diseases not seen during training
  • Limited to English language medical texts

Ethical Considerations

  • This model is intended for research and educational purposes
  • Should not be used as a substitute for professional medical diagnosis
  • Patient privacy and data protection must be ensured when using this model
  • Results should be validated by medical professionals
Downloads last month
18
Safetensors
Model size
108M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for keanteng/bert-ner-wqd7005

Finetuned
(31)
this model

Collection including keanteng/bert-ner-wqd7005