Spaces:
Runtime error
Runtime error
import time | |
from about import show_about_ask2democracy | |
import streamlit as st | |
from pinecone_quieries import PineconeProposalQueries | |
from config import Config | |
from samples import * | |
queries = PineconeProposalQueries (index_name= Config.index_name, | |
api_key = Config.es_password, | |
environment = Config.pinecone_environment, | |
embedding_dim = Config.embedding_dim, | |
reader_name_or_path = Config.reader_model_name_or_path, | |
use_gpu = Config.use_gpu, | |
OPENAI_key= None) | |
def search(question, retriever_top_k, reader_top_k, selected_index=None): | |
filters = {"source_title": selected_index} | |
query_result = queries.search_by_query(query = question, | |
retriever_top_k = retriever_top_k, | |
reader_top_k = reader_top_k, | |
filters = filters) | |
result = [] | |
for i in range(0, len(query_result)): | |
item = query_result[i] | |
result.append([[i+1], item.answer.replace("\n",""), item.context[:250], | |
item.meta['title'], item.meta['source_title'], | |
int(item.meta['page']), item.meta['source_url']]) | |
return result | |
def search_and_show_results(query:str, retriever_top_k = 5, reader_top_k =3, selected_index=None): | |
stt = time.time() | |
results = search(query, retriever_top_k=retriever_top_k, | |
reader_top_k=reader_top_k, selected_index=selected_index) | |
ent = time.time() | |
elapsed_time = round(ent - stt, 2) | |
st.write(f"**Resultados encontrados para la pregunta** \"{query}\" ({elapsed_time} sec.):") | |
for i, answer in enumerate(results): | |
st.subheader(f"{answer[1]}") | |
doc = answer[2][:250] + "..." | |
st.markdown(f"{doc}[Lee más aquí]({answer[6]})", unsafe_allow_html=True) | |
st.caption(f"Fuente: {answer[4]} - Artículo: {answer[3]} - Página: {answer[5]}") | |
def search_and_generate_answer(question, retriever_top_k, generator_top_k, | |
openai_api_key, openai_model_name= "text-davinci-003", | |
temperature = .5, max_tokens = 30, selected_index = None): | |
filters = {"source_title": selected_index} | |
query_result = queries.genenerate_answer_OpenAI(query = question, | |
retriever_top_k = retriever_top_k, | |
generator_top_k = generator_top_k, | |
filters = filters, OPENAI_key = openai_api_key, | |
openai_model_name= openai_model_name,temperature = temperature, max_tokens = max_tokens) | |
result = [] | |
for i in range(0, len(query_result)): | |
item = query_result[i] | |
source_title = item.meta['doc_metas'][0]['source_title'] | |
source_url = item.meta['doc_metas'][0]['source_url'] | |
chapter_titles = [source['title'] for source in item.meta['doc_metas']] | |
result.append([[i+1], item.answer.replace("\n",""), | |
source_title, source_url, str(chapter_titles)]) | |
return result | |
def search_and_show_generative_results(query:str, retriever_top_k = 5, generator_top_k =1 , openai_api_key = None, openai_model_name = "text-davinci-003", temperature = .5, max_tokens = 30, selected_index = None): | |
# set start time | |
stt = time.time() | |
results = search_and_generate_answer(query, retriever_top_k = retriever_top_k, | |
generator_top_k= generator_top_k, | |
openai_api_key = openai_api_key, | |
openai_model_name= openai_model_name, | |
temperature = temperature, max_tokens = max_tokens, | |
selected_index = selected_index) | |
ent = time.time() | |
elapsed_time = round(ent - stt, 2) | |
st.write(f"**Respuesta generada para la pregunta** \"{query}\" ({elapsed_time} sec.):") | |
if results != None: | |
for i, answer in enumerate(results): | |
# answer starts with header | |
st.subheader(f"{answer[1]}") | |
st.caption(f"Fuentes: {answer[2]} - {answer[4]}") | |
st.markdown(f"[Lee más aquí]({answer[3]})") | |
indexes = [{"title": "Propuesta reforma a la salud 13 de febrero de 2023", "name": "Reforma de la salud 13 Febrero 2023", "samples": samples_reforma_salud}, | |
{"title": "Propuesta reforma pensional marzo 22 de 2023", "name": "Reforma pensional Marzo 2023", "samples": samples_reforma_pensional}, | |
{"title": "Hallazgos de la comisión de la verdad", "name": "Hallazgos y recomendaciones - 28 de Junio 2022", "samples": samples_hallazgos_paz} | |
] | |
index_titles = [item["title"] for item in indexes] | |
def get_selected_index_by_title(title): | |
for item in indexes: | |
if item["title"] == title: | |
return item["name"] | |
return None | |
def get_samples_for_index(title): | |
for item in indexes: | |
if item["title"] == title: | |
return item["samples"] | |
return None | |
def main(): | |
st.title("Ask2Democracy 🇨🇴") | |
st.markdown(""" | |
<div align="right"> | |
Creado por Jorge Henao 🇨🇴 <a href="https://twitter.com/jhenaotw" target='_blank'>Twitter</a> <a href="https://www.linkedin.com/in/henaojorge" target='_blank'>LinkedIn</a> <a href="https://linktr.ee/jorgehenao" target='_blank'>Linktree</a> | |
</div>""", unsafe_allow_html=True) | |
# session_state = st.session_state | |
# if "api_key" not in session_state: | |
# session_state.api_key = "" | |
with st.form("my_form"): | |
st.sidebar.title("Configuración de búsqueda") | |
with st.sidebar.expander("Parámetros de recuperación", expanded= True): | |
index = st.selectbox("Selecciona el documento que deseas explorar", index_titles) | |
top_k_retriever = st.slider("Retriever Top K", 1, 10, 5) | |
top_k_reader = st.slider("Reader Top K", 1, 10, 3) | |
with st.sidebar.expander("Configuración OpenAI"): | |
openai_api_key = st.text_input("API Key", type="password", placeholder="Copia aquí tu OpenAI API key (no será guardada)", | |
help="puedes obtener tu api key de OpenAI en https://platform.openai.com/account/api-keys.") | |
openai_api_model = st.text_input("Modelo", value= "text-davinci-003") | |
openai_api_temp = st.slider("Temperatura", 0.1, 1.0, 0.5, step=0.1) | |
openai_api_max_tokens = st.slider("Max tokens", 10, 100, 60, step=10) | |
# if openai_api_key: | |
# session_state.password = openai_api_key | |
sample_questions = get_samples_for_index(index).splitlines() | |
query = st.text_area("",placeholder="Escribe aquí tu pregunta, cuanto más contexto le des, mejor serán las respuestas") | |
with st.expander("Algunas preguntas de ejemplo", expanded= False): | |
for sample in sample_questions: | |
st.markdown(f"- {sample}") | |
submited = st.form_submit_button("Buscar") | |
if submited: | |
selected_index = get_selected_index_by_title(index) | |
if openai_api_key: | |
with st.expander("", expanded= True): | |
search_and_show_generative_results(query = query,retriever_top_k= top_k_retriever, | |
generator_top_k= 1, openai_api_key = openai_api_key, | |
openai_model_name = openai_api_model, | |
temperature= openai_api_temp, | |
max_tokens= openai_api_max_tokens, | |
selected_index = selected_index) | |
with st.expander("", expanded= True): | |
search_and_show_results(query, retriever_top_k=top_k_retriever, | |
reader_top_k=top_k_reader, | |
selected_index=selected_index) | |
else: | |
show_about_ask2democracy() | |
if __name__ == "__main__": | |
main() |