Daryl Lim commited on
Commit
80a5f54
·
1 Parent(s): f096bc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -46
app.py CHANGED
@@ -123,6 +123,57 @@ def convert_document_to_markdown(doc_path) -> str:
123
  except Exception as e:
124
  return f"Error converting document: {str(e)}"
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  # Function to generate a summary using the IBM Granite model
127
  def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
128
  """Generate a summary from document chunks using the IBM Granite model
@@ -132,14 +183,13 @@ def generate_summary(chunks: List[Document], length_type="sentences", length_cou
132
  length_type: Either "sentences" or "paragraphs"
133
  length_count: Number of sentences (1-10) or paragraphs (1-3)
134
  """
135
- # Print debug information to track what parameters are being used
136
  print(f"Generating summary with length_type={length_type}, length_count={length_count}")
137
 
138
  # Ensure length_count is an integer
139
  try:
140
  length_count = int(length_count)
141
  except (ValueError, TypeError):
142
- # Default to 3 if conversion fails
143
  print(f"Failed to convert length_count to int: {length_count}, using default 3")
144
  length_count = 3
145
 
@@ -149,18 +199,36 @@ def generate_summary(chunks: List[Document], length_type="sentences", length_cou
149
  else: # paragraphs
150
  length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
151
 
152
- # Concatenate the retrieved chunks
153
- combined_text = " ".join([chunk.page_content for chunk in chunks])
 
 
 
 
 
 
 
 
154
 
155
- # Use a more direct instruction to enforce the length constraint
156
  if length_type == "sentences":
157
- length_instruction = f"Your summary must be EXACTLY {length_count} sentence{'s' if length_count > 1 else ''}. Not more, not less."
158
  else: # paragraphs
159
- length_instruction = f"Your summary must be EXACTLY {length_count} paragraph{'s' if length_count > 1 else ''}. Not more, not less."
160
 
161
- # Construct the prompt with clearer instructions
162
  prompt = f"""<instruction>
163
- Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. Summarize the following text. {length_instruction} Your response should only include the summary. Do not provide any further explanation.
 
 
 
 
 
 
 
 
 
 
164
  </instruction>
165
 
166
  <text>
@@ -168,28 +236,30 @@ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a
168
  </text>
169
  """
170
 
171
- # Calculate appropriate max_new_tokens based on length requirements
172
- # Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
173
  if length_type == "sentences":
174
- max_tokens = length_count * 30 # Increased slightly for flexibility
 
175
  else: # paragraphs
176
- max_tokens = length_count * 120 # Increased slightly for flexibility
 
177
 
178
  # Ensure minimum tokens and add buffer
179
- max_tokens = max(100, min(1500, max_tokens + 50))
180
-
181
- # Generate the summary using the IBM Granite model
182
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
183
 
184
  print(f"Using max_new_tokens={max_tokens}")
185
 
 
 
 
186
  with torch.no_grad():
187
  output = model.generate(
188
  **inputs,
189
  max_new_tokens=max_tokens,
190
- temperature=0.7,
191
  top_p=0.9,
192
- do_sample=True
 
193
  )
194
 
195
  # Decode and return the generated summary
@@ -197,6 +267,18 @@ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a
197
 
198
  # Extract just the generated response (after the prompt)
199
  summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):]
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  return summary.strip()
202
 
