Daryl Lim commited on
Commit
ca55264
·
1 Parent(s): 15864df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -195
app.py CHANGED
@@ -4,13 +4,13 @@ import shutil
4
  import torch
5
  import gradio as gr
6
  from pathlib import Path
7
- from typing import Optional, List, Dict, Any, Union
8
- import requests
9
- from urllib.parse import urlparse
10
 
11
  # Docling imports
12
  from docling.datamodel.base_models import InputFormat
13
- from docling.datamodel.pipeline_options import PdfPipelineOptions, TesseractCliOcrOptions
14
  from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline
15
 
16
  # LangChain imports
@@ -26,16 +26,22 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
26
 
27
  # Initialize IBM Granite model and tokenizer
28
  print("Loading Granite model and tokenizer...")
29
- tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.2-8b-instruct")
 
 
 
 
 
30
  model = AutoModelForCausalLM.from_pretrained(
31
- "ibm-granite/granite-3.2-8b-instruct",
32
  device_map="auto",
33
- torch_dtype=torch.bfloat16
 
34
  )
35
  print("Model loaded successfully!")
36
 
37
  # Helper function to detect document format
38
- def get_document_format(file_path) -> InputFormat:
39
  """Determine the document format based on file extension"""
40
  try:
41
  file_path = str(file_path)
@@ -48,9 +54,10 @@ def get_document_format(file_path) -> InputFormat:
48
  '.html': InputFormat.HTML,
49
  '.htm': InputFormat.HTML
50
  }
51
- return format_map.get(extension, None)
52
  except Exception as e:
53
- return f"Error in get_document_format: {str(e)}"
 
54
 
55
  # Function to convert documents to markdown
56
  def convert_document_to_markdown(doc_path) -> str:
@@ -59,16 +66,19 @@ def convert_document_to_markdown(doc_path) -> str:
59
  # Convert to absolute path string
60
  input_path = os.path.abspath(str(doc_path))
61
  print(f"Converting document: {doc_path}")
 
62
  # Create temporary directory for processing
63
  with tempfile.TemporaryDirectory() as temp_dir:
64
  # Copy input file to temp directory
65
  temp_input = os.path.join(temp_dir, os.path.basename(input_path))
66
  shutil.copy2(input_path, temp_input)
 
67
  # Configure pipeline options
68
  pipeline_options = PdfPipelineOptions()
69
- pipeline_options.do_ocr = False # Disable OCR temporarily
70
  pipeline_options.do_table_structure = True
71
- # Create converter with minimal options
 
72
  converter = DocumentConverter(
73
  allowed_formats=[
74
  InputFormat.PDF,
@@ -85,104 +95,66 @@ def convert_document_to_markdown(doc_path) -> str:
85
  )
86
  }
87
  )
 
88
  # Convert document
89
  print("Starting conversion...")
90
  conv_result = converter.convert(temp_input)
91
  if not conv_result or not conv_result.document:
92
  raise ValueError(f"Failed to convert document: {doc_path}")
 
93
  # Export to markdown
94
  print("Exporting to markdown...")
95
  md = conv_result.document.export_to_markdown()
 
96
  # Create output path
97
  output_dir = os.path.dirname(input_path)
98
  base_name = os.path.splitext(os.path.basename(input_path))[0]
99
  md_path = os.path.join(output_dir, f"{base_name}_converted.md")
 
100
  # Write markdown file
101
- print(f"Writing markdown to: {base_name}_converted.md")
102
  with open(md_path, "w", encoding="utf-8") as fp:
103
  fp.write(md)
104
  return md_path
105
  except Exception as e:
106
  return f"Error converting document: {str(e)}"
107
 
