ikraamkb commited on
Commit
b0f9e39
·
verified ·
1 Parent(s): 734e13c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +209 -2
app.py CHANGED
@@ -1,5 +1,5 @@
1
  # app.py
2
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import fitz, docx, pptx, openpyxl, re, nltk, tempfile, os, easyocr, datetime, hashlib
4
  from nltk.tokenize import sent_tokenize
5
  from fpdf import FPDF
@@ -138,4 +138,211 @@ async def summarize_document(file, length="medium"):
138
  result["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
139
  if pdf_path:
140
  result["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
141
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # app.py
2
+ """from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import fitz, docx, pptx, openpyxl, re, nltk, tempfile, os, easyocr, datetime, hashlib
4
  from nltk.tokenize import sent_tokenize
5
  from fpdf import FPDF
 
138
  result["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
139
  if pdf_path:
140
  result["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
141
+ return result"""
142
+ from fastapi import FastAPI, UploadFile, File, Form
143
+ from fastapi.responses import RedirectResponse, FileResponse, JSONResponse
144
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
145
+ import fitz # PyMuPDF
146
+ import docx
147
+ import pptx
148
+ import openpyxl
149
+ import re
150
+ import nltk
151
+ import torch
152
+ from nltk.tokenize import sent_tokenize
153
+ from gtts import gTTS
154
+ from fpdf import FPDF
155
+ import tempfile
156
+ import os
157
+ import easyocr
158
+ import datetime
159
+ import hashlib
160
+
161
+ # Initialize
162
+ nltk.download('punkt', quiet=True)
163
+ app = FastAPI()
164
+
165
+ # Load Summarizer Model
166
+ MODEL_NAME = "facebook/bart-large-cnn"
167
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
168
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
169
+ model.eval()
170
+ summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1, batch_size=4)
171
+
172
+ # Load OCR Reader
173
+ reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
174
+
175
+ # Cache
176
+ summary_cache = {}
177
+
178
+ # --- Helper Functions ---
179
+
180
+ def clean_text(text: str) -> str:
181
+ text = re.sub(r'\s+', ' ', text)
182
+ text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
183
+ text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
184
+ text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
185
+ return text.strip()
186
+
187
+ def extract_text(file_path: str, file_extension: str):
188
+ try:
189
+ if file_extension == "pdf":
190
+ with fitz.open(file_path) as doc:
191
+ text = "\n".join(page.get_text("text") for page in doc)
192
+ if len(text.strip()) < 50:
193
+ images = [page.get_pixmap() for page in doc]
194
+ temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
195
+ images[0].save(temp_img.name)
196
+ ocr_result = reader.readtext(temp_img.name, detail=0)
197
+ os.unlink(temp_img.name)
198
+ text = "\n".join(ocr_result) if ocr_result else text
199
+ return clean_text(text), ""
200
+
201
+ elif file_extension == "docx":
202
+ doc = docx.Document(file_path)
203
+ return clean_text("\n".join(p.text for p in doc.paragraphs)), ""
204
+
205
+ elif file_extension == "pptx":
206
+ prs = pptx.Presentation(file_path)
207
+ text = [shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]
208
+ return clean_text("\n".join(text)), ""
209
+
210
+ elif file_extension == "xlsx":
211
+ wb = openpyxl.load_workbook(file_path, read_only=True)
212
+ text = [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
213
+ return clean_text("\n".join(text)), ""
214
+
215
+ elif file_extension in ["jpg", "jpeg", "png"]:
216
+ ocr_result = reader.readtext(file_path, detail=0)
217
+ return clean_text("\n".join(ocr_result)), ""
218
+
219
+ return "", "Unsupported file format"
220
+ except Exception as e:
221
+ return "", f"Error reading {file_extension.upper()} file: {str(e)}"
222
+
223
+ def chunk_text(text: str, max_tokens: int = 950):
224
+ try:
225
+ sentences = sent_tokenize(text)
226
+ except:
227
+ words = text.split()
228
+ sentences = [' '.join(words[i:i+20]) for i in range(0, len(words), 20)]
229
+
230
+ chunks = []
231
+ current_chunk = ""
232
+ for sentence in sentences:
233
+ token_length = len(tokenizer.encode(current_chunk + " " + sentence))
234
+ if token_length <= max_tokens:
235
+ current_chunk += " " + sentence
236
+ else:
237
+ if current_chunk.strip():
238
+ chunks.append(current_chunk.strip())
239
+ current_chunk = sentence
240
+
241
+ if current_chunk.strip():
242
+ chunks.append(current_chunk.strip())
243
+
244
+ return chunks
245
+
246
+ def generate_summary(text: str, length: str = "medium") -> str:
247
+ cache_key = hashlib.md5((text + length).encode()).hexdigest()
248
+ if cache_key in summary_cache:
249
+ return summary_cache[cache_key]
250
+
251
+ length_params = {
252
+ "short": {"max_length": 80, "min_length": 30},
253
+ "medium": {"max_length": 200, "min_length": 80},
254
+ "long": {"max_length": 300, "min_length": 210}
255
+ }
256
+ chunks = chunk_text(text)
257
+
258
+ summaries = summarizer(
259
+ chunks,
260
+ max_length=length_params[length]["max_length"],
261
+ min_length=length_params[length]["min_length"],
262
+ do_sample=False,
263
+ truncation=True,
264
+ no_repeat_ngram_size=2,
265
+ num_beams=2,
266
+ early_stopping=True
267
+ )
268
+ summary_texts = [s['summary_text'] for s in summaries]
269
+
270
+ final_summary = " ".join(summary_texts)
271
+ final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
272
+ final_summary = final_summary if len(final_summary) > 25 else "Summary too short - document may be too brief"
273
+
274
+ summary_cache[cache_key] = final_summary
275
+ return final_summary
276
+
277
+ def text_to_speech(text: str):
278
+ try:
279
+ tts = gTTS(text)
280
+ temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
281
+ tts.save(temp_audio.name)
282
+ return temp_audio.name
283
+ except Exception:
284
+ return ""
285
+
286
+ def create_pdf(summary: str, original_filename: str):
287
+ try:
288
+ pdf = FPDF()
289
+ pdf.add_page()
290
+ pdf.set_font("Arial", 'B', 16)
291
+ pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
292
+ pdf.set_font("Arial", size=12)
293
+ pdf.cell(200, 10, txt=f"Original file: {original_filename}", ln=1)
294
+ pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
295
+ pdf.ln(10)
296
+ pdf.multi_cell(0, 10, txt=summary)
297
+ temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
298
+ pdf.output(temp_pdf.name)
299
+ return temp_pdf.name
300
+ except Exception:
301
+ return ""
302
+
303
+ # --- API Endpoints ---
304
+
305
+ @app.post("/summarize/")
306
+ async def summarize_api(file: UploadFile = File(...), length: str = Form("medium")):
307
+ try:
308
+ contents = await file.read()
309
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
310
+ tmp_file.write(contents)
311
+ tmp_path = tmp_file.name
312
+
313
+ file_ext = tmp_path.split('.')[-1].lower()
314
+ text, error = extract_text(tmp_path, file_ext)
315
+
316
+ if error:
317
+ return JSONResponse({"detail": error}, status_code=400)
318
+
319
+ if not text or len(text.split()) < 30:
320
+ return JSONResponse({"detail": "Document too short to summarize"}, status_code=400)
321
+
322
+ summary = generate_summary(text, length)
323
+ audio_path = text_to_speech(summary)
324
+ pdf_path = create_pdf(summary, file.filename)
325
+
326
+ response = {"summary": summary}
327
+ if audio_path:
328
+ response["audioUrl"] = f"/files/{os.path.basename(audio_path)}"
329
+ if pdf_path:
330
+ response["pdfUrl"] = f"/files/{os.path.basename(pdf_path)}"
331
+
332
+ return JSONResponse(response)
333
+
334
+ except Exception as e:
335
+ print(f"Error during summarization: {str(e)}")
336
+ return JSONResponse({"detail": f"Internal server error: {str(e)}"}, status_code=500)
337
+
338
+ @app.get("/files/{file_name}")
339
+ async def serve_file(file_name: str):
340
+ path = os.path.join(tempfile.gettempdir(), file_name)
341
+ if os.path.exists(path):
342
+ return FileResponse(path)
343
+ return JSONResponse({"error": "File not found"}, status_code=404)
344
+
345
+ @app.get("/")
346
+ def home():
347
+ return RedirectResponse(url="/")
348
+