Spaces:
Sleeping
Sleeping
added app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import torch
|
| 4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
from langchain.vectorstores import Qdrant
|
| 6 |
+
from qdrant_client import QdrantClient
|
| 7 |
+
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
|
| 8 |
+
from config import DB_CONFIG
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@st.cache_resource
|
| 12 |
+
def load_embeddings():
|
| 13 |
+
model_name = "intfloat/multilingual-e5-large"
|
| 14 |
+
model_kwargs = {"device": "cuda:0" if torch.cuda.is_available() else "cpu"}
|
| 15 |
+
encode_kwargs = {"normalize_embeddings": False}
|
| 16 |
+
embeddings = HuggingFaceEmbeddings(
|
| 17 |
+
model_name=model_name,
|
| 18 |
+
model_kwargs=model_kwargs,
|
| 19 |
+
encode_kwargs=encode_kwargs,
|
| 20 |
+
)
|
| 21 |
+
return embeddings
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
EMBEDDINGS = load_embeddings()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def make_filter_obj(options: list[dict[str]]):
|
| 28 |
+
must = []
|
| 29 |
+
for option in options:
|
| 30 |
+
must.append(
|
| 31 |
+
FieldCondition(key=option["key"], match=MatchValue(value=option["value"]))
|
| 32 |
+
)
|
| 33 |
+
filter = Filter(must=must)
|
| 34 |
+
return filter
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_similay(query: str, filter: Filter):
|
| 38 |
+
db_url, db_api_key, db_collection_name = DB_CONFIG
|
| 39 |
+
client = QdrantClient(url=db_url, api_key=db_api_key)
|
| 40 |
+
db = Qdrant(
|
| 41 |
+
client=client, collection_name=db_collection_name, embeddings=EMBEDDINGS
|
| 42 |
+
)
|
| 43 |
+
docs = db.similarity_search_with_score(
|
| 44 |
+
query,
|
| 45 |
+
k=20,
|
| 46 |
+
filter=filter,
|
| 47 |
+
)
|
| 48 |
+
return docs
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main(
|
| 52 |
+
query: str,
|
| 53 |
+
repo_name: str,
|
| 54 |
+
) -> Iterable[tuple[str, tuple[str, str]]]:
|
| 55 |
+
options = [{"key": "metadata.repo_name", "value": repo_name}]
|
| 56 |
+
filter = make_filter_obj(options=options)
|
| 57 |
+
docs = get_similay(query, filter)
|
| 58 |
+
for doc, score in docs:
|
| 59 |
+
text = doc.page_content
|
| 60 |
+
metadata = doc.metadata
|
| 61 |
+
# print(metadata)
|
| 62 |
+
title = metadata.get("title")
|
| 63 |
+
url = metadata.get("url")
|
| 64 |
+
id_ = metadata.get("id")
|
| 65 |
+
is_comment = metadata.get("type_") == "comment"
|
| 66 |
+
yield title, url, id_, text, score, is_comment
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
with st.form("my_form"):
|
| 70 |
+
st.title("GitHub Issue Search")
|
| 71 |
+
query = st.text_input(label="query")
|
| 72 |
+
repo_name = st.radio(
|
| 73 |
+
options=["cocoa", "plone", "volto", "plone.restapi"], label="Repo name"
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
submitted = st.form_submit_button("Submit")
|
| 77 |
+
if submitted:
|
| 78 |
+
st.divider()
|
| 79 |
+
st.header("Search Results")
|
| 80 |
+
st.divider()
|
| 81 |
+
with st.spinner("Searching..."):
|
| 82 |
+
results = main(query, repo_name)
|
| 83 |
+
for title, url, id_, text, score, is_comment in results:
|
| 84 |
+
with st.container():
|
| 85 |
+
if not is_comment:
|
| 86 |
+
st.subheader(f"#{id_} - {title}")
|
| 87 |
+
else:
|
| 88 |
+
st.subheader(f"comment with {title}")
|
| 89 |
+
st.write(url)
|
| 90 |
+
st.write(text)
|
| 91 |
+
st.write(score)
|
| 92 |
+
# st.markdown(html, unsafe_allow_html=True)
|
| 93 |
+
st.divider()
|