@@ -255,19 +337,15 @@ def process_document(
255
  if markdown_path.startswith("Error"):
256
  return markdown_path
257
 
258
- # Load and split the document
259
- progress(0.4, "Loading and splitting document...")
260
- loader = UnstructuredMarkdownLoader(str(markdown_path))
261
- documents = loader.load()
 
262
 
263
- # Optimize text splitting for better chunks
264
- text_splitter = RecursiveCharacterTextSplitter(
265
- chunk_size=1000, # Larger chunk size for better context
266
- chunk_overlap=100,
267
- length_function=len,
268
- separators=["\n\n", "\n", ".", " ", ""] # Prioritize splitting at paragraph/sentence boundaries
269
- )
270
- texts = text_splitter.split_documents(documents)
271
 
272
  if not texts:
273
  return "No text could be extracted from the document."
@@ -302,36 +380,45 @@ def process_document(
302
  # Sleep briefly to allow memory cleanup
303
  time.sleep(0.1)
304
 
305
- # Generate summary from chunks
306
- if len(all_chunks) > 8:
307
- # If we have many chunks, process in batches
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
  summaries = []
309
  for i in range(0, len(all_chunks), batch_size):
310
  batch = all_chunks[i:i+batch_size]
311
  summary = generate_summary(
312
  batch,
313
- length_type=length_type,
314
- length_count=max(1, length_count // 2) # Use smaller count for partial summaries
315
  )
316
  summaries.append(summary)
317
 
318
  # Force garbage collection
319
  gc.collect()
320
 
321
- # Create final summary from batch summaries
322
  final_summary = generate_summary(
323
  [Document(page_content=s) for s in summaries],
324
- length_type=length_type,
325
  length_count=length_count
326
  )
327
  return final_summary
328
- else:
329
- # If we have few chunks, generate summary directly
330
- return generate_summary(
331
- all_chunks,
332
- length_type=length_type,
333
- length_count=length_count
334
- )
335
 
336
  except Exception as e:
337
  return f"Error processing document: {str(e)}"
@@ -458,4 +545,4 @@ def create_gradio_interface():
458
  # Launch the application
459
  if __name__ == "__main__":
460
  app = create_gradio_interface()
461
- app.launch()
 
123
  except Exception as e:
124
  return f"Error converting document: {str(e)}"
125
 
126
+ # Improved text processing function
127
+ def clean_and_prepare_text(markdown_path):
128
+ """Load, clean and prepare document text for better processing"""
129
+ try:
130
+ # Load the document
131
+ loader = UnstructuredMarkdownLoader(str(markdown_path))
132
+ documents = loader.load()
133
+
134
+ if not documents:
135
+ return None, "No content could be extracted from the document."
136
+
137
+ # Combine all document content for pre-processing
138
+ raw_text = " ".join([doc.page_content for doc in documents])
139
+
140
+ # Clean up the text
141
+ # 1. Normalize whitespace
142
+ text = " ".join(raw_text.split())
143
+ # 2. Fix common OCR and conversion artifacts
144
+ text = text.replace(" .", ".").replace(" ,", ",")
145
+ # 3. Ensure proper spacing after punctuation
146
+ for punct in ['.', '!', '?']:
147
+ text = text.replace(f"{punct}", f"{punct} ")
148
+
149
+ # Split into improved documents
150
+ # Use a sensible paragraph size
151
+ paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
152
+
153
+ # Create structured documents for better processing
154
+ processed_docs = []
155
+ for i, para in enumerate(paragraphs):
156
+ if len(para) > 10: # Skip very short paragraphs
157
+ processed_docs.append(Document(
158
+ page_content=para,
159
+ metadata={"source": markdown_path, "paragraph": i}
160
+ ))
161
+
162
+ return processed_docs, None
163
+
164
+ except Exception as e:
165
+ return None, f"Error processing document text: {str(e)}"
166
+
167
+ # Improved text splitting configuration
168
+ def create_optimized_text_splitter():
169
+ """Create an optimized text splitter for document processing"""
170
+ return RecursiveCharacterTextSplitter(
171
+ chunk_size=800, # Slightly smaller for more focused chunks
172
+ chunk_overlap=150, # Increased overlap to maintain context
173
+ length_function=len,
174
+ separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] # More comprehensive separators
175
+ )
176
+
177
  # Function to generate a summary using the IBM Granite model
178
  def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
179
  """Generate a summary from document chunks using the IBM Granite model
 
183
  length_type: Either "sentences" or "paragraphs"
184
  length_count: Number of sentences (1-10) or paragraphs (1-3)
185
  """
186
+ # Print debug information
187
  print(f"Generating summary with length_type={length_type}, length_count={length_count}")
188
 
189
  # Ensure length_count is an integer
190
  try:
191
  length_count = int(length_count)
192
  except (ValueError, TypeError):
 
193
  print(f"Failed to convert length_count to int: {length_count}, using default 3")
194
  length_count = 3
195
 
 
199
  else: # paragraphs
200
  length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
201
 
202
+ # Clean and concatenate the text from chunks
203
+ # Remove any excessive whitespace and normalize
204
+ cleaned_chunks = []
205
+ for chunk in chunks:
206
+ text = chunk.page_content
207
+ # Remove excessive newlines and whitespace
208
+ text = ' '.join(text.split())
209
+ cleaned_chunks.append(text)
210
+
211
+ combined_text = " ".join(cleaned_chunks)
212
 
213
+ # More explicit and forceful prompt structure
214
  if length_type == "sentences":
215
+ length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences."
216
  else: # paragraphs
217
+ length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs."
218
 
219
+ # More detailed prompt with examples of what constitutes a sentence
220
  prompt = f"""<instruction>
221
+ You are an expert document summarizer. Your task is to create a high-quality summary of the following text.
222
+
223
+ {length_instruction}
224
+
225
+ Remember:
226
+ - Your summary must capture the main points of the document
227
+ - Your summary must be in your own words (not copied text)
228
+ - Your summary must be clearly written and well-structured
229
+ - Do not include any explanations, headings, bullet points, or additional formatting
230
+ - Respond ONLY with the summary text itself
231
+
232
  </instruction>
233
 
234
  <text>
 
236
  </text>
237
  """
238
 
239
+ # Calculate appropriate max_new_tokens but with stricter limits
 
240
  if length_type == "sentences":
241
+ # Approximately 20 tokens per sentence
242
+ max_tokens = length_count * 40
243
  else: # paragraphs
244
+ # Approximately 100 tokens per paragraph
245
+ max_tokens = length_count * 150
246
 
247
  # Ensure minimum tokens and add buffer
248
+ max_tokens = max(100, min(1500, max_tokens))
 
 
 
249
 
250
  print(f"Using max_new_tokens={max_tokens}")
251
 
252
+ # Generate with lower temperature for more consistent results
253
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
254
+
255
  with torch.no_grad():
256
  output = model.generate(
257
  **inputs,
258
  max_new_tokens=max_tokens,
259
+ temperature=0.3, # Lower temperature for more deterministic output
260
  top_p=0.9,
261
+ do_sample=True,
262
+ repetition_penalty=1.2 # Discourage repetition
263
  )
264
 
265
  # Decode and return the generated summary
 
267
 
268
  # Extract just the generated response (after the prompt)
269
  summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):]
