import gradio as gr from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration import json import numpy as np # Load horoscope data with open("horoscope_data.json", "r") as file: horoscope_data = json.load(file) # Custom Retriever that looks up horoscopes class CustomHoroscopeRetriever(RagRetriever): def __init__(self, horoscope_data): self.horoscope_data = horoscope_data def retrieve(self, question_texts, n_docs=1): # Convert numpy arrays to lists if needed if isinstance(question_texts, np.ndarray): question_texts = question_texts.tolist() # Ensure question_texts is a list of strings if isinstance(question_texts, list): question_texts = question_texts[0] # Get the first element if isinstance(question_texts, list): # If it's still a list, get the first string question_texts = question_texts[0] # Ensure the text is a string if isinstance(question_texts, str): zodiac_sign = question_texts # Use as-is else: return ["I couldn't process your request. Please try again with a valid zodiac sign."] if zodiac_sign in self.horoscope_data: return [self.horoscope_data[zodiac_sign]] else: return ["I couldn't find your zodiac sign. Please try again with a valid one."] # Initialize the custom retriever with the horoscope data retriever = CustomHoroscopeRetriever(horoscope_data) # Initialize RAG components tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) # Define the chatbot function def horoscope_chatbot(input_text): input_ids = tokenizer(input_text, return_tensors="pt").input_ids generated_ids = model.generate(input_ids=input_ids) generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Set up Gradio interface iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot") # Launch the interface with public link iface.launch(share=True)