VPCSinfo commited on
Commit
ff3806f
Β·
1 Parent(s): 1cdeb09

[IMP] added improvement

Browse files
Files changed (2) hide show
  1. .gitignore +1 -0
  2. retriever.py +70 -8
.gitignore CHANGED
@@ -51,6 +51,7 @@ Thumbs.db
51
  # Gradio specific
52
  gradio_cached_examples/
53
  flagged/
 
54
 
55
  # Environment variables
56
  .env
 
51
  # Gradio specific
52
  gradio_cached_examples/
53
  flagged/
54
+ .gradio/
55
 
56
  # Environment variables
57
  .env
retriever.py CHANGED
@@ -1,12 +1,67 @@
1
  from smolagents import Tool
2
- from langchain_community.retrievers import BM25Retriever
3
  from langchain.docstore.document import Document
 
 
4
  import datasets
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class GuestInfoRetrieverTool(Tool):
8
  name = "guest_info_retriever"
9
- description = "Retrieves detailed information about gala guests based on their name or relation."
10
  inputs = {
11
  "query": {
12
  "type": "string",
@@ -15,10 +70,9 @@ class GuestInfoRetrieverTool(Tool):
15
  }
16
  output_type = "string"
17
 
18
- def __init__(self, docs):
19
  self.is_initialized = False
20
- self.retriever = BM25Retriever.from_documents(docs)
21
-
22
 
23
  def forward(self, query: str):
24
  results = self.retriever.get_relevant_documents(query)
@@ -28,7 +82,15 @@ class GuestInfoRetrieverTool(Tool):
28
  return "No matching guest information found."
29
 
30
 
31
- def load_guest_dataset():
 
 
 
 
 
 
 
 
32
  # Load the dataset
33
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
34
 
@@ -46,8 +108,8 @@ def load_guest_dataset():
46
  for guest in guest_dataset
47
  ]
48
 
49
- # Return the tool
50
- return GuestInfoRetrieverTool(docs)
51
 
52
 
53
 
 
1
  from smolagents import Tool
 
2
  from langchain.docstore.document import Document
3
+ from sentence_transformers import SentenceTransformer
4
+ import numpy as np
5
  import datasets
6
+ from typing import List
7
+
8
+
9
+ class SentenceTransformerRetriever:
10
+ """Retriever that uses SentenceTransformer embeddings for semantic search."""
11
+
12
+ def __init__(self, docs: List[Document], model_name: str = "all-MiniLM-L6-v2"):
13
+ """Initialize with documents and a SentenceTransformer model.
14
+
15
+ Args:
16
+ docs: List of Document objects
17
+ model_name: Name of the SentenceTransformer model to use
18
+ """
19
+ self.docs = docs
20
+ self.model = SentenceTransformer(model_name)
21
+
22
+ # Create embeddings for all documents
23
+ self.doc_texts = [doc.page_content for doc in self.docs]
24
+ # Ensure we get numpy arrays for document embeddings
25
+ self.doc_embeddings = self.model.encode(self.doc_texts, convert_to_numpy=True)
26
+
27
+ def get_relevant_documents(self, query: str, k: int = 3) -> List[Document]:
28
+ """Return documents relevant to the query.
29
+
30
+ Args:
31
+ query: Query string
32
+ k: Number of documents to return
33
+
34
+ Returns:
35
+ List of relevant Document objects
36
+ """
37
+ # Encode the query and ensure we get a numpy array
38
+ query_embedding = self.model.encode(query, convert_to_numpy=True)
39
+
40
+ # Calculate similarities
41
+ # Calculate cosine similarity manually to avoid tensor conversion issues
42
+ similarities = []
43
+ for doc_embedding in self.doc_embeddings:
44
+ # Calculate cosine similarity between query and document
45
+ dot_product = np.dot(query_embedding, doc_embedding)
46
+ query_norm = np.linalg.norm(query_embedding)
47
+ doc_norm = np.linalg.norm(doc_embedding)
48
+ similarity = dot_product / (query_norm * doc_norm)
49
+ similarities.append(similarity)
50
+
51
+ # Convert to numpy array
52
+ similarities = np.array(similarities)
53
+
54
+ # Get the top k most similar documents
55
+ # Sort indices by similarity in descending order and take the top k
56
+ top_k_indices = np.argsort(-similarities)[:k]
57
+
58
+ # Return the top k documents
59
+ return [self.docs[i] for i in top_k_indices]
60
 
61
 
62
  class GuestInfoRetrieverTool(Tool):
63
  name = "guest_info_retriever"
64
+ description = "Retrieves detailed information about gala guests based on their name or relation using semantic search."
65
  inputs = {
66
  "query": {
67
  "type": "string",
 
70
  }
71
  output_type = "string"
72
 
73
+ def __init__(self, docs, model_name: str = "all-MiniLM-L6-v2"):
74
  self.is_initialized = False
75
+ self.retriever = SentenceTransformerRetriever(docs, model_name)
 
76
 
77
  def forward(self, query: str):
78
  results = self.retriever.get_relevant_documents(query)
 
82
  return "No matching guest information found."
83
 
84
 
85
+ def load_guest_dataset(model_name: str = "all-MiniLM-L6-v2"):
86
+ """Load the guest dataset and create a retriever tool.
87
+
88
+ Args:
89
+ model_name: Name of the SentenceTransformer model to use
90
+
91
+ Returns:
92
+ GuestInfoRetrieverTool: A tool for retrieving guest information
93
+ """
94
  # Load the dataset
95
  guest_dataset = datasets.load_dataset("agents-course/unit3-invitees", split="train")
96
 
 
108
  for guest in guest_dataset
109
  ]
110
 
111
+ # Return the tool with the specified model
112
+ return GuestInfoRetrieverTool(docs, model_name=model_name)
113
 
114
 
115