108
- # Function to download file from URL
109
- def download_file_from_url(url: str) -> Optional[str]:
110
- """Download a file from a URL and save it temporarily"""
111
- try:
112
- # Parse URL to get filename
113
- parsed_url = urlparse(url)
114
- filename = os.path.basename(parsed_url.path)
115
-
116
- if not filename:
117
- filename = "downloaded_document"
118
-
119
- # Add extension based on Content-Type if needed
120
- response = requests.get(url, stream=True)
121
- response.raise_for_status()
122
-
123
- content_type = response.headers.get('Content-Type', '')
124
- if 'pdf' in content_type:
125
- if not filename.lower().endswith('.pdf'):
126
- filename += ".pdf"
127
- elif 'word' in content_type or 'docx' in content_type:
128
- if not filename.lower().endswith(('.doc', '.docx')):
129
- filename += ".docx"
130
- elif 'powerpoint' in content_type or 'pptx' in content_type:
131
- if not filename.lower().endswith(('.ppt', '.pptx')):
132
- filename += ".pptx"
133
- elif 'html' in content_type:
134
- if not filename.lower().endswith(('.html', '.htm')):
135
- filename += ".html"
136
-
137
- # Create a temporary file
138
- temp_dir = tempfile.gettempdir()
139
- file_path = os.path.join(temp_dir, filename)
140
-
141
- # Save the file
142
- with open(file_path, 'wb') as f:
143
- for chunk in response.iter_content(chunk_size=8192):
144
- f.write(chunk)
145
-
146
- return file_path
147
- except Exception as e:
148
- print(f"Error downloading file: {str(e)}")
149
- return None
150
-
151
  # Function to generate a summary using the IBM Granite model
152
- def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abstractive", detail_level="medium", length="medium"):
153
- """Generate a summary from document chunks using the IBM Granite model"""
 
 
 
 
 
 
154
  # Concatenate the retrieved chunks
155
  combined_text = " ".join([chunk.page_content for chunk in chunks])
156
 
157
- # Create a prompt based on the summary parameters
158
- if summary_type == "extractive":
159
- summary_instruction = "Extract the key sentences from the text to create a summary."
160
- else: # abstractive
161
- summary_instruction = "Generate a comprehensive summary in your own words."
162
-
163
- if detail_level == "high":
164
- detail_instruction = "Include specific details and examples."
165
- elif detail_level == "medium":
166
- detail_instruction = "Balance key points with some supporting details."
167
- else: # low
168
- detail_instruction = "Focus only on the main points and key takeaways."
169
-
170
- if length == "short":
171
- length_instruction = "Keep the summary concise and brief."
172
- elif length == "medium":
173
- length_instruction = "Create a moderate-length summary."
174
- else: # long
175
- length_instruction = "Provide a comprehensive, detailed summary."
176
 
177
- # Construct the full prompt
178
  prompt = f"""<instruction>
179
- You are a document summarization assistant. Based on the following text, {summary_instruction} {detail_instruction} {length_instruction}
180
- </instruction>
 
 
 
 
 
181
 
182
- <text>
183
- {combined_text}
184
- </text>
185
- """
 
 
 
 
 
186
 
187
  # Generate the summary using the IBM Granite model
188
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
@@ -190,7 +162,7 @@ def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abs
190
  with torch.no_grad():
