ikraamkb commited on
Commit
990a952
·
verified ·
1 Parent(s): a94c049

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -56
app.py CHANGED
@@ -139,8 +139,9 @@ async def summarize_document(file, length="medium"):
139
  if pdf_path:
140
  result["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
141
  return result"""
142
- from fastapi import FastAPI, UploadFile, File, Form
143
- from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
 
144
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
145
  import fitz # PyMuPDF
146
  import docx
@@ -158,23 +159,23 @@ import easyocr
158
  import datetime
159
  import hashlib
160
 
161
- # Initialize
162
  nltk.download('punkt', quiet=True)
163
- app = FastAPI()
164
 
165
- # Load Summarizer Model
166
  MODEL_NAME = "facebook/bart-large-cnn"
167
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
168
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
169
  model.eval()
170
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
171
 
172
- # Load OCR Reader
173
  reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
174
 
175
- # Cache
176
  summary_cache = {}
177
 
 
 
 
178
  # --- Helper Functions ---
179
 
180
  def clean_text(text: str) -> str:
@@ -184,9 +185,9 @@ def clean_text(text: str) -> str:
184
  text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
185
  return text.strip()
186
 
187
- def extract_text(file_path: str, file_extension: str):
188
  try:
189
- if file_extension == "pdf":
190
  with fitz.open(file_path) as doc:
191
  text = "\n".join(page.get_text("text") for page in doc)
192
  if len(text.strip()) < 50:
@@ -196,29 +197,24 @@ def extract_text(file_path: str, file_extension: str):
196
  ocr_result = reader.readtext(temp_img.name, detail=0)
197
  os.unlink(temp_img.name)
198
  text = "\n".join(ocr_result) if ocr_result else text
199
- return clean_text(text), ""
200
-
201
- elif file_extension == "docx":
202
  doc = docx.Document(file_path)
203
- return clean_text("\n".join(p.text for p in doc.paragraphs)), ""
204
-
205
- elif file_extension == "pptx":
206
  prs = pptx.Presentation(file_path)
207
- text = [shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]
208
- return clean_text("\n".join(text)), ""
209
-
210
- elif file_extension == "xlsx":
211
  wb = openpyxl.load_workbook(file_path, read_only=True)
212
- text = [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
213
- return clean_text("\n".join(text)), ""
214
-
215
- elif file_extension in ["jpg", "jpeg", "png"]:
216
- ocr_result = reader.readtext(file_path, detail=0)
217
- return clean_text("\n".join(ocr_result)), ""
 
218
 
219
- return "", "Unsupported file format"
220
  except Exception as e:
221
- return "", f"Error reading {file_extension.upper()} file: {str(e)}"
222
 
223
  def chunk_text(text: str, max_tokens: int = 950):
224
  try:
@@ -243,7 +239,7 @@ def chunk_text(text: str, max_tokens: int = 950):
243
 
244
  return chunks
245
 
246
- def generate_summary(text: str, length: str = "medium") -> str:
247
  cache_key = hashlib.md5((text + length).encode()).hexdigest()
248
  if cache_key in summary_cache:
249
  return summary_cache[cache_key]
@@ -283,14 +279,14 @@ def text_to_speech(text: str):
283
  except Exception:
284
  return ""
285
 
286
- def create_pdf(summary: str, original_filename: str):
287
  try:
288
  pdf = FPDF()
289
  pdf.add_page()
290
  pdf.set_font("Arial", 'B', 16)
291
  pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
292
  pdf.set_font("Arial", size=12)
293
- pdf.cell(200, 10, txt=f"Original file: {original_filename}", ln=1)
294
  pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
295
  pdf.ln(10)
296
  pdf.multi_cell(0, 10, txt=summary)
@@ -300,49 +296,50 @@ def create_pdf(summary: str, original_filename: str):
300
  except Exception:
301
  return ""
302
 
303
- # --- API Endpoints ---
304
 
305
- @app.post("/summarize/")
306
- async def summarize_api(file: UploadFile = File(...), length: str = Form("medium")):
307
  try:
308
- contents = await file.read()
309
- with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
310
- tmp_file.write(contents)
 
 
 
 
 
 
311
  tmp_path = tmp_file.name
312
 
313
- file_ext = tmp_path.split('.')[-1].lower()
314
- text, error = extract_text(tmp_path, file_ext)
315
 
316
  if error:
317
- return JSONResponse({"detail": error}, status_code=400)
 
318
 
319
  if not text or len(text.split()) < 30:
320
- return JSONResponse({"detail": "Document too short to summarize"}, status_code=400)
 
321
 
 
322
  summary = generate_summary(text, length)
 
 
323
  audio_path = text_to_speech(summary)
324
- pdf_path = create_pdf(summary, file.filename)
325
 
 
 
 
 
326
  response = {"summary": summary}
327
  if audio_path:
328
  response["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
329
  if pdf_path:
330
  response["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
331
 
332
- return JSONResponse(response)
333
 
334
  except Exception as e:
335
- print(f"Error during summarization: {str(e)}")
336
- return JSONResponse({"detail": f"Internal server error: {str(e)}"}, status_code=500)
337
-
338
- @app.get("/files/{file_name}")
339
- async def serve_file(file_name: str):
340
- path = os.path.join(tempfile.gettempdir(), file_name)
341
- if os.path.exists(path):
342
- return FileResponse(path)
343
- return JSONResponse({"error": "File not found"}, status_code=404)
344
-
345
- @app.get("/")
346
- def home():
347
- return RedirectResponse(url="/")
348
-
 
139
  if pdf_path:
140
  result["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
141
  return result"""
142
+ # app.py
143
+
144
+ from fastapi import UploadFile, File
145
  from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
146
  import fitz # PyMuPDF
147
  import docx
 
159
  import datetime
160
  import hashlib
161
 
162
+ # Setup
163
  nltk.download('punkt', quiet=True)
 
164
 
165
+ # Load Models
166
  MODEL_NAME = "facebook/bart-large-cnn"
167
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
168
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
169
  model.eval()
170
  summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
171
 
 
172
  reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
173
 
 
174
  summary_cache = {}
175
 
176
+ # Allowed file extensions
177
+ ALLOWED_EXTENSIONS = {'pdf', 'docx', 'pptx', 'xlsx'}
178
+
179
  # --- Helper Functions ---
180
 
181
  def clean_text(text: str) -> str:
 
185
  text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
186
  return text.strip()
187
 
188
+ def extract_text(file_path: str, extension: str):
189
  try:
190
+ if extension == "pdf":
191
  with fitz.open(file_path) as doc:
192
  text = "\n".join(page.get_text("text") for page in doc)
193
  if len(text.strip()) < 50:
 
197
  ocr_result = reader.readtext(temp_img.name, detail=0)
198
  os.unlink(temp_img.name)
199
  text = "\n".join(ocr_result) if ocr_result else text
200
+ elif extension == "docx":
 
 
201
  doc = docx.Document(file_path)
202
+ text = "\n".join(p.text for p in doc.paragraphs)
203
+ elif extension == "pptx":
 
204
  prs = pptx.Presentation(file_path)
205
+ text = "\n".join(shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text"))
206
+ elif extension == "xlsx":
 
 
207
  wb = openpyxl.load_workbook(file_path, read_only=True)
208
+ text = "\n".join(
209
+ [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
210
+ )
211
+ else:
212
+ return "", "Unsupported file format."
213
+
214
+ return clean_text(text), ""
215
 
 
216
  except Exception as e:
217
+ return "", f"Error reading {extension.upper()} file: {str(e)}"
218
 
219
  def chunk_text(text: str, max_tokens: int = 950):
220
  try:
 
239
 
240
  return chunks
241
 
242
+ def generate_summary(text: str, length: str = "medium"):
243
  cache_key = hashlib.md5((text + length).encode()).hexdigest()
244
  if cache_key in summary_cache:
245
  return summary_cache[cache_key]
 
279
  except Exception:
280
  return ""
281
 
282
+ def create_pdf(summary: str, filename: str):
283
  try:
284
  pdf = FPDF()
285
  pdf.add_page()
286
  pdf.set_font("Arial", 'B', 16)
287
  pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
288
  pdf.set_font("Arial", size=12)
289
+ pdf.cell(200, 10, txt=f"Original file: {filename}", ln=1)
290
  pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
291
  pdf.ln(10)
292
  pdf.multi_cell(0, 10, txt=summary)
 
296
  except Exception:
297
  return ""
298
 
299
+ # --- Public API Function ---
300
 
301
+ async def summarize_document(file: UploadFile, length: str = "medium"):
 
302
  try:
303
+ filename = file.filename
304
+ extension = os.path.splitext(filename)[-1].lower().replace('.', '')
305
+
306
+ if extension not in ALLOWED_EXTENSIONS:
307
+ raise Exception(f"Unsupported file type: {extension.upper()}. Only PDF, DOCX, PPTX, XLSX are allowed.")
308
+
309
+ # Save uploaded file
310
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f".{extension}") as tmp_file:
311
+ tmp_file.write(await file.read())
312
  tmp_path = tmp_file.name
313
 
314
+ # Extract text
315
+ text, error = extract_text(tmp_path, extension)
316
 
317
  if error:
318
+ os.unlink(tmp_path)
319
+ raise Exception(error)
320
 
321
  if not text or len(text.split()) < 30:
322
+ os.unlink(tmp_path)
323
+ raise Exception("Document too short to summarize.")
324
 
325
+ # Summarize
326
  summary = generate_summary(text, length)
327
+
328
+ # Create audio + PDF
329
  audio_path = text_to_speech(summary)
330
+ pdf_path = create_pdf(summary, filename)
331
 
332
+ # Clean temp file
333
+ os.unlink(tmp_path)
334
+
335
+ # Prepare response
336
  response = {"summary": summary}
337
  if audio_path:
338
  response["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
339
  if pdf_path:
340
  response["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
341
 
342
+ return response
343
 
344
  except Exception as e:
345
+ raise Exception(f"Summarization failed: {str(e)}")