Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration | |
import json | |
import numpy as np | |
# Load horoscope data | |
with open("horoscope_data.json", "r") as file: | |
horoscope_data = json.load(file) | |
# Custom Retriever that looks up horoscopes | |
class CustomHoroscopeRetriever(RagRetriever): | |
def __init__(self, horoscope_data): | |
self.horoscope_data = horoscope_data | |
def retrieve(self, question_texts, n_docs=1): | |
# Convert numpy arrays to lists if needed | |
if isinstance(question_texts, np.ndarray): | |
question_texts = question_texts.tolist() | |
# Ensure question_texts is a list of strings | |
if isinstance(question_texts, list): | |
question_texts = question_texts[0] # Get the first element | |
if isinstance(question_texts, list): # If it's still a list, get the first string | |
question_texts = question_texts[0] | |
# Ensure the text is a string | |
if isinstance(question_texts, str): | |
zodiac_sign = question_texts # Use as-is | |
else: | |
return ["I couldn't process your request. Please try again with a valid zodiac sign."] | |
if zodiac_sign in self.horoscope_data: | |
return [self.horoscope_data[zodiac_sign]] | |
else: | |
return ["I couldn't find your zodiac sign. Please try again with a valid one."] | |
# Initialize the custom retriever with the horoscope data | |
retriever = CustomHoroscopeRetriever(horoscope_data) | |
# Initialize RAG components | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") | |
model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=retriever) | |
# Define the chatbot function | |
def horoscope_chatbot(input_text): | |
input_ids = tokenizer(input_text, return_tensors="pt").input_ids | |
generated_ids = model.generate(input_ids=input_ids) | |
generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
# Set up Gradio interface | |
iface = gr.Interface(fn=horoscope_chatbot, inputs="text", outputs="text", title="Horoscope RAG Chatbot") | |
# Launch the interface with public link | |
iface.launch(share=True) | |