Keisuke Yamanaka - CNC commited on
Commit
8d45acd
·
1 Parent(s): 92ea1dc

update app.py

Browse files
Files changed (1) hide show
  1. app_multimodal_AI.py- +440 -0
app_multimodal_AI.py- ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ #rom huggingface_hub import InferenceClient
3
+ from langdetect import detect
4
+ import pycountry
5
+ from googletrans import Translator
6
+
7
+ from langchain_core.output_parsers import StrOutputParser
8
+ from langchain_core.prompts import ChatPromptTemplate
9
+ from langchain_openai import ChatOpenAI
10
+
11
+ #from langchain.document_loaders import UnstructuredExcelLoader
12
+ from langchain_community.document_loaders import PyPDFLoader
13
+ from langchain_text_splitters import CharacterTextSplitter
14
+ import glob
15
+ import base64
16
+ import os
17
+ from os.path import split
18
+
19
+ from langchain_core.messages import HumanMessage
20
+ from langchain_text_splitters import CharacterTextSplitter
21
+ from unstructured.partition.pdf import partition_pdf
22
+ import uuid
23
+
24
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
25
+ from langchain.storage import InMemoryStore
26
+ from langchain_chroma import Chroma
27
+ from langchain_core.documents import Document
28
+ from langchain_openai import OpenAIEmbeddings
29
+
30
+ import io
31
+ import re
32
+ import glob
33
+
34
+ from IPython.display import HTML, display
35
+ from langchain_core.runnables import RunnableLambda, RunnablePassthrough
36
+ from PIL import Image
37
+
38
+
39
+ class CNC_QA:
40
+ def __init__(self):
41
+ print("Initialing CLASS:CNC_QA ")
42
+ self.bot=self.load_QAAI()
43
+
44
+ def load_QAAI(self):
45
+ # File path
46
+ # The vectorstore to use to index the summaries
47
+ # Initialize empty summaries
48
+ text_summaries = []
49
+ texts = []
50
+ table_summaries = []
51
+ tables = []
52
+
53
+ # Store base64 encoded images
54
+ img_base64_list = []
55
+ # Store image summaries
56
+ image_summaries = []
57
+
58
+ print("Start to load documents")
59
+ fullpathes=glob.glob(f'./Doc/*')
60
+ for i,fullpath in enumerate(fullpathes):
61
+ print(f'{i+1}/{len(fullpathes)}:{fullpath}')
62
+ text_summarie,text,table_summarie,table,image_summarie,img_base64 = self.load_documents(fullpath)
63
+ text_summaries += text_summarie
64
+ texts += text
65
+ table_summaries += table_summarie
66
+ tables += table
67
+ img_base64_list += image_summarie
68
+ image_summaries += img_base64
69
+
70
+ vectorstore = Chroma(
71
+ collection_name="mm_rag_cj_blog", embedding_function=OpenAIEmbeddings()
72
+ )
73
+
74
+ # Create retriever
75
+ self.retriever_multi_vector_img = self.create_multi_vector_retriever(
76
+ vectorstore,
77
+ text_summaries,
78
+ texts,
79
+ table_summaries,
80
+ tables,
81
+ image_summaries,
82
+ img_base64_list,
83
+ )
84
+
85
+ chain_multimodal_rag = self.multi_modal_rag_chain(self.retriever_multi_vector_img)
86
+ return chain_multimodal_rag
87
+
88
+ def load_documents(self,fullpath):
89
+ fpath, fname = split(fullpath)
90
+ fpath += '/'
91
+ # Get elements
92
+ print('Get elements')
93
+ raw_pdf_elements = self.extract_pdf_elements(fpath, fname)
94
+
95
+ # Get text, tables
96
+ print('Get text, tables')
97
+ texts, tables = self.categorize_elements(raw_pdf_elements)
98
+
99
+ # Optional: Enforce a specific token size for texts
100
+ print('Optional: Enforce a specific token size for texts')
101
+ text_splitter = CharacterTextSplitter.from_tiktoken_encoder(
102
+ chunk_size=4000, chunk_overlap=0
103
+ )
104
+ joined_texts = " ".join(texts)
105
+ texts_4k_token = text_splitter.split_text(joined_texts)
106
+
107
+ # Get text, table summaries
108
+ print('Get text, table summaries')
109
+ text_summaries, table_summaries = self.generate_text_summaries(
110
+ texts_4k_token, tables, summarize_texts=True
111
+ )
112
+
113
+ print('Image summaries')
114
+ img_base64_list, image_summaries = self.generate_img_summaries(fpath)
115
+ return text_summaries,texts,table_summaries,tables,image_summaries,img_base64_list
116
+
117
+
118
+
119
+ # Extract elements from PDF
120
+ def extract_pdf_elements(self,path, fname):
121
+ """
122
+ Extract images, tables, and chunk text from a PDF file.
123
+ path: File path, which is used to dump images (.jpg)
124
+ fname: File name
125
+ """
126
+ return partition_pdf(
127
+ filename=path + fname,
128
+ #filename=r'/content/drive/My Drive/huggingface_transformers_demo/transformers/Doc/ResconReg.pdf',
129
+ extract_images_in_pdf=True,
130
+ infer_table_structure=True,
131
+ chunking_strategy="by_title",
132
+ max_characters=4000,
133
+ new_after_n_chars=3800,
134
+ combine_text_under_n_chars=2000,
135
+ image_output_dir_path=path,
136
+ )
137
+
138
+
139
+ # Categorize elements by type
140
+ def categorize_elements(self,raw_pdf_elements):
141
+ """
142
+ Categorize extracted elements from a PDF into tables and texts.
143
+ raw_pdf_elements: List of unstructured.documents.elements
144
+ """
145
+ tables = []
146
+ texts = []
147
+ for element in raw_pdf_elements:
148
+ if "unstructured.documents.elements.Table" in str(type(element)):
149
+ tables.append(str(element))
150
+ elif "unstructured.documents.elements.CompositeElement" in str(type(element)):
151
+ texts.append(str(element))
152
+ return texts, tables
153
+
154
+ # Generate summaries of text elements
155
+ def generate_text_summaries(self,texts, tables, summarize_texts=False):
156
+ """
157
+ Summarize text elements
158
+ texts: List of str
159
+ tables: List of str
160
+ summarize_texts: Bool to summarize texts
161
+ """
162
+
163
+ # Prompt
164
+ prompt_text = """You are an assistant tasked with summarizing tables and text for retrieval. \
165
+ These summaries will be embedded and used to retrieve the raw text or table elements. \
166
+ Give a concise summary of the table or text that is well optimized for retrieval. Table or text: {element} """
167
+ prompt = ChatPromptTemplate.from_template(prompt_text)
168
+
169
+ # Text summary chain
170
+ model = ChatOpenAI(temperature=0, model="gpt-4o-mini")
171
+ summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
172
+
173
+ # Initialize empty summaries
174
+ text_summaries = []
175
+ table_summaries = []
176
+
177
+ # Apply to text if texts are provided and summarization is requested
178
+ if texts and summarize_texts:
179
+ text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
180
+ elif texts:
181
+ text_summaries = texts
182
+
183
+ # Apply to tables if tables are provided
184
+ if tables:
185
+ table_summaries = summarize_chain.batch(tables, {"max_concurrency": 5})
186
+
187
+ return text_summaries, table_summaries
188
+
189
+ def encode_image(self,image_path):
190
+ """Getting the base64 string"""
191
+ with open(image_path, "rb") as image_file:
192
+ return base64.b64encode(image_file.read()).decode("utf-8")
193
+
194
+
195
+ def image_summarize(self,img_base64, prompt):
196
+ """Make image summary"""
197
+ chat = ChatOpenAI(self,model="gpt-4o-mini", max_tokens=1024)
198
+
199
+ msg = chat.invoke(
200
+ [
201
+ HumanMessage(
202
+ content=[
203
+ {"type": "text", "text": prompt},
204
+ {
205
+ "type": "image_url",
206
+ "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
207
+ },
208
+ ]
209
+ )
210
+ ]
211
+ )
212
+ return msg.content
213
+
214
+ def generate_img_summaries(self,path):
215
+ """
216
+ Generate summaries and base64 encoded strings for images
217
+ path: Path to list of .jpg files extracted by Unstructured
218
+ """
219
+
220
+ # Store base64 encoded images
221
+ img_base64_list = []
222
+
223
+ # Store image summaries
224
+ image_summaries = []
225
+
226
+ # Prompt
227
+ prompt = """You are an assistant tasked with summarizing images for retrieval. \
228
+ These summaries will be embedded and used to retrieve the raw image. \
229
+ Give a concise summary of the image that is well optimized for retrieval."""
230
+
231
+
232
+ # Apply to images
233
+ for img_file in sorted(os.listdir(path)):
234
+ if img_file.endswith(".jpg"):
235
+ img_path = os.path.join(path, img_file)
236
+ base64_image = self.encode_image(img_path)
237
+ img_base64_list.append(base64_image)
238
+ image_summaries.append(self.image_summarize(base64_image, prompt))
239
+
240
+ return img_base64_list, image_summaries
241
+
242
+ def create_multi_vector_retriever(
243
+ self,vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
244
+ ):
245
+ """
246
+ Create retriever that indexes summaries, but returns raw images or texts
247
+ """
248
+
249
+ # Initialize the storage layer
250
+ store = InMemoryStore()
251
+ id_key = "doc_id"
252
+
253
+ # Create the multi-vector retriever
254
+ retriever = MultiVectorRetriever(
255
+ vectorstore=vectorstore,
256
+ docstore=store,
257
+ id_key=id_key,
258
+ )
259
+
260
+ # Helper function to add documents to the vectorstore and docstore
261
+ def add_documents(retriever, doc_summaries, doc_contents):
262
+ doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
263
+ for text in doc_summaries:
264
+ print(text)
265
+ summary_docs = [
266
+ Document(page_content=s, metadata={id_key: doc_ids[i]})
267
+ for i, s in enumerate(doc_summaries)
268
+ ]
269
+ retriever.vectorstore.add_documents(summary_docs)
270
+ retriever.docstore.mset(list(zip(doc_ids, doc_contents)))
271
+
272
+ # Add texts, tables, and images
273
+ # Check that text_summaries is not empty before adding
274
+ if text_summaries:
275
+ add_documents(retriever, text_summaries, texts)
276
+ # Check that table_summaries is not empty before adding
277
+ if table_summaries:
278
+ add_documents(retriever, table_summaries, tables)
279
+ # Check that image_summaries is not empty before adding
280
+ if image_summaries:
281
+ add_documents(retriever, image_summaries, images)
282
+
283
+ return retriever
284
+
285
+ # def plt_img_base64(self,img_base64):
286
+ # """Disply base64 encoded string as image"""
287
+ # # Create an HTML img tag with the base64 string as the source
288
+ # image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
289
+ # # Display the image by rendering the HTML
290
+ # display(HTML(image_html))
291
+
292
+
293
+ def looks_like_base64(self,sb):
294
+ """Check if the string looks like base64"""
295
+ return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None
296
+
297
+
298
+ def is_image_data(self,b64data):
299
+ """
300
+ Check if the base64 data is an image by looking at the start of the data
301
+ """
302
+ image_signatures = {
303
+ b"\xff\xd8\xff": "jpg",
304
+ b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
305
+ b"\x47\x49\x46\x38": "gif",
306
+ b"\x52\x49\x46\x46": "webp",
307
+ }
308
+ try:
309
+ header = base64.b64decode(b64data)[:8] # Decode and get the first 8 bytes
310
+ for sig, format in image_signatures.items():
311
+ if header.startswith(sig):
312
+ return True
313
+ return False
314
+ except Exception:
315
+ return False
316
+
317
+
318
+ def resize_base64_image(self,base64_string, size=(128, 128)):
319
+ """
320
+ Resize an image encoded as a Base64 string
321
+ """
322
+ # Decode the Base64 string
323
+ img_data = base64.b64decode(base64_string)
324
+ img = Image.open(io.BytesIO(img_data))
325
+
326
+ # Resize the image
327
+ resized_img = img.resize(size, Image.LANCZOS)
328
+
329
+ # Save the resized image to a bytes buffer
330
+ buffered = io.BytesIO()
331
+ resized_img.save(buffered, format=img.format)
332
+
333
+ # Encode the resized image to Base64
334
+ return base64.b64encode(buffered.getvalue()).decode("utf-8")
335
+
336
+
337
+ def split_image_text_types(self,docs):
338
+ """
339
+ Split base64-encoded images and texts
340
+ """
341
+ b64_images = []
342
+ texts = []
343
+ for doc in docs:
344
+ # Check if the document is of type Document and extract page_content if so
345
+ if isinstance(doc, Document):
346
+ doc = doc.page_content
347
+ if self.looks_like_base64(doc) and self.is_image_data(doc):
348
+ doc = self.resize_base64_image(doc, size=(1300, 600))
349
+ b64_images.append(doc)
350
+ else:
351
+ texts.append(doc)
352
+ return {"images": b64_images, "texts": texts}
353
+
354
+
355
+ def img_prompt_func(self,data_dict):
356
+ """
357
+ Join the context into a single string
358
+ """
359
+ formatted_texts = "\n".join(data_dict["context"]["texts"])
360
+ messages = []
361
+
362
+ # Adding image(s) to the messages if present
363
+ if data_dict["context"]["images"]:
364
+ for image in data_dict["context"]["images"]:
365
+ image_message = {
366
+ "type": "image_url",
367
+ "image_url": {"url": f"data:image/jpeg;base64,{image}"},
368
+ }
369
+ messages.append(image_message)
370
+
371
+ # Adding the text for analysis
372
+ text_message = {
373
+ "type": "text",
374
+ "text": (
375
+ "You are CNC machine engineer who answer the question.\n"
376
+ "You will be given a mixed of text, tables, and image(s) usually of charts or graphs.\n"
377
+ "Use this information to provide investment advice related to the user question. \n"
378
+ f"User-provided question: {data_dict['question']}\n\n"
379
+ "Text and / or tables:\n"
380
+ f"{formatted_texts}"
381
+ ),
382
+ }
383
+ messages.append(text_message)
384
+ return [HumanMessage(content=messages)]
385
+
386
+
387
+ def multi_modal_rag_chain(self,retriever):
388
+ """
389
+ Multi-modal RAG chain
390
+ """
391
+
392
+ # Multi-modal LLM
393
+ model = ChatOpenAI(temperature=0, model="gpt-4o-mini", max_tokens=1024)
394
+
395
+ # RAG pipeline
396
+ chain = (
397
+ {
398
+ "context": retriever | RunnableLambda(self.split_image_text_types),
399
+ "question": RunnablePassthrough(),
400
+ }
401
+ | RunnableLambda(self.img_prompt_func)
402
+ | model
403
+ | StrOutputParser()
404
+ )
405
+
406
+ return chain
407
+ def echo(self,message,history):
408
+ #message = text_en
409
+ ans = self.bot.invoke(message)
410
+
411
+
412
+ return ans
413
+
414
+
415
+ def convert_lang(self,message,lang_dest):
416
+ lang = detect(message)
417
+
418
+ translator = Translator()
419
+
420
+ print(f'元言語:{lang} -> 翻訳言語:{lang_dest}')
421
+ if lang == lang_dest:
422
+ text = message
423
+ else:
424
+ text = translator.translate(message, src=lang, dest=lang_dest).text
425
+ print(message)
426
+ print(text)
427
+
428
+ return text, lang
429
+
430
+
431
+
432
+ if __name__ == "__main__":
433
+ print("start")
434
+ os.environ["OPENAI_API_KEY"] = "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"
435
+
436
+ meldas = CNC_QA()
437
+
438
+ demo = gr.ChatInterface(fn=meldas.echo, examples=["What is 3D machinning simulation?", "Is there some limit (program step or scan time) at the time of communication in the bus coupling of M3?"], title="MELDAS AI")
439
+ #demo = gr.Interface(fn=chat_func)
440
+ demo.launch(debug=True,share=True)