Spaces:
Sleeping
Sleeping
from datasets import load_dataset | |
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer | |
import openai | |
import faiss | |
import numpy as np | |
# Set up OpenAI API key for GPT-4 | |
openai.api_key = "your_openai_api_key" | |
# Load PubMedBERT tokenizer and model | |
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) | |
# Load the FDA dataset from Hugging Face | |
dataset = load_dataset("pretzinger/cdx-cleared-approved", split="train") | |
# Tokenize the dataset | |
def tokenize_function(example): | |
return tokenizer(example["text"], padding="max_length", truncation=True) | |
tokenized_dataset = dataset.map(tokenize_function, batched=True) | |
# FAISS setup for vector search (embedding-based memory) | |
dimension = 768 # PubMedBERT embedding size | |
index = faiss.IndexFlatL2(dimension) | |
def embed_text(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
outputs = model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).detach().numpy() | |
# Example: Embed past conversation and store in FAISS | |
past_conversation = "FDA approval for companion diagnostics requires careful documentation." | |
past_embedding = embed_text(past_conversation) | |
index.add(past_embedding) | |
# Embed the incoming query and search for related memory | |
def search_memory(query): | |
query_embedding = embed_text(query) | |
D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation | |
return I | |
# Function to handle FDA-related queries with PubMedBERT | |
def handle_fda_query(query): | |
# If query requires specific FDA info, process it with PubMedBERT | |
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True) | |
outputs = model(**inputs) | |
logits = outputs.logits | |
# Process logits for classification or output a meaningful response | |
response = "Processed FDA-related query via PubMedBERT" | |
return response | |
# Function to handle general queries using GPT-4 | |
def handle_openai_query(prompt): | |
response = openai.Completion.create( | |
engine="gpt-4", # Ensuring GPT-4 usage | |
prompt=prompt, | |
max_tokens=100 | |
) | |
return response.choices[0].text.strip() | |
# Main assistant function that delegates to either OpenAI or PubMedBERT | |
def assistant(query): | |
# First, determine if query needs FDA-specific info | |
openai_response = handle_openai_query(f"Is this query FDA-related: {query}") | |
if "FDA" in openai_response or "regulatory" in openai_response: | |
# Search past conversations/memory using FAISS | |
memory_index = search_memory(query) | |
if memory_index: | |
return f"Found relevant past memory: {past_conversation}" # Return past context from memory | |
# If no memory match, proceed with PubMedBERT | |
return handle_fda_query(query) | |
# General conversational handling with OpenAI (GPT-4) | |
return openai_response | |
# Example Usage | |
query = "What is required for PMA approval for companion diagnostics?" | |
response = assistant(query) | |
print(response) |