Spaces:
Running
on
Zero
Running
on
Zero
Daryl Lim
commited on
Commit
·
80a5f54
1
Parent(s):
f096bc8
Update app.py
Browse files
app.py
CHANGED
@@ -123,6 +123,57 @@ def convert_document_to_markdown(doc_path) -> str:
|
|
123 |
except Exception as e:
|
124 |
return f"Error converting document: {str(e)}"
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
# Function to generate a summary using the IBM Granite model
|
127 |
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
|
128 |
"""Generate a summary from document chunks using the IBM Granite model
|
@@ -132,14 +183,13 @@ def generate_summary(chunks: List[Document], length_type="sentences", length_cou
|
|
132 |
length_type: Either "sentences" or "paragraphs"
|
133 |
length_count: Number of sentences (1-10) or paragraphs (1-3)
|
134 |
"""
|
135 |
-
# Print debug information
|
136 |
print(f"Generating summary with length_type={length_type}, length_count={length_count}")
|
137 |
|
138 |
# Ensure length_count is an integer
|
139 |
try:
|
140 |
length_count = int(length_count)
|
141 |
except (ValueError, TypeError):
|
142 |
-
# Default to 3 if conversion fails
|
143 |
print(f"Failed to convert length_count to int: {length_count}, using default 3")
|
144 |
length_count = 3
|
145 |
|
@@ -149,18 +199,36 @@ def generate_summary(chunks: List[Document], length_type="sentences", length_cou
|
|
149 |
else: # paragraphs
|
150 |
length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
|
151 |
|
152 |
-
#
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
#
|
156 |
if length_type == "sentences":
|
157 |
-
length_instruction = f"
|
158 |
else: # paragraphs
|
159 |
-
length_instruction = f"
|
160 |
|
161 |
-
#
|
162 |
prompt = f"""<instruction>
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
</instruction>
|
165 |
|
166 |
<text>
|
@@ -168,28 +236,30 @@ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a
|
|
168 |
</text>
|
169 |
"""
|
170 |
|
171 |
-
# Calculate appropriate max_new_tokens
|
172 |
-
# Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
|
173 |
if length_type == "sentences":
|
174 |
-
|
|
|
175 |
else: # paragraphs
|
176 |
-
|
|
|
177 |
|
178 |
# Ensure minimum tokens and add buffer
|
179 |
-
max_tokens = max(100, min(1500, max_tokens
|
180 |
-
|
181 |
-
# Generate the summary using the IBM Granite model
|
182 |
-
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
183 |
|
184 |
print(f"Using max_new_tokens={max_tokens}")
|
185 |
|
|
|
|
|
|
|
186 |
with torch.no_grad():
|
187 |
output = model.generate(
|
188 |
**inputs,
|
189 |
max_new_tokens=max_tokens,
|
190 |
-
temperature=0.
|
191 |
top_p=0.9,
|
192 |
-
do_sample=True
|
|
|
193 |
)
|
194 |
|
195 |
# Decode and return the generated summary
|
@@ -197,6 +267,18 @@ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a
|
|
197 |
|
198 |
# Extract just the generated response (after the prompt)
|
199 |
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
return summary.strip()
|
202 |
|
@@ -255,19 +337,15 @@ def process_document(
|
|
255 |
if markdown_path.startswith("Error"):
|
256 |
return markdown_path
|
257 |
|
258 |
-
#
|
259 |
-
progress(0.4, "
|
260 |
-
|
261 |
-
|
|
|
262 |
|
263 |
-
#
|
264 |
-
text_splitter =
|
265 |
-
|
266 |
-
chunk_overlap=100,
|
267 |
-
length_function=len,
|
268 |
-
separators=["\n\n", "\n", ".", " ", ""] # Prioritize splitting at paragraph/sentence boundaries
|
269 |
-
)
|
270 |
-
texts = text_splitter.split_documents(documents)
|
271 |
|
272 |
if not texts:
|
273 |
return "No text could be extracted from the document."
|
@@ -302,36 +380,45 @@ def process_document(
|
|
302 |
# Sleep briefly to allow memory cleanup
|
303 |
time.sleep(0.1)
|
304 |
|
305 |
-
#
|
306 |
-
if len(all_chunks)
|
307 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
summaries = []
|
309 |
for i in range(0, len(all_chunks), batch_size):
|
310 |
batch = all_chunks[i:i+batch_size]
|
311 |
summary = generate_summary(
|
312 |
batch,
|
313 |
-
length_type=
|
314 |
-
length_count=
|
315 |
)
|
316 |
summaries.append(summary)
|
317 |
|
318 |
# Force garbage collection
|
319 |
gc.collect()
|
320 |
|
321 |
-
#
|
322 |
final_summary = generate_summary(
|
323 |
[Document(page_content=s) for s in summaries],
|
324 |
-
length_type=length_type,
|
325 |
length_count=length_count
|
326 |
)
|
327 |
return final_summary
|
328 |
-
else:
|
329 |
-
# If we have few chunks, generate summary directly
|
330 |
-
return generate_summary(
|
331 |
-
all_chunks,
|
332 |
-
length_type=length_type,
|
333 |
-
length_count=length_count
|
334 |
-
)
|
335 |
|
336 |
except Exception as e:
|
337 |
return f"Error processing document: {str(e)}"
|
@@ -458,4 +545,4 @@ def create_gradio_interface():
|
|
458 |
# Launch the application
|
459 |
if __name__ == "__main__":
|
460 |
app = create_gradio_interface()
|
461 |
-
app.launch()
|
|
|
123 |
except Exception as e:
|
124 |
return f"Error converting document: {str(e)}"
|
125 |
|
126 |
+
# Improved text processing function
|
127 |
+
def clean_and_prepare_text(markdown_path):
|
128 |
+
"""Load, clean and prepare document text for better processing"""
|
129 |
+
try:
|
130 |
+
# Load the document
|
131 |
+
loader = UnstructuredMarkdownLoader(str(markdown_path))
|
132 |
+
documents = loader.load()
|
133 |
+
|
134 |
+
if not documents:
|
135 |
+
return None, "No content could be extracted from the document."
|
136 |
+
|
137 |
+
# Combine all document content for pre-processing
|
138 |
+
raw_text = " ".join([doc.page_content for doc in documents])
|
139 |
+
|
140 |
+
# Clean up the text
|
141 |
+
# 1. Normalize whitespace
|
142 |
+
text = " ".join(raw_text.split())
|
143 |
+
# 2. Fix common OCR and conversion artifacts
|
144 |
+
text = text.replace(" .", ".").replace(" ,", ",")
|
145 |
+
# 3. Ensure proper spacing after punctuation
|
146 |
+
for punct in ['.', '!', '?']:
|
147 |
+
text = text.replace(f"{punct}", f"{punct} ")
|
148 |
+
|
149 |
+
# Split into improved documents
|
150 |
+
# Use a sensible paragraph size
|
151 |
+
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
152 |
+
|
153 |
+
# Create structured documents for better processing
|
154 |
+
processed_docs = []
|
155 |
+
for i, para in enumerate(paragraphs):
|
156 |
+
if len(para) > 10: # Skip very short paragraphs
|
157 |
+
processed_docs.append(Document(
|
158 |
+
page_content=para,
|
159 |
+
metadata={"source": markdown_path, "paragraph": i}
|
160 |
+
))
|
161 |
+
|
162 |
+
return processed_docs, None
|
163 |
+
|
164 |
+
except Exception as e:
|
165 |
+
return None, f"Error processing document text: {str(e)}"
|
166 |
+
|
167 |
+
# Improved text splitting configuration
|
168 |
+
def create_optimized_text_splitter():
|
169 |
+
"""Create an optimized text splitter for document processing"""
|
170 |
+
return RecursiveCharacterTextSplitter(
|
171 |
+
chunk_size=800, # Slightly smaller for more focused chunks
|
172 |
+
chunk_overlap=150, # Increased overlap to maintain context
|
173 |
+
length_function=len,
|
174 |
+
separators=["\n\n", "\n", ".", "!", "?", ";", ":", " ", ""] # More comprehensive separators
|
175 |
+
)
|
176 |
+
|
177 |
# Function to generate a summary using the IBM Granite model
|
178 |
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
|
179 |
"""Generate a summary from document chunks using the IBM Granite model
|
|
|
183 |
length_type: Either "sentences" or "paragraphs"
|
184 |
length_count: Number of sentences (1-10) or paragraphs (1-3)
|
185 |
"""
|
186 |
+
# Print debug information
|
187 |
print(f"Generating summary with length_type={length_type}, length_count={length_count}")
|
188 |
|
189 |
# Ensure length_count is an integer
|
190 |
try:
|
191 |
length_count = int(length_count)
|
192 |
except (ValueError, TypeError):
|
|
|
193 |
print(f"Failed to convert length_count to int: {length_count}, using default 3")
|
194 |
length_count = 3
|
195 |
|
|
|
199 |
else: # paragraphs
|
200 |
length_count = max(1, min(3, length_count)) # Limit to 1-3 paragraphs
|
201 |
|
202 |
+
# Clean and concatenate the text from chunks
|
203 |
+
# Remove any excessive whitespace and normalize
|
204 |
+
cleaned_chunks = []
|
205 |
+
for chunk in chunks:
|
206 |
+
text = chunk.page_content
|
207 |
+
# Remove excessive newlines and whitespace
|
208 |
+
text = ' '.join(text.split())
|
209 |
+
cleaned_chunks.append(text)
|
210 |
+
|
211 |
+
combined_text = " ".join(cleaned_chunks)
|
212 |
|
213 |
+
# More explicit and forceful prompt structure
|
214 |
if length_type == "sentences":
|
215 |
+
length_instruction = f"Create a concise summary that is EXACTLY {length_count} complete sentences. Not {length_count-1} sentences. Not {length_count+1} sentences. EXACTLY {length_count} sentences."
|
216 |
else: # paragraphs
|
217 |
+
length_instruction = f"Create a concise summary that is EXACTLY {length_count} paragraphs. Each paragraph should be 2-4 sentences long. Not {length_count-1} paragraphs. Not {length_count+1} paragraphs. EXACTLY {length_count} paragraphs."
|
218 |
|
219 |
+
# More detailed prompt with examples of what constitutes a sentence
|
220 |
prompt = f"""<instruction>
|
221 |
+
You are an expert document summarizer. Your task is to create a high-quality summary of the following text.
|
222 |
+
|
223 |
+
{length_instruction}
|
224 |
+
|
225 |
+
Remember:
|
226 |
+
- Your summary must capture the main points of the document
|
227 |
+
- Your summary must be in your own words (not copied text)
|
228 |
+
- Your summary must be clearly written and well-structured
|
229 |
+
- Do not include any explanations, headings, bullet points, or additional formatting
|
230 |
+
- Respond ONLY with the summary text itself
|
231 |
+
|
232 |
</instruction>
|
233 |
|
234 |
<text>
|
|
|
236 |
</text>
|
237 |
"""
|
238 |
|
239 |
+
# Calculate appropriate max_new_tokens but with stricter limits
|
|
|
240 |
if length_type == "sentences":
|
241 |
+
# Approximately 20 tokens per sentence
|
242 |
+
max_tokens = length_count * 40
|
243 |
else: # paragraphs
|
244 |
+
# Approximately 100 tokens per paragraph
|
245 |
+
max_tokens = length_count * 150
|
246 |
|
247 |
# Ensure minimum tokens and add buffer
|
248 |
+
max_tokens = max(100, min(1500, max_tokens))
|
|
|
|
|
|
|
249 |
|
250 |
print(f"Using max_new_tokens={max_tokens}")
|
251 |
|
252 |
+
# Generate with lower temperature for more consistent results
|
253 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
254 |
+
|
255 |
with torch.no_grad():
|
256 |
output = model.generate(
|
257 |
**inputs,
|
258 |
max_new_tokens=max_tokens,
|
259 |
+
temperature=0.3, # Lower temperature for more deterministic output
|
260 |
top_p=0.9,
|
261 |
+
do_sample=True,
|
262 |
+
repetition_penalty=1.2 # Discourage repetition
|
263 |
)
|
264 |
|
265 |
# Decode and return the generated summary
|
|
|
267 |
|
268 |
# Extract just the generated response (after the prompt)
|
269 |
summary = summary[len(tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)):]
|
270 |
+
summary = summary.strip()
|
271 |
+
|
272 |
+
# Post-process the summary to ensure it meets the length constraints
|
273 |
+
if length_type == "sentences":
|
274 |
+
# Simple sentence counting based on periods
|
275 |
+
sentences = [s.strip() for s in summary.split('.') if s.strip()]
|
276 |
+
if len(sentences) > length_count:
|
277 |
+
# Take only the requested number of sentences
|
278 |
+
summary = '. '.join(sentences[:length_count]) + '.'
|
279 |
+
elif len(sentences) < length_count:
|
280 |
+
# If we have too few sentences, log this issue
|
281 |
+
print(f"Warning: Generated only {len(sentences)} sentences instead of {length_count}")
|
282 |
|
283 |
return summary.strip()
|
284 |
|
|
|
337 |
if markdown_path.startswith("Error"):
|
338 |
return markdown_path
|
339 |
|
340 |
+
# Clean and prepare the text
|
341 |
+
progress(0.4, "Processing document text...")
|
342 |
+
processed_docs, error = clean_and_prepare_text(markdown_path)
|
343 |
+
if error:
|
344 |
+
return error
|
345 |
|
346 |
+
# Split the documents with optimized splitter
|
347 |
+
text_splitter = create_optimized_text_splitter()
|
348 |
+
texts = text_splitter.split_documents(processed_docs)
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
if not texts:
|
351 |
return "No text could be extracted from the document."
|
|
|
380 |
# Sleep briefly to allow memory cleanup
|
381 |
time.sleep(0.1)
|
382 |
|
383 |
+
# Case 1: Very small documents - use all chunks directly
|
384 |
+
if len(all_chunks) <= 8:
|
385 |
+
return generate_summary(
|
386 |
+
all_chunks,
|
387 |
+
length_type=length_type.lower(),
|
388 |
+
length_count=length_count
|
389 |
+
)
|
390 |
+
|
391 |
+
# Case 2: Medium-sized documents - process in one batch
|
392 |
+
elif len(all_chunks) <= 16:
|
393 |
+
return generate_summary(
|
394 |
+
all_chunks[:8], # Use first 8 chunks (usually contains most important info)
|
395 |
+
length_type=length_type.lower(),
|
396 |
+
length_count=length_count
|
397 |
+
)
|
398 |
+
|
399 |
+
# Case 3: Large documents - process in multiple batches
|
400 |
+
else:
|
401 |
+
# First pass: Generate summaries for each batch
|
402 |
summaries = []
|
403 |
for i in range(0, len(all_chunks), batch_size):
|
404 |
batch = all_chunks[i:i+batch_size]
|
405 |
summary = generate_summary(
|
406 |
batch,
|
407 |
+
length_type="paragraphs", # Use paragraphs for intermediate summaries
|
408 |
+
length_count=1 # One paragraph per batch
|
409 |
)
|
410 |
summaries.append(summary)
|
411 |
|
412 |
# Force garbage collection
|
413 |
gc.collect()
|
414 |
|
415 |
+
# Second pass: Generate final summary from batch summaries
|
416 |
final_summary = generate_summary(
|
417 |
[Document(page_content=s) for s in summaries],
|
418 |
+
length_type=length_type.lower(),
|
419 |
length_count=length_count
|
420 |
)
|
421 |
return final_summary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
|
423 |
except Exception as e:
|
424 |
return f"Error processing document: {str(e)}"
|
|
|
545 |
# Launch the application
|
546 |
if __name__ == "__main__":
|
547 |
app = create_gradio_interface()
|
548 |
+
app.launch()
|