Daryl Lim
commited on
Commit
·
ca55264
1
Parent(s):
15864df
Update app.py
Browse files
app.py
CHANGED
@@ -4,13 +4,13 @@ import shutil
|
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
from pathlib import Path
|
7 |
-
from typing import Optional, List,
|
8 |
-
import
|
9 |
-
|
10 |
|
11 |
# Docling imports
|
12 |
from docling.datamodel.base_models import InputFormat
|
13 |
-
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
14 |
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline
|
15 |
|
16 |
# LangChain imports
|
@@ -26,16 +26,22 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
26 |
|
27 |
# Initialize IBM Granite model and tokenizer
|
28 |
print("Loading Granite model and tokenizer...")
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
model = AutoModelForCausalLM.from_pretrained(
|
31 |
-
|
32 |
device_map="auto",
|
33 |
-
torch_dtype=torch.bfloat16
|
|
|
34 |
)
|
35 |
print("Model loaded successfully!")
|
36 |
|
37 |
# Helper function to detect document format
|
38 |
-
def get_document_format(file_path) -> InputFormat:
|
39 |
"""Determine the document format based on file extension"""
|
40 |
try:
|
41 |
file_path = str(file_path)
|
@@ -48,9 +54,10 @@ def get_document_format(file_path) -> InputFormat:
|
|
48 |
'.html': InputFormat.HTML,
|
49 |
'.htm': InputFormat.HTML
|
50 |
}
|
51 |
-
return format_map.get(extension
|
52 |
except Exception as e:
|
53 |
-
|
|
|
54 |
|
55 |
# Function to convert documents to markdown
|
56 |
def convert_document_to_markdown(doc_path) -> str:
|
@@ -59,16 +66,19 @@ def convert_document_to_markdown(doc_path) -> str:
|
|
59 |
# Convert to absolute path string
|
60 |
input_path = os.path.abspath(str(doc_path))
|
61 |
print(f"Converting document: {doc_path}")
|
|
|
62 |
# Create temporary directory for processing
|
63 |
with tempfile.TemporaryDirectory() as temp_dir:
|
64 |
# Copy input file to temp directory
|
65 |
temp_input = os.path.join(temp_dir, os.path.basename(input_path))
|
66 |
shutil.copy2(input_path, temp_input)
|
|
|
67 |
# Configure pipeline options
|
68 |
pipeline_options = PdfPipelineOptions()
|
69 |
-
pipeline_options.do_ocr = False # Disable OCR
|
70 |
pipeline_options.do_table_structure = True
|
71 |
-
|
|
|
72 |
converter = DocumentConverter(
|
73 |
allowed_formats=[
|
74 |
InputFormat.PDF,
|
@@ -85,104 +95,66 @@ def convert_document_to_markdown(doc_path) -> str:
|
|
85 |
)
|
86 |
}
|
87 |
)
|
|
|
88 |
# Convert document
|
89 |
print("Starting conversion...")
|
90 |
conv_result = converter.convert(temp_input)
|
91 |
if not conv_result or not conv_result.document:
|
92 |
raise ValueError(f"Failed to convert document: {doc_path}")
|
|
|
93 |
# Export to markdown
|
94 |
print("Exporting to markdown...")
|
95 |
md = conv_result.document.export_to_markdown()
|
|
|
96 |
# Create output path
|
97 |
output_dir = os.path.dirname(input_path)
|
98 |
base_name = os.path.splitext(os.path.basename(input_path))[0]
|
99 |
md_path = os.path.join(output_dir, f"{base_name}_converted.md")
|
|
|
100 |
# Write markdown file
|
101 |
-
print(f"Writing markdown to: {base_name}_converted.md")
|
102 |
with open(md_path, "w", encoding="utf-8") as fp:
|
103 |
fp.write(md)
|
104 |
return md_path
|
105 |
except Exception as e:
|
106 |
return f"Error converting document: {str(e)}"
|
107 |
|
108 |
-
# Function to download file from URL
|
109 |
-
def download_file_from_url(url: str) -> Optional[str]:
|
110 |
-
"""Download a file from a URL and save it temporarily"""
|
111 |
-
try:
|
112 |
-
# Parse URL to get filename
|
113 |
-
parsed_url = urlparse(url)
|
114 |
-
filename = os.path.basename(parsed_url.path)
|
115 |
-
|
116 |
-
if not filename:
|
117 |
-
filename = "downloaded_document"
|
118 |
-
|
119 |
-
# Add extension based on Content-Type if needed
|
120 |
-
response = requests.get(url, stream=True)
|
121 |
-
response.raise_for_status()
|
122 |
-
|
123 |
-
content_type = response.headers.get('Content-Type', '')
|
124 |
-
if 'pdf' in content_type:
|
125 |
-
if not filename.lower().endswith('.pdf'):
|
126 |
-
filename += ".pdf"
|
127 |
-
elif 'word' in content_type or 'docx' in content_type:
|
128 |
-
if not filename.lower().endswith(('.doc', '.docx')):
|
129 |
-
filename += ".docx"
|
130 |
-
elif 'powerpoint' in content_type or 'pptx' in content_type:
|
131 |
-
if not filename.lower().endswith(('.ppt', '.pptx')):
|
132 |
-
filename += ".pptx"
|
133 |
-
elif 'html' in content_type:
|
134 |
-
if not filename.lower().endswith(('.html', '.htm')):
|
135 |
-
filename += ".html"
|
136 |
-
|
137 |
-
# Create a temporary file
|
138 |
-
temp_dir = tempfile.gettempdir()
|
139 |
-
file_path = os.path.join(temp_dir, filename)
|
140 |
-
|
141 |
-
# Save the file
|
142 |
-
with open(file_path, 'wb') as f:
|
143 |
-
for chunk in response.iter_content(chunk_size=8192):
|
144 |
-
f.write(chunk)
|
145 |
-
|
146 |
-
return file_path
|
147 |
-
except Exception as e:
|
148 |
-
print(f"Error downloading file: {str(e)}")
|
149 |
-
return None
|
150 |
-
|
151 |
# Function to generate a summary using the IBM Granite model
|
152 |
-
def generate_summary(chunks: List[Document],
|
153 |
-
"""Generate a summary from document chunks using the IBM Granite model
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
# Concatenate the retrieved chunks
|
155 |
combined_text = " ".join([chunk.page_content for chunk in chunks])
|
156 |
|
157 |
-
#
|
158 |
-
if
|
159 |
-
|
160 |
-
else: #
|
161 |
-
|
162 |
-
|
163 |
-
if detail_level == "high":
|
164 |
-
detail_instruction = "Include specific details and examples."
|
165 |
-
elif detail_level == "medium":
|
166 |
-
detail_instruction = "Balance key points with some supporting details."
|
167 |
-
else: # low
|
168 |
-
detail_instruction = "Focus only on the main points and key takeaways."
|
169 |
-
|
170 |
-
if length == "short":
|
171 |
-
length_instruction = "Keep the summary concise and brief."
|
172 |
-
elif length == "medium":
|
173 |
-
length_instruction = "Create a moderate-length summary."
|
174 |
-
else: # long
|
175 |
-
length_instruction = "Provide a comprehensive, detailed summary."
|
176 |
|
177 |
-
# Construct the
|
178 |
prompt = f"""<instruction>
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
# Generate the summary using the IBM Granite model
|
188 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
@@ -190,7 +162,7 @@ def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abs
|
|
190 |
with torch.no_grad():
|
191 |
output = model.generate(
|
192 |
**inputs,
|
193 |
-
max_new_tokens=
|
194 |
temperature=0.7,
|
195 |
top_p=0.9,
|
196 |
do_sample=True
|
@@ -204,68 +176,50 @@ def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abs
|
|
204 |
|
205 |
return summary.strip()
|
206 |
|
207 |
-
# Function to
|
208 |
-
def
|
209 |
-
"""
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
batch_chunks = [retriever.vectorstore.docstore.search(doc_id) for doc_id in batch_ids]
|
216 |
-
all_chunks.extend(batch_chunks)
|
217 |
-
|
218 |
-
# Process chunks in manageable batches if needed
|
219 |
-
summaries = []
|
220 |
-
for i in range(0, len(all_chunks), chunk_size):
|
221 |
-
batch = all_chunks[i:i+chunk_size]
|
222 |
-
summary = generate_summary(
|
223 |
-
batch,
|
224 |
-
model,
|
225 |
-
tokenizer,
|
226 |
-
summary_type=summary_params.get("summary_type", "abstractive"),
|
227 |
-
detail_level=summary_params.get("detail_level", "medium"),
|
228 |
-
length=summary_params.get("length", "medium")
|
229 |
)
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
tokenizer,
|
238 |
-
summary_type=summary_params.get("summary_type", "abstractive"),
|
239 |
-
detail_level=summary_params.get("detail_level", "medium"),
|
240 |
-
length=summary_params.get("length", "medium")
|
241 |
)
|
242 |
-
|
243 |
-
|
244 |
-
|
|
|
|
|
|
|
245 |
|
246 |
# Main function to process document and generate summary
|
247 |
@spaces.GPU
|
248 |
def process_document(
|
249 |
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None,
|
250 |
-
|
251 |
-
|
252 |
-
detail_level: str = "medium",
|
253 |
-
length: str = "medium",
|
254 |
progress=gr.Progress()
|
255 |
):
|
256 |
-
"""Process a document file
|
257 |
try:
|
258 |
-
# Process input
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
return "Please provide either a file or a URL to summarize."
|
269 |
|
270 |
# Convert document to markdown
|
271 |
progress(0.3, "Converting document to markdown...")
|
@@ -278,41 +232,78 @@ def process_document(
|
|
278 |
loader = UnstructuredMarkdownLoader(str(markdown_path))
|
279 |
documents = loader.load()
|
280 |
|
|
|
281 |
text_splitter = RecursiveCharacterTextSplitter(
|
282 |
-
chunk_size=
|
283 |
-
chunk_overlap=
|
284 |
-
length_function=len
|
|
|
285 |
)
|
286 |
texts = text_splitter.split_documents(documents)
|
287 |
|
288 |
if not texts:
|
289 |
return "No text could be extracted from the document."
|
290 |
|
291 |
-
# Create
|
292 |
-
progress(0.6, "
|
293 |
-
|
294 |
-
model_name="nomic-ai/nomic-embed-text-v1",
|
295 |
-
model_kwargs={'trust_remote_code': True}
|
296 |
-
)
|
297 |
-
vectorstore = FAISS.from_documents(texts, embeddings)
|
298 |
|
299 |
-
# Create retriever
|
300 |
retriever = vectorstore.as_retriever(
|
301 |
search_type="similarity",
|
302 |
-
search_kwargs={"k": 4}
|
303 |
)
|
304 |
|
305 |
-
#
|
306 |
progress(0.8, "Generating summary...")
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
summary = summarize_full_document(retriever, model, tokenizer, summary_params)
|
313 |
|
314 |
-
|
315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
except Exception as e:
|
318 |
return f"Error processing document: {str(e)}"
|
@@ -320,61 +311,90 @@ def process_document(
|
|
320 |
# Create Gradio interface
|
321 |
def create_gradio_interface():
|
322 |
"""Create and launch the Gradio interface"""
|
323 |
-
with gr.Blocks(title="Document
|
324 |
-
gr.Markdown("# Document
|
325 |
-
gr.Markdown("Upload a document
|
326 |
|
327 |
with gr.Row():
|
328 |
-
with gr.Column():
|
329 |
-
file_input = gr.File(
|
330 |
-
|
|
|
|
|
331 |
|
332 |
with gr.Row():
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
)
|
339 |
|
340 |
with gr.Row():
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
347 |
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
|
|
354 |
|
355 |
-
submit_btn = gr.Button("
|
356 |
|
357 |
-
with gr.Column():
|
358 |
-
output = gr.
|
359 |
-
label="Summary
|
360 |
lines=15,
|
361 |
max_lines=30
|
362 |
)
|
363 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
submit_btn.click(
|
365 |
-
fn=
|
366 |
-
inputs=[file_input,
|
367 |
outputs=output
|
368 |
)
|
369 |
|
370 |
gr.Markdown("""
|
371 |
## How to use:
|
372 |
-
1. Upload a document (PDF, DOCX, PPTX, HTML)
|
373 |
-
2. Choose your
|
374 |
-
-
|
375 |
-
-
|
376 |
-
|
377 |
-
|
|
|
378 |
""")
|
379 |
|
380 |
return app
|
@@ -382,4 +402,4 @@ def create_gradio_interface():
|
|
382 |
# Launch the application
|
383 |
if __name__ == "__main__":
|
384 |
app = create_gradio_interface()
|
385 |
-
app.launch()
|
|
|
4 |
import torch
|
5 |
import gradio as gr
|
6 |
from pathlib import Path
|
7 |
+
from typing import Optional, List, Union
|
8 |
+
import gc
|
9 |
+
import time
|
10 |
|
11 |
# Docling imports
|
12 |
from docling.datamodel.base_models import InputFormat
|
13 |
+
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
14 |
from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline
|
15 |
|
16 |
# LangChain imports
|
|
|
26 |
|
27 |
# Initialize IBM Granite model and tokenizer
|
28 |
print("Loading Granite model and tokenizer...")
|
29 |
+
model_name = "ibm-granite/granite-3.3-8b-instruct"
|
30 |
+
|
31 |
+
# Load tokenizer
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
33 |
+
|
34 |
+
# Load model with optimization for GPU
|
35 |
model = AutoModelForCausalLM.from_pretrained(
|
36 |
+
model_name,
|
37 |
device_map="auto",
|
38 |
+
torch_dtype=torch.bfloat16,
|
39 |
+
load_in_8bit=True # Use 8-bit quantization for memory efficiency
|
40 |
)
|
41 |
print("Model loaded successfully!")
|
42 |
|
43 |
# Helper function to detect document format
|
44 |
+
def get_document_format(file_path) -> Optional[InputFormat]:
|
45 |
"""Determine the document format based on file extension"""
|
46 |
try:
|
47 |
file_path = str(file_path)
|
|
|
54 |
'.html': InputFormat.HTML,
|
55 |
'.htm': InputFormat.HTML
|
56 |
}
|
57 |
+
return format_map.get(extension)
|
58 |
except Exception as e:
|
59 |
+
print(f"Error in get_document_format: {str(e)}")
|
60 |
+
return None
|
61 |
|
62 |
# Function to convert documents to markdown
|
63 |
def convert_document_to_markdown(doc_path) -> str:
|
|
|
66 |
# Convert to absolute path string
|
67 |
input_path = os.path.abspath(str(doc_path))
|
68 |
print(f"Converting document: {doc_path}")
|
69 |
+
|
70 |
# Create temporary directory for processing
|
71 |
with tempfile.TemporaryDirectory() as temp_dir:
|
72 |
# Copy input file to temp directory
|
73 |
temp_input = os.path.join(temp_dir, os.path.basename(input_path))
|
74 |
shutil.copy2(input_path, temp_input)
|
75 |
+
|
76 |
# Configure pipeline options
|
77 |
pipeline_options = PdfPipelineOptions()
|
78 |
+
pipeline_options.do_ocr = False # Disable OCR for performance
|
79 |
pipeline_options.do_table_structure = True
|
80 |
+
|
81 |
+
# Create converter with optimized options
|
82 |
converter = DocumentConverter(
|
83 |
allowed_formats=[
|
84 |
InputFormat.PDF,
|
|
|
95 |
)
|
96 |
}
|
97 |
)
|
98 |
+
|
99 |
# Convert document
|
100 |
print("Starting conversion...")
|
101 |
conv_result = converter.convert(temp_input)
|
102 |
if not conv_result or not conv_result.document:
|
103 |
raise ValueError(f"Failed to convert document: {doc_path}")
|
104 |
+
|
105 |
# Export to markdown
|
106 |
print("Exporting to markdown...")
|
107 |
md = conv_result.document.export_to_markdown()
|
108 |
+
|
109 |
# Create output path
|
110 |
output_dir = os.path.dirname(input_path)
|
111 |
base_name = os.path.splitext(os.path.basename(input_path))[0]
|
112 |
md_path = os.path.join(output_dir, f"{base_name}_converted.md")
|
113 |
+
|
114 |
# Write markdown file
|
|
|
115 |
with open(md_path, "w", encoding="utf-8") as fp:
|
116 |
fp.write(md)
|
117 |
return md_path
|
118 |
except Exception as e:
|
119 |
return f"Error converting document: {str(e)}"
|
120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
# Function to generate a summary using the IBM Granite model
|
122 |
+
def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
|
123 |
+
"""Generate a summary from document chunks using the IBM Granite model
|
124 |
+
|
125 |
+
Args:
|
126 |
+
chunks: List of document chunks to summarize
|
127 |
+
length_type: Either "sentences" or "paragraphs"
|
128 |
+
length_count: Number of sentences (1-10) or paragraphs (1-3)
|
129 |
+
"""
|
130 |
# Concatenate the retrieved chunks
|
131 |
combined_text = " ".join([chunk.page_content for chunk in chunks])
|
132 |
|
133 |
+
# Construct length instruction based on type and count
|
134 |
+
if length_type == "sentences":
|
135 |
+
length_instruction = f"Summarize the following text in {length_count} sentence{'s' if length_count > 1 else ''}."
|
136 |
+
else: # paragraphs
|
137 |
+
length_instruction = f"Summarize the following text in {length_count} paragraph{'s' if length_count > 1 else ''}."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
|
139 |
+
# Construct the prompt
|
140 |
prompt = f"""<instruction>
|
141 |
+
Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. {length_instruction} Your response should only include the answer. Do not provide any further explanation.
|
142 |
+
</instruction>
|
143 |
+
|
144 |
+
<text>
|
145 |
+
{combined_text}
|
146 |
+
</text>
|
147 |
+
"""
|
148 |
|
149 |
+
# Calculate appropriate max_new_tokens based on length requirements
|
150 |
+
# Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
|
151 |
+
if length_type == "sentences":
|
152 |
+
max_tokens = length_count * 20 # Slightly more than needed for flexibility
|
153 |
+
else: # paragraphs
|
154 |
+
max_tokens = length_count * 100 # Slightly more than needed for flexibility
|
155 |
+
|
156 |
+
# Ensure minimum tokens and add buffer
|
157 |
+
max_tokens = max(100, min(1000, max_tokens + 50))
|
158 |
|
159 |
# Generate the summary using the IBM Granite model
|
160 |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
|
|
162 |
with torch.no_grad():
|
163 |
output = model.generate(
|
164 |
**inputs,
|
165 |
+
max_new_tokens=max_tokens,
|
166 |
temperature=0.7,
|
167 |
top_p=0.9,
|
168 |
do_sample=True
|
|
|
176 |
|
177 |
return summary.strip()
|
178 |
|
179 |
+
# Function to process document chunks efficiently
|
180 |
+
def process_document_chunks(texts, batch_size=8):
|
181 |
+
"""Process document chunks in efficient batches"""
|
182 |
+
try:
|
183 |
+
# Create embeddings with optimized settings
|
184 |
+
embeddings = HuggingFaceEmbeddings(
|
185 |
+
model_name="nomic-ai/nomic-embed-text-v1",
|
186 |
+
model_kwargs={'trust_remote_code': True}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
)
|
188 |
+
|
189 |
+
# Create vector store more efficiently
|
190 |
+
vectorstore = FAISS.from_documents(
|
191 |
+
texts,
|
192 |
+
embeddings,
|
193 |
+
# Add distance function for better retrieval
|
194 |
+
distance_strategy="cosine"
|
|
|
|
|
|
|
|
|
195 |
)
|
196 |
+
|
197 |
+
return vectorstore
|
198 |
+
except Exception as e:
|
199 |
+
print(f"Error in document processing: {str(e)}")
|
200 |
+
# Fallback to basic processing if optimization fails
|
201 |
+
return FAISS.from_documents(texts, embeddings)
|
202 |
|
203 |
# Main function to process document and generate summary
|
204 |
@spaces.GPU
|
205 |
def process_document(
|
206 |
file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None,
|
207 |
+
length_type: str = "sentences",
|
208 |
+
length_count: int = 3,
|
|
|
|
|
209 |
progress=gr.Progress()
|
210 |
):
|
211 |
+
"""Process a document file and generate a summary"""
|
212 |
try:
|
213 |
+
# Process input file
|
214 |
+
if not file_obj:
|
215 |
+
return "Please provide a file to summarize."
|
216 |
+
|
217 |
+
document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
|
218 |
+
|
219 |
+
# Validate document format
|
220 |
+
format_type = get_document_format(document_path)
|
221 |
+
if not format_type:
|
222 |
+
return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file."
|
|
|
223 |
|
224 |
# Convert document to markdown
|
225 |
progress(0.3, "Converting document to markdown...")
|
|
|
232 |
loader = UnstructuredMarkdownLoader(str(markdown_path))
|
233 |
documents = loader.load()
|
234 |
|
235 |
+
# Optimize text splitting for better chunks
|
236 |
text_splitter = RecursiveCharacterTextSplitter(
|
237 |
+
chunk_size=1000, # Larger chunk size for better context
|
238 |
+
chunk_overlap=100,
|
239 |
+
length_function=len,
|
240 |
+
separators=["\n\n", "\n", ".", " ", ""] # Prioritize splitting at paragraph/sentence boundaries
|
241 |
)
|
242 |
texts = text_splitter.split_documents(documents)
|
243 |
|
244 |
if not texts:
|
245 |
return "No text could be extracted from the document."
|
246 |
|
247 |
+
# Create vector store with efficient processing
|
248 |
+
progress(0.6, "Processing document content...")
|
249 |
+
vectorstore = process_document_chunks(texts)
|
|
|
|
|
|
|
|
|
250 |
|
251 |
+
# Create retriever with optimized settings
|
252 |
retriever = vectorstore.as_retriever(
|
253 |
search_type="similarity",
|
254 |
+
search_kwargs={"k": 4} # Number of chunks to retrieve
|
255 |
)
|
256 |
|
257 |
+
# Process chunks in smaller batches for memory efficiency
|
258 |
progress(0.8, "Generating summary...")
|
259 |
+
all_chunks = []
|
260 |
+
batch_size = 4 # Smaller batch size for memory efficiency
|
261 |
+
|
262 |
+
# Get all document chunks
|
263 |
+
doc_ids = list(vectorstore.index_to_docstore_id.values())
|
|
|
264 |
|
265 |
+
# Process in smaller batches
|
266 |
+
for i in range(0, len(doc_ids), batch_size):
|
267 |
+
batch_ids = doc_ids[i:i+batch_size]
|
268 |
+
batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids]
|
269 |
+
all_chunks.extend(batch_chunks)
|
270 |
+
|
271 |
+
# Force garbage collection to free memory
|
272 |
+
gc.collect()
|
273 |
+
|
274 |
+
# Sleep briefly to allow memory cleanup
|
275 |
+
time.sleep(0.1)
|
276 |
+
|
277 |
+
# Generate summary from chunks
|
278 |
+
if len(all_chunks) > 8:
|
279 |
+
# If we have many chunks, process in batches
|
280 |
+
summaries = []
|
281 |
+
for i in range(0, len(all_chunks), batch_size):
|
282 |
+
batch = all_chunks[i:i+batch_size]
|
283 |
+
summary = generate_summary(
|
284 |
+
batch,
|
285 |
+
length_type=length_type,
|
286 |
+
length_count=max(1, length_count // 2) # Use smaller count for partial summaries
|
287 |
+
)
|
288 |
+
summaries.append(summary)
|
289 |
+
|
290 |
+
# Force garbage collection
|
291 |
+
gc.collect()
|
292 |
+
|
293 |
+
# Create final summary from batch summaries
|
294 |
+
final_summary = generate_summary(
|
295 |
+
[Document(page_content=s) for s in summaries],
|
296 |
+
length_type=length_type,
|
297 |
+
length_count=length_count
|
298 |
+
)
|
299 |
+
return final_summary
|
300 |
+
else:
|
301 |
+
# If we have few chunks, generate summary directly
|
302 |
+
return generate_summary(
|
303 |
+
all_chunks,
|
304 |
+
length_type=length_type,
|
305 |
+
length_count=length_count
|
306 |
+
)
|
307 |
|
308 |
except Exception as e:
|
309 |
return f"Error processing document: {str(e)}"
|
|
|
311 |
# Create Gradio interface
|
312 |
def create_gradio_interface():
|
313 |
"""Create and launch the Gradio interface"""
|
314 |
+
with gr.Blocks(title="Granite Document Summarization") as app:
|
315 |
+
gr.Markdown("# Granite Document Summarization")
|
316 |
+
gr.Markdown("Upload a document to generate a summary.")
|
317 |
|
318 |
with gr.Row():
|
319 |
+
with gr.Column(scale=1):
|
320 |
+
file_input = gr.File(
|
321 |
+
label="Upload Document (PDF, DOCX, PPTX, HTML)",
|
322 |
+
file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"]
|
323 |
+
)
|
324 |
|
325 |
with gr.Row():
|
326 |
+
length_type = gr.Radio(
|
327 |
+
choices=["Sentences", "Paragraphs"],
|
328 |
+
value="Sentences",
|
329 |
+
label="Summary Length Type"
|
330 |
+
)
|
|
|
331 |
|
332 |
with gr.Row():
|
333 |
+
# Use slider for sentence count (1-10)
|
334 |
+
sentence_count = gr.Slider(
|
335 |
+
minimum=1,
|
336 |
+
maximum=10,
|
337 |
+
value=3,
|
338 |
+
step=1,
|
339 |
+
label="Number of Sentences",
|
340 |
+
visible=True
|
341 |
+
)
|
342 |
|
343 |
+
# Use radio for paragraph count (1-3)
|
344 |
+
paragraph_count = gr.Radio(
|
345 |
+
choices=["1", "2", "3"],
|
346 |
+
value="1",
|
347 |
+
label="Number of Paragraphs",
|
348 |
+
visible=False
|
349 |
+
)
|
350 |
|
351 |
+
submit_btn = gr.Button("Summarize", variant="primary")
|
352 |
|
353 |
+
with gr.Column(scale=2):
|
354 |
+
output = gr.TextArea(
|
355 |
+
label="Summary",
|
356 |
lines=15,
|
357 |
max_lines=30
|
358 |
)
|
359 |
|
360 |
+
# Add interactivity to show/hide appropriate count selector
|
361 |
+
def update_count_visibility(length_type):
|
362 |
+
return {
|
363 |
+
sentence_count: length_type == "Sentences",
|
364 |
+
paragraph_count: length_type == "Paragraphs"
|
365 |
+
}
|
366 |
+
|
367 |
+
length_type.change(
|
368 |
+
fn=update_count_visibility,
|
369 |
+
inputs=[length_type],
|
370 |
+
outputs=[sentence_count, paragraph_count]
|
371 |
+
)
|
372 |
+
|
373 |
+
# Function to convert paragraph count from string to int and handle capitalized length types
|
374 |
+
def process_document_wrapper(file, length_type, sentence_count, paragraph_count):
|
375 |
+
# Convert capitalized length_type to lowercase for processing
|
376 |
+
length_type_lower = length_type.lower()
|
377 |
+
|
378 |
+
if length_type_lower == "sentences":
|
379 |
+
return process_document(file, length_type_lower, int(sentence_count))
|
380 |
+
else:
|
381 |
+
return process_document(file, length_type_lower, int(paragraph_count))
|
382 |
+
|
383 |
submit_btn.click(
|
384 |
+
fn=process_document_wrapper,
|
385 |
+
inputs=[file_input, length_type, sentence_count, paragraph_count],
|
386 |
outputs=output
|
387 |
)
|
388 |
|
389 |
gr.Markdown("""
|
390 |
## How to use:
|
391 |
+
1. Upload a document (PDF, DOCX, PPTX, HTML)
|
392 |
+
2. Choose your summary length preference:
|
393 |
+
- Number of Sentences (1-10)
|
394 |
+
- Number of Paragraphs (1-3)
|
395 |
+
3. Click "Summarize" to process the document
|
396 |
+
|
397 |
+
*This application uses the IBM Granite 3.3-8b model to generate summaries.*
|
398 |
""")
|
399 |
|
400 |
return app
|
|
|
402 |
# Launch the application
|
403 |
if __name__ == "__main__":
|
404 |
app = create_gradio_interface()
|
405 |
+
app.launch()
|