270
+ summary = summary.strip()
271
+
272
+ # Post-process the summary to ensure it meets the length constraints
273
+ if length_type == "sentences":
274
+ # Simple sentence counting based on periods
275
+ sentences = [s.strip() for s in summary.split('.') if s.strip()]
276
+ if len(sentences) > length_count:
277
+ # Take only the requested number of sentences
278
+ summary = '. '.join(sentences[:length_count]) + '.'
279
+ elif len(sentences) < length_count:
280
+ # If we have too few sentences, log this issue
281
+ print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}")
282
 
283
  return summary.strip()
284
 
 
337
  if markdown_path.startswith("Error"):
338
  return markdown_path
339
 
340
+ # Clean and prepare the text
341
+ progress(0.4, "Processing document text...")
342
+ processed_docs, error = clean_and_prepare_text(markdown_path)
343
+ if error:
344
+ return error
345
 
346
+ # Split the documents with optimized splitter
347
+ text_splitter = create_optimized_text_splitter()
348
+ texts = text_splitter.split_documents(processed_docs)
 
 
 
 
 
349
 
350
  if not texts:
351
  return "No text could be extracted from the document."
 
380
  # Sleep briefly to allow memory cleanup
381
  time.sleep(0.1)
382
 
383
+ # Case 1: Very small documents - use all chunks directly
384
+ if len(all_chunks) <= 8:
385
+ return generate_summary(
386
+ all_chunks,
387
+ length_type=length_type.lower(),
388
+ length_count=length_count
389
+ )
390
+
391
+ # Case 2: Medium-sized documents - process in one batch
392
+ elif len(all_chunks) <= 16:
393
+ return generate_summary(
394
+ all_chunks[:8], # Use first 8 chunks (usually contains most important info)
395
+ length_type=length_type.lower(),
396
+ length_count=length_count
397
+ )
398
+
399
+ # Case 3: Large documents - process in multiple batches
400
+ else:
401
+ # First pass: Generate summaries for each batch
402
  summaries = []
403
  for i in range(0, len(all_chunks), batch_size):
404
  batch = all_chunks[i:i+batch_size]
405
  summary = generate_summary(
406
  batch,
407
+ length_type="paragraphs", # Use paragraphs for intermediate summaries
408
+ length_count=1 # One paragraph per batch
409
  )
410
  summaries.append(summary)
411
 
412
  # Force garbage collection
413
  gc.collect()
414
 
415
+ # Second pass: Generate final summary from batch summaries
416
  final_summary = generate_summary(
417
  [Document(page_content=s) for s in summaries],
418
+ length_type=length_type.lower(),
419
  length_count=length_count
420
  )
421
  return final_summary
 
 
 
 
 
 
 
422
 
423
  except Exception as e:
424
  return f"Error processing document: {str(e)}"
 
545
  # Launch the application
546
  if __name__ == "__main__":
547
  app = create_gradio_interface()
548
+ app.launch()