from datasets import load_dataset from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer import openai import faiss import numpy as np # Set up OpenAI API key for GPT-4 openai.api_key = "your_openai_api_key" # Load PubMedBERT tokenizer and model tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2) # Load the FDA dataset from Hugging Face dataset = load_dataset("pretzinger/cdx-cleared-approved", split="train") # Tokenize the dataset def tokenize_function(example): return tokenizer(example["text"], padding="max_length", truncation=True) tokenized_dataset = dataset.map(tokenize_function, batched=True) # FAISS setup for vector search (embedding-based memory) dimension = 768 # PubMedBERT embedding size index = faiss.IndexFlatL2(dimension) def embed_text(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) outputs = model(**inputs) return outputs.last_hidden_state.mean(dim=1).detach().numpy() # Example: Embed past conversation and store in FAISS past_conversation = "FDA approval for companion diagnostics requires careful documentation." past_embedding = embed_text(past_conversation) index.add(past_embedding) # Embed the incoming query and search for related memory def search_memory(query): query_embedding = embed_text(query) D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation return I # Function to handle FDA-related queries with PubMedBERT def handle_fda_query(query): # If query requires specific FDA info, process it with PubMedBERT inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True) outputs = model(**inputs) logits = outputs.logits # Process logits for classification or output a meaningful response response = "Processed FDA-related query via PubMedBERT" return response # Function to handle general queries using GPT-4 def handle_openai_query(prompt): response = openai.Completion.create( engine="gpt-4", # Ensuring GPT-4 usage prompt=prompt, max_tokens=100 ) return response.choices[0].text.strip() # Main assistant function that delegates to either OpenAI or PubMedBERT def assistant(query): # First, determine if query needs FDA-specific info openai_response = handle_openai_query(f"Is this query FDA-related: {query}") if "FDA" in openai_response or "regulatory" in openai_response: # Search past conversations/memory using FAISS memory_index = search_memory(query) if memory_index: return f"Found relevant past memory: {past_conversation}" # Return past context from memory # If no memory match, proceed with PubMedBERT return handle_fda_query(query) # General conversational handling with OpenAI (GPT-4) return openai_response # Example Usage query = "What is required for PMA approval for companion diagnostics?" response = assistant(query) print(response)