Spaces:
Runtime error
Runtime error
import logging | |
import sys | |
import os | |
import re | |
import base64 | |
import nest_asyncio | |
nest_asyncio.apply() | |
import pandas as pd | |
from pathlib import Path | |
from typing import Any, Dict, List, Optional | |
from PIL import Image | |
import streamlit as st | |
import torch | |
# Imports do LlamaIndex | |
from llama_index.core import ( | |
Settings, | |
SimpleDirectoryReader, | |
StorageContext, | |
Document | |
) | |
from llama_index.core.storage.docstore import SimpleDocumentStore | |
from llama_index.core.node_parser import LangchainNodeParser | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from llama_index.core.storage.chat_store import SimpleChatStore | |
from llama_index.core.memory import ChatMemoryBuffer | |
from llama_index.core.query_engine import RetrieverQueryEngine | |
from llama_index.core.chat_engine import CondensePlusContextChatEngine | |
from llama_index.core.retrievers import QueryFusionRetriever | |
from llama_index.vector_stores.chroma import ChromaVectorStore | |
from llama_index.core import VectorStoreIndex | |
import chromadb | |
############################################################################### | |
# MONKEY PATCH EM bm25s # | |
############################################################################### | |
import bm25s | |
############################################################################### | |
# CLASSE BM25Retriever (AJUSTADA PARA ENCODING) # | |
############################################################################### | |
import json | |
import Stemmer | |
from llama_index.core.base.base_retriever import BaseRetriever | |
from llama_index.core.callbacks.base import CallbackManager | |
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K | |
from llama_index.core.schema import ( | |
BaseNode, | |
IndexNode, | |
NodeWithScore, | |
QueryBundle, | |
MetadataMode, | |
) | |
from llama_index.core.vector_stores.utils import ( | |
node_to_metadata_dict, | |
metadata_dict_to_node, | |
) | |
from typing import cast | |
logger = logging.getLogger(__name__) | |
DEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"} | |
DEFAULT_PERSIST_FILENAME = "retriever.json" | |
class BM25Retriever(BaseRetriever): | |
""" | |
Implementação customizada do algoritmo BM25 com a lib bm25s, incluindo um | |
'monkey patch' para contornar problemas de decodificação de caracteres. | |
""" | |
def __init__( | |
self, | |
nodes: Optional[List[BaseNode]] = None, | |
stemmer: Optional[Stemmer.Stemmer] = None, | |
language: str = "en", | |
existing_bm25: Optional[bm25s.BM25] = None, | |
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, | |
callback_manager: Optional[CallbackManager] = None, | |
objects: Optional[List[IndexNode]] = None, | |
object_map: Optional[dict] = None, | |
verbose: bool = False, | |
) -> None: | |
self.stemmer = stemmer or Stemmer.Stemmer("english") | |
self.similarity_top_k = similarity_top_k | |
if existing_bm25 is not None: | |
# Usa instância BM25 existente | |
self.bm25 = existing_bm25 | |
self.corpus = existing_bm25.corpus | |
else: | |
# Cria uma nova instância BM25 a partir de 'nodes' | |
if nodes is None: | |
raise ValueError("É preciso fornecer 'nodes' ou um 'existing_bm25'.") | |
self.corpus = [node_to_metadata_dict(node) for node in nodes] | |
corpus_tokens = bm25s.tokenize( | |
[node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], | |
stopwords=language, | |
stemmer=self.stemmer, | |
show_progress=verbose, | |
) | |
self.bm25 = bm25s.BM25() | |
self.bm25.index(corpus_tokens, show_progress=verbose) | |
super().__init__( | |
callback_manager=callback_manager, | |
object_map=object_map, | |
objects=objects, | |
verbose=verbose, | |
) | |
def from_defaults( | |
cls, | |
index: Optional[VectorStoreIndex] = None, | |
nodes: Optional[List[BaseNode]] = None, | |
docstore: Optional["BaseDocumentStore"] = None, | |
stemmer: Optional[Stemmer.Stemmer] = None, | |
language: str = "en", | |
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, | |
verbose: bool = False, | |
tokenizer: Optional[Any] = None, | |
) -> "BM25Retriever": | |
if tokenizer is not None: | |
logger.warning( | |
"O parâmetro 'tokenizer' foi descontinuado e será removido " | |
"no futuro. Use um Stemmer do PyStemmer para melhor controle." | |
) | |
if sum(bool(val) for val in [index, nodes, docstore]) != 1: | |
raise ValueError("Passe exatamente um entre 'index', 'nodes' ou 'docstore'.") | |
if index is not None: | |
docstore = index.docstore | |
if docstore is not None: | |
nodes = cast(List[BaseNode], list(docstore.docs.values())) | |
assert nodes is not None, ( | |
"Não foi possível determinar os nodes. Verifique seus parâmetros." | |
) | |
return cls( | |
nodes=nodes, | |
stemmer=stemmer, | |
language=language, | |
similarity_top_k=similarity_top_k, | |
verbose=verbose, | |
) | |
def get_persist_args(self) -> Dict[str, Any]: | |
"""Dicionário com os parâmetros de persistência a serem salvos.""" | |
return { | |
DEFAULT_PERSIST_ARGS[key]: getattr(self, key) | |
for key in DEFAULT_PERSIST_ARGS | |
if hasattr(self, key) | |
} | |
def persist(self, path: str, **kwargs: Any) -> None: | |
""" | |
Persiste o retriever em um diretório, incluindo | |
a estrutura do BM25 e o corpus em JSON. | |
""" | |
self.bm25.save(path, corpus=self.corpus, **kwargs) | |
with open( | |
os.path.join(path, DEFAULT_PERSIST_FILENAME), | |
"wt", | |
encoding="utf-8", | |
errors="ignore", | |
) as f: | |
json.dump(self.get_persist_args(), f, indent=2, ensure_ascii=False) | |
def from_persist_dir(cls, path: str, **kwargs: Any) -> "BM25Retriever": | |
""" | |
Carrega o retriever de um diretório, incluindo o BM25 e o corpus. | |
Devido ao nosso patch, ignoramos qualquer erro de decodificação | |
que eventualmente apareça. | |
""" | |
bm25_obj = bm25s.BM25.load(path, load_corpus=True, **kwargs) | |
with open( | |
os.path.join(path, DEFAULT_PERSIST_FILENAME), | |
"rt", | |
encoding="utf-8", | |
errors="ignore", | |
) as f: | |
retriever_data = json.load(f) | |
return cls(existing_bm25=bm25_obj, **retriever_data) | |
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: | |
"""Recupera nós relevantes a partir do BM25.""" | |
query = query_bundle.query_str | |
tokenized_query = bm25s.tokenize( | |
query, stemmer=self.stemmer, show_progress=self._verbose | |
) | |
indexes, scores = self.bm25.retrieve( | |
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose | |
) | |
# bm25s retorna lista de listas, pois suporta batched queries | |
indexes = indexes[0] | |
scores = scores[0] | |
nodes: List[NodeWithScore] = [] | |
for idx, score in zip(indexes, scores): | |
if isinstance(idx, dict): | |
node = metadata_dict_to_node(idx) | |
else: | |
node_dict = self.corpus[int(idx)] | |
node = metadata_dict_to_node(node_dict) | |
nodes.append(NodeWithScore(node=node, score=float(score))) | |
return nodes | |
############################################################################### | |
# CONFIGURAÇÃO STREAMLIT E AJUSTES DA PIPELINE # | |
############################################################################### | |
# Evite reindexar ou baixar dados repetidamente armazenando o estado na sessão. | |
im = Image.open("pngegg.png") | |
st.set_page_config(page_title="Chatbot Carômetro", page_icon=im, layout="wide") | |
# Seções laterais (sidebar) | |
st.sidebar.title("Configuração de LLM") | |
sidebar_option = st.sidebar.radio("Selecione o LLM", ["gpt-3.5-turbo"]) | |
import base64 | |
with open("sicoob-logo.png", "rb") as f: | |
data = base64.b64encode(f.read()).decode("utf-8") | |
st.sidebar.markdown( | |
f""" | |
<div style="display:table;margin-top:-80%;margin-left:0%;"> | |
<img src="data:image/png;base64,{data}" width="250" height="70"> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
if sidebar_option == "gpt-3.5-turbo": | |
from llama_index.llms.openai import OpenAI | |
from llama_index.embeddings.openai import OpenAIEmbedding | |
Settings.llm = OpenAI(model="gpt-3.5-turbo") | |
Settings.embed_model = OpenAIEmbedding(model_name="text-embedding-ada-002") | |
else: | |
raise Exception("Opção de LLM inválida!") | |
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | |
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout)) | |
# Caminhos principais | |
chat_store_path = os.path.join("chat_store", "chat_store.json") | |
documents_path = "documentos" | |
chroma_storage_path = "chroma_db" | |
bm25_persist_path = "bm25_retriever" | |
# Classe CSV customizada | |
class CustomPandasCSVReader: | |
"""PandasCSVReader modificado para incluir cabeçalhos nos documentos.""" | |
def __init__( | |
self, | |
*args: Any, | |
concat_rows: bool = True, | |
col_joiner: str = ", ", | |
row_joiner: str = "\n", | |
pandas_config: dict = {}, | |
**kwargs: Any | |
) -> None: | |
self._concat_rows = concat_rows | |
self._col_joiner = col_joiner | |
self._row_joiner = row_joiner | |
self._pandas_config = pandas_config | |
def load_data( | |
self, | |
file: Path, | |
extra_info: Optional[Dict] = None, | |
) -> List[Document]: | |
df = pd.read_csv(file, **self._pandas_config) | |
text_list = [" ".join(df.columns.astype(str))] | |
text_list += ( | |
df.astype(str) | |
.apply(lambda row: self._col_joiner.join(row.values), axis=1) | |
.tolist() | |
) | |
metadata = {"filename": file.name, "extension": file.suffix} | |
if extra_info: | |
metadata.update(extra_info) | |
if self._concat_rows: | |
return [Document(text=self._row_joiner.join(text_list), metadata=metadata)] | |
else: | |
return [ | |
Document(text=text, metadata=metadata) | |
for text in text_list | |
] | |
def clean_documents(documents: List[Document]) -> List[Document]: | |
"""Remove caracteres indesejados diretamente nos textos.""" | |
cleaned_docs = [] | |
for doc in documents: | |
cleaned_text = re.sub(r"[^0-9A-Za-zÀ-ÿ ]", "", doc.get_content()) | |
doc.text = cleaned_text | |
cleaned_docs.append(doc) | |
return cleaned_docs | |
def are_docs_downloaded(directory_path: str) -> bool: | |
"""Verifica se o diretório tem algum arquivo.""" | |
return os.path.isdir(directory_path) and any(os.scandir(directory_path)) | |
# Simula a leitura de arquivos do Google Drive | |
from llama_index.readers.google import GoogleDriveReader | |
import json | |
credentials_json = os.getenv('GOOGLE_CREDENTIALS') | |
token_json = os.getenv('GOOGLE_TOKEN') | |
if credentials_json is None: | |
raise ValueError("The GOOGLE_CREDENTIALS environment variable is not set.") | |
# Write the credentials to a file | |
credentials_path = "credentials.json" | |
token_path = "token.json" | |
with open(credentials_path, 'w') as credentials_file: | |
credentials_file.write(credentials_json) | |
with open(token_path, 'w') as credentials_file: | |
credentials_file.write(token_json) | |
google_drive_reader = GoogleDriveReader(credentials_path=credentials_path) | |
google_drive_reader._creds = google_drive_reader._get_credentials() | |
def download_original_files_from_folder( | |
greader: GoogleDriveReader, | |
pasta_documentos_drive: str, | |
local_path: str | |
): | |
"""Faz download dos arquivos apenas se não existirem localmente.""" | |
os.makedirs(local_path, exist_ok=True) | |
files_meta = greader._get_fileids_meta(folder_id=pasta_documentos_drive) | |
if not files_meta: | |
logging.info("Nenhum arquivo encontrado na pasta especificada.") | |
return | |
for fmeta in files_meta: | |
file_id = fmeta[0] | |
file_name = os.path.basename(fmeta[2]) | |
local_file_path = os.path.join(local_path, file_name) | |
if os.path.exists(local_file_path): | |
logging.info(f"Arquivo '{file_name}' já existe localmente, ignorando download.") | |
continue | |
downloaded_file_path = greader._download_file(file_id, local_file_path) | |
if downloaded_file_path: | |
logging.info(f"Arquivo '{file_name}' baixado com sucesso em: {downloaded_file_path}") | |
else: | |
logging.warning(f"Não foi possível baixar '{file_name}'") | |
# Pasta do Drive | |
pasta_documentos_drive = "1s0UUANcU1B0D2eyRweb1W5idUn1V5JEh" | |
############################################################################### | |
# CRIAÇÃO/CARREGAMENTO DE RECURSOS (evita repetição de etapas) # | |
############################################################################### | |
# 1. Garantir que não baixamos dados novamente se eles já existem. | |
if not are_docs_downloaded(documents_path): | |
logging.info("Baixando arquivos originais do Drive para 'documentos'...") | |
download_original_files_from_folder( | |
google_drive_reader, | |
pasta_documentos_drive, | |
documents_path | |
) | |
else: | |
logging.info("'documentos' já contém arquivos, ignorando download.") | |
# 2. Se ainda não existir docstore e index no estado da sessão, criamos. | |
# Caso contrário, apenas reutilizamos o que já existe. | |
if "docstore" not in st.session_state: | |
# Carregar documentos do diretório local | |
file_extractor = {".csv": CustomPandasCSVReader()} | |
documents = SimpleDirectoryReader( | |
input_dir=documents_path, | |
file_extractor=file_extractor, | |
filename_as_id=True, | |
recursive=True | |
).load_data() | |
documents = clean_documents(documents) | |
# Cria docstore | |
docstore = SimpleDocumentStore() | |
docstore.add_documents(documents) | |
st.session_state["docstore"] = docstore | |
else: | |
docstore = st.session_state["docstore"] | |
# 3. Configuramos o VectorStore + Chroma sem recriar se já estiver pronto. | |
if "vector_store" not in st.session_state: | |
db = chromadb.PersistentClient(path=chroma_storage_path) | |
chroma_collection = db.get_or_create_collection("dense_vectors") | |
vector_store = ChromaVectorStore(chroma_collection=chroma_collection) | |
st.session_state["vector_store"] = vector_store | |
else: | |
vector_store = st.session_state["vector_store"] | |
storage_context = StorageContext.from_defaults( | |
docstore=docstore, | |
vector_store=vector_store | |
) | |
# 4. Carregamos ou criamos o índice. Se já existe a base do Chroma, supõe-se | |
# que o índice foi persistido. Caso contrário, cria-se. | |
if "index" not in st.session_state: | |
if os.path.exists(chroma_storage_path) and os.listdir(chroma_storage_path): | |
# Há dados salvos, então criamos índice a partir do vector_store | |
index = VectorStoreIndex.from_vector_store(vector_store) | |
else: | |
# Cria índice (chunk_size pode ser configurado conforme necessidade) | |
splitter = LangchainNodeParser( | |
RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=128) | |
) | |
index = VectorStoreIndex.from_documents( | |
list(docstore.docs.values()), | |
storage_context=storage_context, | |
transformations=[splitter] | |
) | |
vector_store.persist() | |
st.session_state["index"] = index | |
else: | |
index = st.session_state["index"] | |
# 5. Criação ou carregamento do BM25Retriever customizado | |
if "bm25_retriever" not in st.session_state: | |
if ( | |
os.path.exists(bm25_persist_path) | |
and os.path.exists(os.path.join(bm25_persist_path, "bm25.index.json")) | |
): | |
bm25_retriever = BM25Retriever.from_persist_dir(bm25_persist_path) | |
else: | |
bm25_retriever = BM25Retriever.from_defaults( | |
docstore=docstore, | |
similarity_top_k=2, | |
language="portuguese", | |
verbose=True | |
) | |
os.makedirs(bm25_persist_path, exist_ok=True) | |
bm25_retriever.persist(bm25_persist_path) | |
st.session_state["bm25_retriever"] = bm25_retriever | |
else: | |
bm25_retriever = st.session_state["bm25_retriever"] | |
# 6. Criamos ou recuperamos o retriever que fará Query Fusion (BM25 + eventual vetor) | |
if "fusion_retriever" not in st.session_state: | |
vector_retriever = index.as_retriever(similarity_top_k=2) | |
fusion_retriever = QueryFusionRetriever( | |
[bm25_retriever, vector_retriever], | |
similarity_top_k=2, | |
num_queries=0, | |
mode="reciprocal_rerank", | |
use_async=True, | |
verbose=True, | |
query_gen_prompt=( | |
"Gere {num_queries} perguntas de busca relacionadas à seguinte pergunta. " | |
"Priorize o significado da pergunta sobre qualquer histórico de conversa. " | |
"Se o histórico não for relevante, ignore-o. " | |
"Não adicione explicações ou introduções. Apenas escreva as perguntas. " | |
"Pergunta: {query}\n\nPerguntas:\n" | |
), | |
) | |
st.session_state["fusion_retriever"] = fusion_retriever | |
else: | |
fusion_retriever = st.session_state["fusion_retriever"] | |
# 7. Configura o Chat Engine caso ainda não esteja na sessão | |
if "chat_engine" not in st.session_state: | |
nest_asyncio.apply() | |
memory = ChatMemoryBuffer.from_defaults(token_limit=3900) | |
query_engine = RetrieverQueryEngine.from_args(fusion_retriever) | |
chat_engine = CondensePlusContextChatEngine.from_defaults( | |
query_engine, | |
memory=memory, | |
context_prompt=( | |
"Você é um assistente virtual capaz de interagir normalmente, além de " | |
"fornecer informações sobre organogramas e listar funcionários. " | |
"Aqui estão os documentos relevantes para o contexto:\n" | |
"{context_str}\n" | |
"Use o histórico anterior ou o contexto acima para responder." | |
), | |
verbose=True, | |
) | |
st.session_state["chat_engine"] = chat_engine | |
else: | |
chat_engine = st.session_state["chat_engine"] | |
# 8. Armazenamento do chat | |
if "chat_store" not in st.session_state: | |
if os.path.exists(chat_store_path): | |
chat_store = SimpleChatStore.from_persist_path(persist_path=chat_store_path) | |
else: | |
chat_store = SimpleChatStore() | |
chat_store.persist(persist_path=chat_store_path) | |
st.session_state["chat_store"] = chat_store | |
else: | |
chat_store = st.session_state["chat_store"] | |
############################################################################### | |
# INTERFACE DO CHAT EM STREAMLIT # | |
############################################################################### | |
st.title("Chatbot Carômetro") | |
st.write("Este assistente virtual pode te ajudar a encontrar informações relevantes sobre os carômetros da Sicoob.") | |
if 'chat_history' not in st.session_state: | |
st.session_state.chat_history = [] | |
for message in st.session_state.chat_history: | |
role, text = message.split(":", 1) | |
with st.chat_message(role.strip().lower()): | |
st.write(text.strip()) | |
user_input = st.chat_input("Digite sua pergunta") | |
if user_input: | |
with st.chat_message('user'): | |
st.write(user_input) | |
st.session_state.chat_history.append(f"user: {user_input}") | |
with st.chat_message('assistant'): | |
message_placeholder = st.empty() | |
assistant_message = '' | |
response = chat_engine.stream_chat(user_input) | |
for token in response.response_gen: | |
assistant_message += token | |
message_placeholder.markdown(assistant_message + "▌") | |
message_placeholder.markdown(assistant_message) | |
st.session_state.chat_history.append(f"assistant: {assistant_message}") |