keerthi-balaji commited on
Commit
2483b3f
·
verified ·
1 Parent(s): 89a4d2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -1
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
3
  import json
 
4
 
5
  # Load horoscope data
6
  with open("horoscope_data.json", "r") as file:
@@ -12,13 +13,22 @@ class CustomHoroscopeRetriever(RagRetriever):
12
  self.horoscope_data = horoscope_data
13
 
14
  def retrieve(self, question_texts, n_docs=1):
 
 
 
 
15
  # Ensure question_texts is a list of strings
16
  if isinstance(question_texts, list):
17
  question_texts = question_texts[0] # Get the first element
18
  if isinstance(question_texts, list): # If it's still a list, get the first string
19
  question_texts = question_texts[0]
20
 
21
- zodiac_sign = question_texts.capitalize() # Now it should be a string
 
 
 
 
 
22
  if zodiac_sign in self.horoscope_data:
23
  return [self.horoscope_data[zodiac_sign]]
24
  else:
 
1
  import gradio as gr
2
  from transformers import RagTokenizer, RagRetriever, RagTokenForGeneration
3
  import json
4
+ import numpy as np
5
 
6
  # Load horoscope data
7
  with open("horoscope_data.json", "r") as file:
 
13
  self.horoscope_data = horoscope_data
14
 
15
  def retrieve(self, question_texts, n_docs=1):
16
+ # Convert numpy arrays to lists if needed
17
+ if isinstance(question_texts, np.ndarray):
18
+ question_texts = question_texts.tolist()
19
+
20
  # Ensure question_texts is a list of strings
21
  if isinstance(question_texts, list):
22
  question_texts = question_texts[0] # Get the first element
23
  if isinstance(question_texts, list): # If it's still a list, get the first string
24
  question_texts = question_texts[0]
25
 
26
+ # Ensure the text is a string
27
+ if isinstance(question_texts, str):
28
+ zodiac_sign = question_texts # Use as-is
29
+ else:
30
+ return ["I couldn't process your request. Please try again with a valid zodiac sign."]
31
+
32
  if zodiac_sign in self.horoscope_data:
33
  return [self.horoscope_data[zodiac_sign]]
34
  else: