Ferris2dotOh / app.py
pretzinger's picture
Updated app.py w new code
cd77e73 verified
raw
history blame
3.21 kB
from datasets import load_dataset
from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
import openai
import faiss
import numpy as np
# Set up OpenAI API key for GPT-4
openai.api_key = "your_openai_api_key"
# Load PubMedBERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
# Load the FDA dataset from Hugging Face
dataset = load_dataset("pretzinger/cdx-cleared-approved", split="train")
# Tokenize the dataset
def tokenize_function(example):
return tokenizer(example["text"], padding="max_length", truncation=True)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# FAISS setup for vector search (embedding-based memory)
dimension = 768 # PubMedBERT embedding size
index = faiss.IndexFlatL2(dimension)
def embed_text(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).detach().numpy()
# Example: Embed past conversation and store in FAISS
past_conversation = "FDA approval for companion diagnostics requires careful documentation."
past_embedding = embed_text(past_conversation)
index.add(past_embedding)
# Embed the incoming query and search for related memory
def search_memory(query):
query_embedding = embed_text(query)
D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation
return I
# Function to handle FDA-related queries with PubMedBERT
def handle_fda_query(query):
# If query requires specific FDA info, process it with PubMedBERT
inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
outputs = model(**inputs)
logits = outputs.logits
# Process logits for classification or output a meaningful response
response = "Processed FDA-related query via PubMedBERT"
return response
# Function to handle general queries using GPT-4
def handle_openai_query(prompt):
response = openai.Completion.create(
engine="gpt-4", # Ensuring GPT-4 usage
prompt=prompt,
max_tokens=100
)
return response.choices[0].text.strip()
# Main assistant function that delegates to either OpenAI or PubMedBERT
def assistant(query):
# First, determine if query needs FDA-specific info
openai_response = handle_openai_query(f"Is this query FDA-related: {query}")
if "FDA" in openai_response or "regulatory" in openai_response:
# Search past conversations/memory using FAISS
memory_index = search_memory(query)
if memory_index:
return f"Found relevant past memory: {past_conversation}" # Return past context from memory
# If no memory match, proceed with PubMedBERT
return handle_fda_query(query)
# General conversational handling with OpenAI (GPT-4)
return openai_response
# Example Usage
query = "What is required for PMA approval for companion diagnostics?"
response = assistant(query)
print(response)