ikraamkb commited on
Commit
a377214
·
verified ·
1 Parent(s): 5c4195a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +2 -299
main.py CHANGED
@@ -1,4 +1,4 @@
1
- """from fastapi import FastAPI, UploadFile, File, Form, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
@@ -72,301 +72,4 @@ async def predict(
72
  else:
73
  return JSONResponse({"error": "Invalid option"}, status_code=400)
74
  except Exception as e:
75
- return JSONResponse({"error": f"Prediction failed: {str(e)}"}, status_code=500) """
76
- from fastapi import FastAPI, UploadFile, File, Form, Request, HTTPException
77
- from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
78
- from fastapi.staticfiles import StaticFiles
79
- from fastapi.templating import Jinja2Templates
80
- from fastapi.middleware.cors import CORSMiddleware
81
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM, AutoProcessor, AutoModelForCausalLM
82
- from PIL import Image
83
- import torch
84
- import fitz # PyMuPDF
85
- import docx
86
- import pptx
87
- import openpyxl
88
- import re
89
- import nltk
90
- from nltk.tokenize import sent_tokenize
91
- from gtts import gTTS
92
- from fpdf import FPDF
93
- import tempfile
94
- import os
95
- import shutil
96
- import datetime
97
- import hashlib
98
- import easyocr
99
- from typing import Optional
100
-
101
- # Initialize app
102
- app = FastAPI()
103
-
104
- # CORS Configuration
105
- app.add_middleware(
106
- CORSMiddleware,
107
- allow_origins=["*"],
108
- allow_credentials=True,
109
- allow_methods=["*"],
110
- allow_headers=["*"],
111
- )
112
-
113
- # Static assets
114
- app.mount("/static", StaticFiles(directory="static"), name="static")
115
- app.mount("/resources", StaticFiles(directory="resources"), name="resources")
116
-
117
- # Templates
118
- templates = Jinja2Templates(directory="templates")
119
-
120
- # Initialize models
121
- nltk.download('punkt', quiet=True)
122
-
123
- # Document processing models
124
- try:
125
- tokenizer = AutoTokenizer.from_pretrained("facebook/bart-large-cnn")
126
- model = AutoModelForSeq2SeqLM.from_pretrained("facebook/bart-large-cnn")
127
- model.eval()
128
- summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=-1)
129
- reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
130
- except Exception as e:
131
- print(f"Error loading summarization models: {e}")
132
- summarizer = None
133
-
134
- # Image captioning models
135
- try:
136
- processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
137
- git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
138
- git_model.eval()
139
- USE_GIT = True
140
- except Exception as e:
141
- print(f"Error loading GIT model, falling back to ViT: {e}")
142
- captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
143
- USE_GIT = False
144
-
145
- # Helper functions
146
- def clean_text(text: str) -> str:
147
- text = re.sub(r'\s+', ' ', text)
148
- text = re.sub(r'\u2022\s*|\d\.\s+', '', text)
149
- text = re.sub(r'\[.*?\]|\(.*?\)', '', text)
150
- text = re.sub(r'\bPage\s*\d+\b', '', text, flags=re.IGNORECASE)
151
- return text.strip()
152
-
153
- def extract_text(file_path: str, file_extension: str):
154
- try:
155
- if file_extension == "pdf":
156
- with fitz.open(file_path) as doc:
157
- text = "\n".join(page.get_text("text") for page in doc)
158
- if len(text.strip()) < 50:
159
- images = [page.get_pixmap() for page in doc]
160
- temp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
161
- images[0].save(temp_img.name)
162
- ocr_result = reader.readtext(temp_img.name, detail=0)
163
- os.unlink(temp_img.name)
164
- text = "\n".join(ocr_result) if ocr_result else text
165
- return clean_text(text), ""
166
-
167
- elif file_extension == "docx":
168
- doc = docx.Document(file_path)
169
- return clean_text("\n".join(p.text for p in doc.paragraphs)), ""
170
-
171
- elif file_extension == "pptx":
172
- prs = pptx.Presentation(file_path)
173
- text = [shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]
174
- return clean_text("\n".join(text)), ""
175
-
176
- elif file_extension == "xlsx":
177
- wb = openpyxl.load_workbook(file_path, read_only=True)
178
- text = [" ".join(str(cell) for cell in row if cell) for sheet in wb.sheetnames for row in wb[sheet].iter_rows(values_only=True)]
179
- return clean_text("\n".join(text)), ""
180
-
181
- return "", "Unsupported file format"
182
- except Exception as e:
183
- return "", f"Error reading {file_extension.upper()} file: {str(e)}"
184
-
185
- def chunk_text(text: str, max_tokens: int = 950):
186
- try:
187
- sentences = sent_tokenize(text)
188
- except:
189
- words = text.split()
190
- sentences = [' '.join(words[i:i+20]) for i in range(0, len(words), 20)]
191
-
192
- chunks = []
193
- current_chunk = ""
194
- for sentence in sentences:
195
- token_length = len(tokenizer.encode(current_chunk + " " + sentence))
196
- if token_length <= max_tokens:
197
- current_chunk += " " + sentence
198
- else:
199
- chunks.append(current_chunk.strip())
200
- current_chunk = sentence
201
-
202
- if current_chunk:
203
- chunks.append(current_chunk.strip())
204
-
205
- return chunks
206
-
207
- def generate_summary(text: str, length: str = "medium") -> str:
208
- cache_key = hashlib.md5((text + length).encode()).hexdigest()
209
-
210
- length_params = {
211
- "short": {"max_length": 80, "min_length": 30},
212
- "medium": {"max_length": 200, "min_length": 80},
213
- "long": {"max_length": 300, "min_length": 210}
214
- }
215
-
216
- chunks = chunk_text(text)
217
- try:
218
- summaries = summarizer(
219
- chunks,
220
- max_length=length_params[length]["max_length"],
221
- min_length=length_params[length]["min_length"],
222
- do_sample=False,
223
- truncation=True
224
- )
225
- summary_texts = [s['summary_text'] for s in summaries]
226
- except Exception as e:
227
- summary_texts = [f"[Error: {str(e)}"]
228
-
229
- final_summary = " ".join(summary_texts)
230
- final_summary = ". ".join(s.strip().capitalize() for s in final_summary.split(". ") if s.strip())
231
- return final_summary if len(final_summary) > 25 else "Summary too short"
232
-
233
- def text_to_speech(text: str):
234
- try:
235
- tts = gTTS(text)
236
- temp_audio = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
237
- tts.save(temp_audio.name)
238
- return temp_audio.name
239
- except Exception as e:
240
- print(f"Error in text-to-speech: {e}")
241
- return ""
242
-
243
- def create_pdf(summary: str, original_filename: str):
244
- try:
245
- pdf = FPDF()
246
- pdf.add_page()
247
- pdf.set_font("Arial", 'B', 16)
248
- pdf.cell(200, 10, txt="Document Summary", ln=1, align='C')
249
- pdf.set_font("Arial", size=12)
250
- pdf.cell(200, 10, txt=f"Original file: {original_filename}", ln=1)
251
- pdf.cell(200, 10, txt=f"Generated on: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}", ln=1)
252
- pdf.ln(10)
253
- pdf.multi_cell(0, 10, txt=summary)
254
- temp_pdf = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
255
- pdf.output(temp_pdf.name)
256
- return temp_pdf.name
257
- except Exception as e:
258
- print(f"Error creating PDF: {e}")
259
- return ""
260
-
261
- def generate_caption(image_path: str) -> str:
262
- try:
263
- if USE_GIT:
264
- image = Image.open(image_path).convert("RGB")
265
- inputs = processor(images=image, return_tensors="pt")
266
- outputs = git_model.generate(**inputs, max_length=50)
267
- caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
268
- else:
269
- result = captioner(image_path)
270
- caption = result[0]['generated_text']
271
- return caption
272
- except Exception as e:
273
- raise Exception(f"Caption generation failed: {str(e)}")
274
-
275
- # API Endpoints
276
- @app.post("/summarize/")
277
- async def summarize_document(file: UploadFile = File(...), length: str = Form("medium")):
278
- valid_types = [
279
- 'application/pdf',
280
- 'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
281
- 'application/vnd.openxmlformats-officedocument.presentationml.presentation',
282
- 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet'
283
- ]
284
-
285
- if file.content_type not in valid_types:
286
- raise HTTPException(
287
- status_code=400,
288
- detail="Please upload a valid document (PDF, DOCX, PPTX, or XLSX)"
289
- )
290
-
291
- try:
292
- # Save temp file
293
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
294
- shutil.copyfileobj(file.file, temp)
295
- temp_path = temp.name
296
-
297
- # Process file
298
- text, error = extract_text(temp_path, os.path.splitext(file.filename)[1][1:].lower())
299
- if error:
300
- raise HTTPException(status_code=400, detail=error)
301
-
302
- if not text or len(text.split()) < 30:
303
- raise HTTPException(status_code=400, detail="Document too short to summarize")
304
-
305
- summary = generate_summary(text, length)
306
- audio_path = text_to_speech(summary)
307
- pdf_path = create_pdf(summary, file.filename)
308
-
309
- return {
310
- "summary": summary,
311
- "audio_url": f"/files/{os.path.basename(audio_path)}" if audio_path else None,
312
- "pdf_url": f"/files/{os.path.basename(pdf_path)}" if pdf_path else None
313
- }
314
-
315
- except HTTPException:
316
- raise
317
- except Exception as e:
318
- raise HTTPException(
319
- status_code=500,
320
- detail=f"Summarization failed: {str(e)}"
321
- )
322
- finally:
323
- if 'temp_path' in locals() and os.path.exists(temp_path):
324
- os.unlink(temp_path)
325
-
326
- @app.post("/imagecaption/")
327
- async def caption_image(file: UploadFile = File(...)):
328
- valid_types = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']
329
- if file.content_type not in valid_types:
330
- raise HTTPException(
331
- status_code=400,
332
- detail="Please upload a valid image (JPEG, PNG, GIF, or WEBP)"
333
- )
334
-
335
- try:
336
- # Save temp file
337
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as temp:
338
- shutil.copyfileobj(file.file, temp)
339
- temp_path = temp.name
340
-
341
- # Generate caption
342
- caption = generate_caption(temp_path)
343
-
344
- # Generate audio
345
- audio_path = text_to_speech(caption)
346
-
347
- return {
348
- "answer": caption,
349
- "audio": f"/files/{os.path.basename(audio_path)}" if audio_path else None
350
- }
351
-
352
- except HTTPException:
353
- raise
354
- except Exception as e:
355
- raise HTTPException(
356
- status_code=500,
357
- detail=str(e)
358
- )
359
- finally:
360
- if 'temp_path' in locals() and os.path.exists(temp_path):
361
- os.unlink(temp_path)
362
-
363
- @app.get("/files/{filename}")
364
- async def serve_file(filename: str):
365
- file_path = os.path.join(tempfile.gettempdir(), filename)
366
- if os.path.exists(file_path):
367
- return FileResponse(file_path)
368
- raise HTTPException(status_code=404, detail="File not found")
369
-
370
- @app.get("/", response_class=HTMLResponse)
371
- async def serve_home(request: Request):
372
- return templates.TemplateResponse("HomeS.html", {"request": request})
 
1
+ from fastapi import FastAPI, UploadFile, File, Form, Request
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
4
  from fastapi.staticfiles import StaticFiles
 
72
  else:
73
  return JSONResponse({"error": "Invalid option"}, status_code=400)
74
  except Exception as e:
75
+ return JSONResponse({"error": f"Prediction failed: {str(e)}"}, status_code=500)