191
  output = model.generate(
192
  **inputs,
193
- max_new_tokens=1024,
194
  temperature=0.7,
195
  top_p=0.9,
196
  do_sample=True
@@ -204,68 +176,50 @@ def generate_summary(chunks: List[Document], model, tokenizer, summary_type="abs
204
 
205
  return summary.strip()
206
 
207
- # Function to summarize a full document
208
- def summarize_full_document(retriever, model, tokenizer, summary_params, chunk_size=8):
209
- """Summarize an entire document by processing all chunks"""
210
- all_chunks = []
211
-
212
- # Get all documents from the vector store
213
- for i in range(0, len(retriever.vectorstore.index_to_docstore_id), chunk_size):
214
- batch_ids = list(retriever.vectorstore.index_to_docstore_id.values())[i:i+chunk_size]
215
- batch_chunks = [retriever.vectorstore.docstore.search(doc_id) for doc_id in batch_ids]
216
- all_chunks.extend(batch_chunks)
217
-
218
- # Process chunks in manageable batches if needed
219
- summaries = []
220
- for i in range(0, len(all_chunks), chunk_size):
221
- batch = all_chunks[i:i+chunk_size]
222
- summary = generate_summary(
223
- batch,
224
- model,
225
- tokenizer,
226
- summary_type=summary_params.get("summary_type", "abstractive"),
227
- detail_level=summary_params.get("detail_level", "medium"),
228
- length=summary_params.get("length", "medium")
229
  )
230
- summaries.append(summary)
231
-
232
- # Create final summary from batch summaries if needed
233
- if len(summaries) > 1:
234
- final_summary = generate_summary(
235
- [Document(page_content=s) for s in summaries],
236
- model,
237
- tokenizer,
238
- summary_type=summary_params.get("summary_type", "abstractive"),
239
- detail_level=summary_params.get("detail_level", "medium"),
240
- length=summary_params.get("length", "medium")
241
  )
242
- return final_summary
243
- else:
244
- return summaries[0] if summaries else "No content to summarize"
 
 
 
245
 
246
  # Main function to process document and generate summary
247
  @spaces.GPU
248
  def process_document(
249
  file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None,
250
- url: Optional[str] = None,
251
- summary_type: str = "abstractive",
252
- detail_level: str = "medium",
253
- length: str = "medium",
254
  progress=gr.Progress()
255
  ):
256
- """Process a document file or URL and generate a summary"""
257
  try:
258
- # Process input source (file or URL)
259
- document_path = None
260
- if file_obj:
261
- document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
262
- elif url and url.strip():
263
- progress(0.2, "Downloading document from URL...")
264
- document_path = download_file_from_url(url.strip())
265
- if not document_path:
266
- return "Failed to download document from URL. Please check the URL and try again."
267
- else:
268
- return "Please provide either a file or a URL to summarize."
269
 
270
  # Convert document to markdown
271
  progress(0.3, "Converting document to markdown...")
@@ -278,41 +232,78 @@ def process_document(
278
  loader = UnstructuredMarkdownLoader(str(markdown_path))
279
  documents = loader.load()
280
 
 
281
  text_splitter = RecursiveCharacterTextSplitter(
282
- chunk_size=500,
283
- chunk_overlap=50,
284
- length_function=len
 
285
  )
286
  texts = text_splitter.split_documents(documents)
287
 
288
  if not texts:
289
  return "No text could be extracted from the document."
290
 
291
- # Create embeddings and vector store
292
- progress(0.6, "Creating document embeddings...")
293
- embeddings = HuggingFaceEmbeddings(
294
- model_name="nomic-ai/nomic-embed-text-v1",
295
- model_kwargs={'trust_remote_code': True}
296
- )
297
- vectorstore = FAISS.from_documents(texts, embeddings)
298
 
299
- # Create retriever
300
  retriever = vectorstore.as_retriever(
301
  search_type="similarity",
302
- search_kwargs={"k": 4}
303
  )
304
 
305
- # Generate summary
306
  progress(0.8, "Generating summary...")
307
- summary_params = {
308
- "summary_type": summary_type,
309
- "detail_level": detail_level,
310
- "length": length
311
- }
312
- summary = summarize_full_document(retriever, model, tokenizer, summary_params)
313
 
314
- progress(1.0, "Summary complete!")
315
- return summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
  except Exception as e:
318
  return f"Error processing document: {str(e)}"
@@ -320,61 +311,90 @@ def process_document(
320
  # Create Gradio interface
321
  def create_gradio_interface():
322
  """Create and launch the Gradio interface"""
323
- with gr.Blocks(title="Document Summarizer") as app:
324
- gr.Markdown("# Document Summarizer")
325
- gr.Markdown("Upload a document or provide a URL to generate a summary.")
326
 
327
  with gr.Row():
328
- with gr.Column():
329
- file_input = gr.File(label="Upload Document (PDF, DOCX, PPTX, HTML)")
330
- url_input = gr.Textbox(label="Or enter document URL")
 
 
331
 
332
  with gr.Row():
333
- with gr.Column():
334
- summary_type = gr.Radio(
335
- choices=["extractive", "abstractive"],
336
- value="abstractive",
337
- label="Summary Type"
338
- )
339
 
340
  with gr.Row():
341
- with gr.Column():
342
- detail_level = gr.Radio(
343
- choices=["low", "medium", "high"],
344
- value="medium",
345
- label="Level of Detail"
346
- )
 
 
 
347
 
348
- with gr.Column():
349
- length = gr.Radio(
350
- choices=["short", "medium", "long"],
351
- value="medium",
352
- label="Summary Length"
353
- )
 
354
 
355
- submit_btn = gr.Button("Generate Summary", variant="primary")
356
 
357
- with gr.Column():
358
- output = gr.Textbox(
359
- label="Summary Result",
360
  lines=15,
361
  max_lines=30
362
  )
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  submit_btn.click(
365
- fn=process_document,
366
- inputs=[file_input, url_input, summary_type, detail_level, length],
367
  outputs=output
368
  )
369
 
370
  gr.Markdown("""
371
  ## How to use:
372
- 1. Upload a document (PDF, DOCX, PPTX, HTML) or provide a URL
373
- 2. Choose your preferred summary parameters:
374
- - Summary Type: Extractive (pulls key sentences) or Abstractive (generates new text)
375
- - Level of Detail: Low, Medium, or High
376
- - Summary Length: Short, Medium, or Long
377
- 3. Click "Generate Summary" to process the document
 
378
  """)
379
 
380
  return app
@@ -382,4 +402,4 @@ def create_gradio_interface():
382
  # Launch the application
383
  if __name__ == "__main__":
384
  app = create_gradio_interface()
385
- app.launch()
 
4
  import torch
5
  import gradio as gr
6
  from pathlib import Path
7
+ from typing import Optional, List, Union
8
+ import gc
9
+ import time
10
 
11
  # Docling imports
12
  from docling.datamodel.base_models import InputFormat
13
+ from docling.datamodel.pipeline_options import PdfPipelineOptions
14
  from docling.document_converter import DocumentConverter, PdfFormatOption, WordFormatOption, SimplePipeline
15
 
16
  # LangChain imports
 
26
 
27
  # Initialize IBM Granite model and tokenizer
28
  print("Loading Granite model and tokenizer...")
29
+ model_name = "ibm-granite/granite-3.3-8b-instruct"
30
+
31
+ # Load tokenizer
32
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
33
+
34
+ # Load model with optimization for GPU
35
  model = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
  device_map="auto",
38
+ torch_dtype=torch.bfloat16,
39
+ load_in_8bit=True # Use 8-bit quantization for memory efficiency
40
  )
41
  print("Model loaded successfully!")
42
 
43
  # Helper function to detect document format
44
+ def get_document_format(file_path) -> Optional[InputFormat]:
45
  """Determine the document format based on file extension"""
46
  try:
47
  file_path = str(file_path)
 
54
  '.html': InputFormat.HTML,
55
  '.htm': InputFormat.HTML
56
  }
57
+ return format_map.get(extension)
58
  except Exception as e:
59
+ print(f"Error in get_document_format: {str(e)}")
60
+ return None
61
 
62
  # Function to convert documents to markdown
63
  def convert_document_to_markdown(doc_path) -> str:
 
66
  # Convert to absolute path string
67
  input_path = os.path.abspath(str(doc_path))
68
  print(f"Converting document: {doc_path}")
69
+
70
  # Create temporary directory for processing
71
  with tempfile.TemporaryDirectory() as temp_dir:
72
  # Copy input file to temp directory
73
  temp_input = os.path.join(temp_dir, os.path.basename(input_path))
74
  shutil.copy2(input_path, temp_input)
75
+
76
  # Configure pipeline options
77
  pipeline_options = PdfPipelineOptions()
78
+ pipeline_options.do_ocr = False # Disable OCR for performance
79
  pipeline_options.do_table_structure = True
80
+
81
+ # Create converter with optimized options
82
  converter = DocumentConverter(
83
  allowed_formats=[
84
  InputFormat.PDF,
 
95
  )
96
  }
97
  )
98
+
99
  # Convert document
100
  print("Starting conversion...")
101
  conv_result = converter.convert(temp_input)
102
  if not conv_result or not conv_result.document:
103
  raise ValueError(f"Failed to convert document: {doc_path}")
104
+
105
  # Export to markdown
106
  print("Exporting to markdown...")
107
  md = conv_result.document.export_to_markdown()
108
+
109
  # Create output path
110
  output_dir = os.path.dirname(input_path)
111
  base_name = os.path.splitext(os.path.basename(input_path))[0]
112
  md_path = os.path.join(output_dir, f"{base_name}_converted.md")
113
+
114
  # Write markdown file
 
115
  with open(md_path, "w", encoding="utf-8") as fp:
116
  fp.write(md)
117
  return md_path
118
  except Exception as e:
119
  return f"Error converting document: {str(e)}"
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # Function to generate a summary using the IBM Granite model
122
+ def generate_summary(chunks: List[Document], length_type="sentences", length_count=3):
123
+ """Generate a summary from document chunks using the IBM Granite model
124
+
125
+ Args:
126
+ chunks: List of document chunks to summarize
127
+ length_type: Either "sentences" or "paragraphs"
128
+ length_count: Number of sentences (1-10) or paragraphs (1-3)
129
+ """
130
  # Concatenate the retrieved chunks
131
  combined_text = " ".join([chunk.page_content for chunk in chunks])
132
 
133
+ # Construct length instruction based on type and count
134
+ if length_type == "sentences":
135
+ length_instruction = f"Summarize the following text in {length_count} sentence{'s' if length_count > 1 else ''}."
136
+ else: # paragraphs
137
+ length_instruction = f"Summarize the following text in {length_count} paragraph{'s' if length_count > 1 else ''}."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ # Construct the prompt
140
  prompt = f"""<instruction>
141
+ Knowledge Cutoff Date: April 2024. You are Granite, developed by IBM. You are a helpful AI assistant. {length_instruction} Your response should only include the answer. Do not provide any further explanation.
142
+ </instruction>
143
+
144
+ <text>
145
+ {combined_text}
146
+ </text>
147
+ """
148
 
149
+ # Calculate appropriate max_new_tokens based on length requirements
150
+ # Approximate tokens: ~15 tokens per sentence, ~75 tokens per paragraph
151
+ if length_type == "sentences":
152
+ max_tokens = length_count * 20 # Slightly more than needed for flexibility
153
+ else: # paragraphs
154
+ max_tokens = length_count * 100 # Slightly more than needed for flexibility
155
+
156
+ # Ensure minimum tokens and add buffer
157
+ max_tokens = max(100, min(1000, max_tokens + 50))
158
 
159
  # Generate the summary using the IBM Granite model
160
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
162
  with torch.no_grad():
163
  output = model.generate(
164
  **inputs,
165
+ max_new_tokens=max_tokens,
166
  temperature=0.7,
167
  top_p=0.9,
168
  do_sample=True
 
176
 
177
  return summary.strip()
178
 
179
+ # Function to process document chunks efficiently
180
+ def process_document_chunks(texts, batch_size=8):
181
+ """Process document chunks in efficient batches"""
182
+ try:
183
+ # Create embeddings with optimized settings
184
+ embeddings = HuggingFaceEmbeddings(
185
+ model_name="nomic-ai/nomic-embed-text-v1",
186
+ model_kwargs={'trust_remote_code': True}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  )
188
+
189
+ # Create vector store more efficiently
190
+ vectorstore = FAISS.from_documents(
191
+ texts,
192
+ embeddings,
193
+ # Add distance function for better retrieval
194
+ distance_strategy="cosine"
 
 
 
 
195
  )
196
+
197
+ return vectorstore
198
+ except Exception as e:
199
+ print(f"Error in document processing: {str(e)}")
200
+ # Fallback to basic processing if optimization fails
201
+ return FAISS.from_documents(texts, embeddings)
202
 
203
  # Main function to process document and generate summary
204
  @spaces.GPU
205
  def process_document(
206
  file_obj: Optional[Union[str, tempfile._TemporaryFileWrapper]] = None,
207
+ length_type: str = "sentences",
208
+ length_count: int = 3,
 
 
209
  progress=gr.Progress()
210
  ):
211
+ """Process a document file and generate a summary"""
212
  try:
213
+ # Process input file
214
+ if not file_obj:
215
+ return "Please provide a file to summarize."
216
+
217
+ document_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj)
218
+
219
+ # Validate document format
220
+ format_type = get_document_format(document_path)
221
+ if not format_type:
222
+ return "Unsupported file format. Please upload a PDF, DOCX, PPTX, or HTML file."
 
223
 
224
  # Convert document to markdown
225
  progress(0.3, "Converting document to markdown...")
 
232
  loader = UnstructuredMarkdownLoader(str(markdown_path))
233
  documents = loader.load()
234
 
235
+ # Optimize text splitting for better chunks
236
  text_splitter = RecursiveCharacterTextSplitter(
237
+ chunk_size=1000, # Larger chunk size for better context
238
+ chunk_overlap=100,
239
+ length_function=len,
240
+ separators=["\n\n", "\n", ".", " ", ""] # Prioritize splitting at paragraph/sentence boundaries
241
  )
242
  texts = text_splitter.split_documents(documents)
243
 
244
  if not texts:
245
  return "No text could be extracted from the document."
246
 
247
+ # Create vector store with efficient processing
248
+ progress(0.6, "Processing document content...")
249
+ vectorstore = process_document_chunks(texts)
 
 
 
 
250
 
251
+ # Create retriever with optimized settings
252
  retriever = vectorstore.as_retriever(
253
  search_type="similarity",
254
+ search_kwargs={"k": 4} # Number of chunks to retrieve
255
  )
256
 
257
+ # Process chunks in smaller batches for memory efficiency
258
  progress(0.8, "Generating summary...")
259
+ all_chunks = []
260
+ batch_size = 4 # Smaller batch size for memory efficiency
261
+
262
+ # Get all document chunks
263
+ doc_ids = list(vectorstore.index_to_docstore_id.values())
 
264
 
265
+ # Process in smaller batches
266
+ for i in range(0, len(doc_ids), batch_size):
267
+ batch_ids = doc_ids[i:i+batch_size]
268
+ batch_chunks = [vectorstore.docstore.search(doc_id) for doc_id in batch_ids]
269
+ all_chunks.extend(batch_chunks)
270
+
271
+ # Force garbage collection to free memory
272
+ gc.collect()
273
+
274
+ # Sleep briefly to allow memory cleanup
275
+ time.sleep(0.1)
276
+
277
+ # Generate summary from chunks
278
+ if len(all_chunks) > 8:
279
+ # If we have many chunks, process in batches
280
+ summaries = []
281
+ for i in range(0, len(all_chunks), batch_size):
282
+ batch = all_chunks[i:i+batch_size]
283
+ summary = generate_summary(
284
+ batch,
285
+ length_type=length_type,
286
+ length_count=max(1, length_count // 2) # Use smaller count for partial summaries
287
+ )
288
+ summaries.append(summary)
289
+
290
+ # Force garbage collection
291
+ gc.collect()
292
+
293
+ # Create final summary from batch summaries
294
+ final_summary = generate_summary(
295
+ [Document(page_content=s) for s in summaries],
296
+ length_type=length_type,
297
+ length_count=length_count
298
+ )
299
+ return final_summary
300
+ else:
301
+ # If we have few chunks, generate summary directly
302
+ return generate_summary(
303
+ all_chunks,
304
+ length_type=length_type,
305
+ length_count=length_count
306
+ )
307
 
308
  except Exception as e:
309
  return f"Error processing document: {str(e)}"
 
311
  # Create Gradio interface
312
  def create_gradio_interface():
313
  """Create and launch the Gradio interface"""
314
+ with gr.Blocks(title="Granite Document Summarization") as app:
315
+ gr.Markdown("# Granite Document Summarization")
316
+ gr.Markdown("Upload a document to generate a summary.")
317
 
318
  with gr.Row():
319
+ with gr.Column(scale=1):
320
+ file_input = gr.File(
321
+ label="Upload Document (PDF, DOCX, PPTX, HTML)",
322
+ file_types=[".pdf", ".docx", ".doc", ".pptx", ".html", ".htm"]
323
+ )
324
 
325
  with gr.Row():
326
+ length_type = gr.Radio(
327
+ choices=["Sentences", "Paragraphs"],
328
+ value="Sentences",
329
+ label="Summary Length Type"
330
+ )
 
331
 
332
  with gr.Row():
333
+ # Use slider for sentence count (1-10)
334
+ sentence_count = gr.Slider(
335
+ minimum=1,
336
+ maximum=10,
337
+ value=3,
338
+ step=1,
339
+ label="Number of Sentences",
340
+ visible=True
341
+ )
342
 
343
+ # Use radio for paragraph count (1-3)
344
+ paragraph_count = gr.Radio(
345
+ choices=["1", "2", "3"],
346
+ value="1",
347
+ label="Number of Paragraphs",
348
+ visible=False
349
+ )
350
 
351
+ submit_btn = gr.Button("Summarize", variant="primary")
352
 
353
+ with gr.Column(scale=2):
354
+ output = gr.TextArea(
355
+ label="Summary",
356
  lines=15,
357
  max_lines=30
358
  )
359
 
360
+ # Add interactivity to show/hide appropriate count selector
361
+ def update_count_visibility(length_type):
362
+ return {
363
+ sentence_count: length_type == "Sentences",
364
+ paragraph_count: length_type == "Paragraphs"
365
+ }
366
+
367
+ length_type.change(
368
+ fn=update_count_visibility,
369
+ inputs=[length_type],
370
+ outputs=[sentence_count, paragraph_count]
371
+ )
372
+
373
+ # Function to convert paragraph count from string to int and handle capitalized length types
374
+ def process_document_wrapper(file, length_type, sentence_count, paragraph_count):
375
+ # Convert capitalized length_type to lowercase for processing
376
+ length_type_lower = length_type.lower()
377
+
378
+ if length_type_lower == "sentences":
379
+ return process_document(file, length_type_lower, int(sentence_count))
380
+ else:
381
+ return process_document(file, length_type_lower, int(paragraph_count))
382
+
383
  submit_btn.click(
384
+ fn=process_document_wrapper,
385
+ inputs=[file_input, length_type, sentence_count, paragraph_count],
386
  outputs=output
387
  )
388
 
389
  gr.Markdown("""
390
  ## How to use:
391
+ 1. Upload a document (PDF, DOCX, PPTX, HTML)
392
+ 2. Choose your summary length preference:
393
+ - Number of Sentences (1-10)
394
+ - Number of Paragraphs (1-3)
395
+ 3. Click "Summarize" to process the document
396
+
397
+ *This application uses the IBM Granite 3.3-8b model to generate summaries.*
398
  """)
399
 
400
  return app
 
402
  # Launch the application
403
  if __name__ == "__main__":
404
  app = create_gradio_interface()
405
+ app.launch()