Spaces:
Sleeping
Sleeping
import streamlit as st | |
from sentence_transformers import SentenceTransformer | |
from langchain.text_splitters import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from transformers import pipeline | |
from datasets import load_dataset | |
import torch | |
# Set up the Streamlit page configuration | |
st.set_page_config(page_title="Gen AI Lawyers Guide", layout="centered", page_icon="π") | |
# Load summarization pipeline model | |
def load_summarization_pipeline(): | |
summarizer = pipeline("summarization", model="facebook/bart-large-cnn") # Use a summarization model | |
return summarizer | |
summarizer = load_summarization_pipeline() | |
# Load the CaseHOLD dataset from Hugging Face | |
def load_casehold_dataset(): | |
dataset = load_dataset("lex_glue", "case_hold", split="train") # Load CaseHOLD dataset | |
texts = [item["context"] for item in dataset] | |
return " ".join(texts) | |
# Split text into manageable chunks | |
def get_text_chunks(text): | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = text_splitter.split_text(text) | |
return chunks | |
# Initialize embedding model | |
def load_embedding_model(): | |
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
return model | |
embedding_model = load_embedding_model() | |
# Create a FAISS vector store with embeddings | |
def load_or_create_vector_store(text_chunks): | |
embeddings = [embedding_model.encode(text) for text in text_chunks] | |
vector_store = FAISS.from_embeddings(embeddings, text_chunks) # FAISS setup with embeddings | |
return vector_store | |
# Generate summary based on the retrieved text | |
def generate_summary_with_huggingface(query, retrieved_text): | |
summarization_input = f"{query}\n\nRelated information:\n{retrieved_text}" | |
max_input_length = 1024 | |
summarization_input = summarization_input[:max_input_length] | |
summary = summarizer(summarization_input, max_length=500, min_length=50, do_sample=False) | |
return summary[0]["summary_text"] | |
# Generate response for user query | |
def user_input(user_question, vector_store): | |
docs = vector_store.similarity_search(user_question) | |
context_text = " ".join([doc.page_content for doc in docs]) | |
return generate_summary_with_huggingface(user_question, context_text) | |
# Main function to run the Streamlit app | |
def main(): | |
st.title("π Gen AI Lawyers Guide with CaseHOLD Dataset") | |
# Load CaseHOLD dataset | |
st.write("Loading the CaseHOLD dataset from Hugging Face's datasets library...") | |
raw_text = load_casehold_dataset() | |
text_chunks = get_text_chunks(raw_text) | |
vector_store = load_or_create_vector_store(text_chunks) | |
# User question input | |
user_question = st.text_input("Ask a Question:", placeholder="Type your question here...") | |
if st.button("Get Response"): | |
if not user_question: | |
st.warning("Please enter a question before submitting.") | |
else: | |
with st.spinner("Generating response..."): | |
answer = user_input(user_question, vector_store) | |
st.markdown(f"**π€ AI:** {answer}") | |
if __name__ == "__main__": | |
main() | |