|
import gradio as gr |
|
|
|
from langdetect import detect |
|
import pycountry |
|
from googletrans import Translator |
|
|
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_openai import ChatOpenAI |
|
|
|
|
|
from langchain_community.document_loaders import PyPDFLoader |
|
from langchain_text_splitters import CharacterTextSplitter |
|
import glob |
|
import base64 |
|
import os |
|
from os.path import split |
|
|
|
from langchain_core.messages import HumanMessage |
|
from langchain_text_splitters import CharacterTextSplitter |
|
from unstructured.partition.pdf import partition_pdf |
|
import uuid |
|
|
|
from langchain.retrievers.multi_vector import MultiVectorRetriever |
|
from langchain.storage import InMemoryStore |
|
from langchain_chroma import Chroma |
|
from langchain_core.documents import Document |
|
from langchain_openai import OpenAIEmbeddings |
|
|
|
import io |
|
import re |
|
import glob |
|
|
|
from IPython.display import HTML, display |
|
from langchain_core.runnables import RunnableLambda, RunnablePassthrough |
|
from PIL import Image |
|
|
|
|
|
class CNC_QA: |
|
def __init__(self): |
|
print("Initialing CLASS:CNC_QA ") |
|
self.bot=self.load_QAAI() |
|
|
|
def load_QAAI(self): |
|
|
|
|
|
|
|
text_summaries = [] |
|
texts = [] |
|
table_summaries = [] |
|
tables = [] |
|
|
|
|
|
img_base64_list = [] |
|
|
|
image_summaries = [] |
|
|
|
print("Start to load documents") |
|
fullpathes=glob.glob(f'./Doc/*') |
|
for i,fullpath in enumerate(fullpathes): |
|
print(f'{i+1}/{len(fullpathes)}:{fullpath}') |
|
text_summarie,text,table_summarie,table,image_summarie,img_base64 = self.load_documents(fullpath) |
|
text_summaries += text_summarie |
|
texts += text |
|
table_summaries += table_summarie |
|
tables += table |
|
img_base64_list += image_summarie |
|
image_summaries += img_base64 |
|
|
|
vectorstore = Chroma( |
|
collection_name="mm_rag_cj_blog", embedding_function=OpenAIEmbeddings() |
|
) |
|
|
|
|
|
self.retriever_multi_vector_img = self.create_multi_vector_retriever( |
|
vectorstore, |
|
text_summaries, |
|
texts, |
|
table_summaries, |
|
tables, |
|
image_summaries, |
|
img_base64_list, |
|
) |
|
|
|
chain_multimodal_rag = self.multi_modal_rag_chain(self.retriever_multi_vector_img) |
|
return chain_multimodal_rag |
|
|
|
def load_documents(self,fullpath): |
|
fpath, fname = split(fullpath) |
|
fpath += '/' |
|
|
|
print('Get elements') |
|
raw_pdf_elements = self.extract_pdf_elements(fpath, fname) |
|
|
|
|
|
print('Get text, tables') |
|
texts, tables = self.categorize_elements(raw_pdf_elements) |
|
|
|
|
|
print('Optional: Enforce a specific token size for texts') |
|
text_splitter = CharacterTextSplitter.from_tiktoken_encoder( |
|
chunk_size=4000, chunk_overlap=0 |
|
) |
|
joined_texts = " ".join(texts) |
|
texts_4k_token = text_splitter.split_text(joined_texts) |
|
|
|
|
|
print('Get text, table summaries') |
|
text_summaries, table_summaries = self.generate_text_summaries( |
|
texts_4k_token, tables, summarize_texts=True |
|
) |
|
|
|
print('Image summaries') |
|
img_base64_list, image_summaries = self.generate_img_summaries(fpath) |
|
return text_summaries,texts,table_summaries,tables,image_summaries,img_base64_list |
|
|
|
|
|
|
|
|
|
def extract_pdf_elements(self,path, fname): |
|
""" |
|
Extract images, tables, and chunk text from a PDF file. |
|
path: File path, which is used to dump images (.jpg) |
|
fname: File name |
|
""" |
|
return partition_pdf( |
|
filename=path + fname, |
|
|
|
extract_images_in_pdf=True, |
|
infer_table_structure=True, |
|
chunking_strategy="by_title", |
|
max_characters=4000, |
|
new_after_n_chars=3800, |
|
combine_text_under_n_chars=2000, |
|
image_output_dir_path=path, |
|
) |
|
|
|
|
|
|
|
def categorize_elements(self,raw_pdf_elements): |
|
""" |
|
Categorize extracted elements from a PDF into tables and texts. |
|
raw_pdf_elements: List of unstructured.documents.elements |
|
""" |
|
tables = [] |
|
texts = [] |
|
for element in raw_pdf_elements: |
|
if "unstructured.documents.elements.Table" in str(type(element)): |
|
tables.append(str(element)) |
|
elif "unstructured.documents.elements.CompositeElement" in str(type(element)): |
|
texts.append(str(element)) |
|
return texts, tables |
|
|
|
|
|
def generate_text_summaries(self,texts, tables, summarize_texts=False): |
|
""" |
|
Summarize text elements |
|
texts: List of str |
|
tables: List of str |
|
summarize_texts: Bool to summarize texts |
|
""" |
|
|
|
|
|
prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \ |
|
These summaries will be embedded and used to retrieve the raw text or table elements. \ |
|
Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """ |
|
prompt = ChatPromptTemplate.from_template(prompt_text) |
|
|
|
|
|
model = ChatOpenAI(temperature=0, model="gpt-4o-mini") |
|
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser() |
|
|
|
|
|
text_summaries = [] |
|
table_summaries = [] |
|
|
|
|
|
if texts and summarize_texts: |
|
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5}) |
|
elif texts: |
|
text_summaries = texts |
|
|
|
|
|
if tables: |
|
table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5}) |
|
|
|
return text_summaries, table_summaries |
|
|
|
def encode_image(self,image_path): |
|
"""Getting the base64 string""" |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
def image_summarize(self,img_base64, prompt): |
|
"""Make image summary""" |
|
chat = ChatOpenAI(self,model="gpt-4o-mini", max_tokens=1024) |
|
|
|
msg = chat.invoke( |
|
[ |
|
HumanMessage( |
|
content=[ |
|
{"type": "text", "text": prompt}, |
|
{ |
|
"type": "image_url", |
|
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, |
|
}, |
|
] |
|
) |
|
] |
|
) |
|
return msg.content |
|
|
|
def generate_img_summaries(self,path): |
|
""" |
|
Generate summaries and base64 encoded strings for images |
|
path: Path to list of .jpg files extracted by Unstructured |
|
""" |
|
|
|
|
|
img_base64_list = [] |
|
|
|
|
|
image_summaries = [] |
|
|
|
|
|
prompt = """You are an assistant tasked with summarizing images for retrieval. \ |
|
These summaries will be embedded and used to retrieve the raw image. \ |
|
Give a concise summary of the image that is well optimized for retrieval.""" |
|
|
|
|
|
|
|
for img_file in sorted(os.listdir(path)): |
|
if img_file.endswith(".jpg"): |
|
img_path = os.path.join(path, img_file) |
|
base64_image = self.encode_image(img_path) |
|
img_base64_list.append(base64_image) |
|
image_summaries.append(self.image_summarize(base64_image, prompt)) |
|
|
|
return img_base64_list, image_summaries |
|
|
|
def create_multi_vector_retriever( |
|
self,vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images |
|
): |
|
""" |
|
Create retriever that indexes summaries, but returns raw images or texts |
|
""" |
|
|
|
|
|
store = InMemoryStore() |
|
id_key = "doc_id" |
|
|
|
|
|
retriever = MultiVectorRetriever( |
|
vectorstore=vectorstore, |
|
docstore=store, |
|
id_key=id_key, |
|
) |
|
|
|
|
|
def add_documents(retriever, doc_summaries, doc_contents): |
|
doc_ids = [str(uuid.uuid4()) for _ in doc_contents] |
|
for text in doc_summaries: |
|
print(text) |
|
summary_docs = [ |
|
Document(page_content=s, metadata={id_key: doc_ids[i]}) |
|
for i, s in enumerate(doc_summaries) |
|
] |
|
retriever.vectorstore.add_documents(summary_docs) |
|
retriever.docstore.mset(list(zip(doc_ids, doc_contents))) |
|
|
|
|
|
|
|
if text_summaries: |
|
add_documents(retriever, text_summaries, texts) |
|
|
|
if table_summaries: |
|
add_documents(retriever, table_summaries, tables) |
|
|
|
if image_summaries: |
|
add_documents(retriever, image_summaries, images) |
|
|
|
return retriever |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def looks_like_base64(self,sb): |
|
"""Check if the string looks like base64""" |
|
return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None |
|
|
|
|
|
def is_image_data(self,b64data): |
|
""" |
|
Check if the base64 data is an image by looking at the start of the data |
|
""" |
|
image_signatures = { |
|
b"\xff\xd8\xff": "jpg", |
|
b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png", |
|
b"\x47\x49\x46\x38": "gif", |
|
b"\x52\x49\x46\x46": "webp", |
|
} |
|
try: |
|
header = base64.b64decode(b64data)[:8] |
|
for sig, format in image_signatures.items(): |
|
if header.startswith(sig): |
|
return True |
|
return False |
|
except Exception: |
|
return False |
|
|
|
|
|
def resize_base64_image(self,base64_string, size=(128, 128)): |
|
""" |
|
Resize an image encoded as a Base64 string |
|
""" |
|
|
|
img_data = base64.b64decode(base64_string) |
|
img = Image.open(io.BytesIO(img_data)) |
|
|
|
|
|
resized_img = img.resize(size, Image.LANCZOS) |
|
|
|
|
|
buffered = io.BytesIO() |
|
resized_img.save(buffered, format=img.format) |
|
|
|
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
def split_image_text_types(self,docs): |
|
""" |
|
Split base64-encoded images and texts |
|
""" |
|
b64_images = [] |
|
texts = [] |
|
for doc in docs: |
|
|
|
if isinstance(doc, Document): |
|
doc = doc.page_content |
|
if self.looks_like_base64(doc) and self.is_image_data(doc): |
|
doc = self.resize_base64_image(doc, size=(1300, 600)) |
|
b64_images.append(doc) |
|
else: |
|
texts.append(doc) |
|
return {"images": b64_images, "texts": texts} |
|
|
|
|
|
def img_prompt_func(self,data_dict): |
|
""" |
|
Join the context into a single string |
|
""" |
|
formatted_texts = "\n".join(data_dict["context"]["texts"]) |
|
messages = [] |
|
|
|
|
|
if data_dict["context"]["images"]: |
|
for image in data_dict["context"]["images"]: |
|
image_message = { |
|
"type": "image_url", |
|
"image_url": {"url": f"data:image/jpeg;base64,{image}"}, |
|
} |
|
messages.append(image_message) |
|
|
|
|
|
text_message = { |
|
"type": "text", |
|
"text": ( |
|
"You are CNC machine engineer who answer the question.\n" |
|
"You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n" |
|
"Use this information to provide investment advice related to the user question. \n" |
|
f"User-provided question: {data_dict['question']}\n\n" |
|
"Text and / or tables:\n" |
|
f"{formatted_texts}" |
|
), |
|
} |
|
messages.append(text_message) |
|
return [HumanMessage(content=messages)] |
|
|
|
|
|
def multi_modal_rag_chain(self,retriever): |
|
""" |
|
Multi-modal RAG chain |
|
""" |
|
|
|
|
|
model = ChatOpenAI(temperature=0, model="gpt-4o-mini", max_tokens=1024) |
|
|
|
|
|
chain = ( |
|
{ |
|
"context": retriever | RunnableLambda(self.split_image_text_types), |
|
"question": RunnablePassthrough(), |
|
} |
|
| RunnableLambda(self.img_prompt_func) |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
return chain |
|
def echo(self,message,history): |
|
|
|
ans = self.bot.invoke(message) |
|
|
|
|
|
return ans |
|
|
|
|
|
def convert_lang(self,message,lang_dest): |
|
lang = detect(message) |
|
|
|
translator = Translator() |
|
|
|
print(f'元言語:{lang} -> 翻訳言語:{lang_dest}') |
|
if lang == lang_dest: |
|
text = message |
|
else: |
|
text = translator.translate(message, src=lang, dest=lang_dest).text |
|
print(message) |
|
print(text) |
|
|
|
return text, lang |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print("start") |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
|
|
meldas = CNC_QA() |
|
|
|
demo = gr.ChatInterface(fn=meldas.echo, examples=["What is 3D machinning simulation?", "Is there some limit (program step or scan time) at the time of communication in the bus coupling of M3?"], title="MELDAS AI") |
|
|
|
demo.launch(debug=True,share=True) |
|
|