ikraamkb commited on
Commit
e540abd
·
verified ·
1 Parent(s): 2a21b95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -1,4 +1,5 @@
1
- # app.py
 
2
  import fitz # PyMuPDF for PDFs
3
  import easyocr # OCR for images
4
  import openpyxl # XLSX processing
@@ -7,12 +8,15 @@ import docx # DOCX processing
7
  from transformers import pipeline
8
  from gtts import gTTS
9
  import tempfile
 
 
10
  app = FastAPI()
11
- # Initialize AI Models
 
12
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
13
- reader = easyocr.Reader(['en', 'fr']) # OCR for English & French
14
 
15
- # ---- TEXT EXTRACTION FUNCTIONS ----
16
  def extract_text_from_pdf(pdf_file):
17
  text = []
18
  try:
@@ -25,13 +29,13 @@ def extract_text_from_pdf(pdf_file):
25
 
26
  def extract_text_from_docx(docx_file):
27
  doc = docx.Document(docx_file)
28
- return "\n".join([p.text for p in doc.paragraphs if p.text.strip()])
29
 
30
  def extract_text_from_pptx(pptx_file):
31
  text = []
32
  try:
33
- presentation = pptx.Presentation(pptx_file)
34
- for slide in presentation.slides:
35
  for shape in slide.shapes:
36
  if hasattr(shape, "text"):
37
  text.append(shape.text)
@@ -51,18 +55,22 @@ def extract_text_from_xlsx(xlsx_file):
51
  return f"Error reading XLSX: {e}"
52
  return "\n".join(text)
53
 
54
- # ---- MAIN QA FUNCTION ----
55
  def answer_question_from_doc(file, question):
56
- ext = file.name.split(".")[-1].lower()
 
 
 
 
57
 
58
  if ext == "pdf":
59
- context = extract_text_from_pdf(file.name)
60
  elif ext == "docx":
61
- context = extract_text_from_docx(file.name)
62
  elif ext == "pptx":
63
- context = extract_text_from_pptx(file.name)
64
  elif ext == "xlsx":
65
- context = extract_text_from_xlsx(file.name)
66
  else:
67
  return "Unsupported file format.", None
68
 
@@ -72,10 +80,25 @@ def answer_question_from_doc(file, question):
72
  try:
73
  result = qa_model({"question": question, "context": context})
74
  answer = result["answer"]
75
- tts = gTTS(text=answer)
76
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
77
  tts.save(tmp.name)
78
  audio_path = tmp.name
79
  return answer, audio_path
80
  except Exception as e:
81
  return f"Error generating answer: {e}", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, Form
2
+ from fastapi.responses import JSONResponse, FileResponse
3
  import fitz # PyMuPDF for PDFs
4
  import easyocr # OCR for images
5
  import openpyxl # XLSX processing
 
8
  from transformers import pipeline
9
  from gtts import gTTS
10
  import tempfile
11
+ import os
12
+
13
  app = FastAPI()
14
+
15
+ # Load AI models
16
  qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
17
+ reader = easyocr.Reader(['en', 'fr'])
18
 
19
+ # Text Extraction
20
  def extract_text_from_pdf(pdf_file):
21
  text = []
22
  try:
 
29
 
30
  def extract_text_from_docx(docx_file):
31
  doc = docx.Document(docx_file)
32
+ return "\n".join(p.text for p in doc.paragraphs if p.text.strip())
33
 
34
  def extract_text_from_pptx(pptx_file):
35
  text = []
36
  try:
37
+ prs = pptx.Presentation(pptx_file)
38
+ for slide in prs.slides:
39
  for shape in slide.shapes:
40
  if hasattr(shape, "text"):
41
  text.append(shape.text)
 
55
  return f"Error reading XLSX: {e}"
56
  return "\n".join(text)
57
 
58
+ # Main QA logic
59
  def answer_question_from_doc(file, question):
60
+ ext = file.filename.split(".")[-1].lower()
61
+ file_path = f"/tmp/{file.filename}"
62
+
63
+ with open(file_path, "wb") as f:
64
+ f.write(file.file.read())
65
 
66
  if ext == "pdf":
67
+ context = extract_text_from_pdf(file_path)
68
  elif ext == "docx":
69
+ context = extract_text_from_docx(file_path)
70
  elif ext == "pptx":
71
+ context = extract_text_from_pptx(file_path)
72
  elif ext == "xlsx":
73
+ context = extract_text_from_xlsx(file_path)
74
  else:
75
  return "Unsupported file format.", None
76
 
 
80
  try:
81
  result = qa_model({"question": question, "context": context})
82
  answer = result["answer"]
83
+ tts = gTTS(answer)
84
  with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp:
85
  tts.save(tmp.name)
86
  audio_path = tmp.name
87
  return answer, audio_path
88
  except Exception as e:
89
  return f"Error generating answer: {e}", None
90
+
91
+ # API route for prediction
92
+ @app.post("/predict")
93
+ async def predict(file: UploadFile, question: str = Form(...)):
94
+ answer, audio_path = answer_question_from_doc(file, question)
95
+ if audio_path:
96
+ return JSONResponse(content={"answer": answer, "audio": f"/audio/{os.path.basename(audio_path)}"})
97
+ else:
98
+ return JSONResponse(content={"answer": answer})
99
+
100
+ # Route to serve audio
101
+ @app.get("/audio/{filename}")
102
+ async def get_audio(filename: str):
103
+ file_path = os.path.join(tempfile.gettempdir(), filename)
104
+ return FileResponse(path=file_path, media_type="audio/mpeg")