Spaces:
Running
Running
from fastapi import FastAPI, UploadFile, Form | |
from fastapi.responses import JSONResponse, FileResponse | |
import fitz # PyMuPDF for PDFs | |
import easyocr # OCR for images | |
import openpyxl # XLSX processing | |
import pptx # PPTX processing | |
import docx # DOCX processing | |
from transformers import pipeline | |
from gtts import gTTS | |
import tempfile | |
import os | |
app = FastAPI() | |
# Load AI models | |
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
reader = easyocr.Reader(['en', 'fr']) | |
# Text Extraction | |
def extract_text_from_pdf(pdf_file): | |
text = [] | |
try: | |
with fitz.open(pdf_file) as doc: | |
for page in doc: | |
text.append(page.get_text("text")) | |
except Exception as e: | |
return f"Error reading PDF: {e}" | |
return "\n".join(text) | |
def extract_text_from_docx(docx_file): | |
doc = docx.Document(docx_file) | |
return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) | |
def extract_text_from_pptx(pptx_file): | |
text = [] | |
try: | |
prs = pptx.Presentation(pptx_file) | |
for slide in prs.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text.append(shape.text) | |
except Exception as e: | |
return f"Error reading PPTX: {e}" | |
return "\n".join(text) | |
def extract_text_from_xlsx(xlsx_file): | |
text = [] | |
try: | |
wb = openpyxl.load_workbook(xlsx_file) | |
for sheet in wb.sheetnames: | |
ws = wb[sheet] | |
for row in ws.iter_rows(values_only=True): | |
text.append(" ".join(str(cell) for cell in row if cell)) | |
except Exception as e: | |
return f"Error reading XLSX: {e}" | |
return "\n".join(text) | |
# Main QA logic | |
def answer_question_from_doc(file, question): | |
ext = file.filename.split(".")[-1].lower() | |
file_path = f"/tmp/{file.filename}" | |
with open(file_path, "wb") as f: | |
f.write(file.file.read()) | |
if ext == "pdf": | |
context = extract_text_from_pdf(file_path) | |
elif ext == "docx": | |
context = extract_text_from_docx(file_path) | |
elif ext == "pptx": | |
context = extract_text_from_pptx(file_path) | |
elif ext == "xlsx": | |
context = extract_text_from_xlsx(file_path) | |
else: | |
return "Unsupported file format.", None | |
if not context.strip(): | |
return "No text found in the document.", None | |
try: | |
result = qa_model({"question": question, "context": context}) | |
answer = result["answer"] | |
tts = gTTS(answer) | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp: | |
tts.save(tmp.name) | |
audio_path = tmp.name | |
return answer, audio_path | |
except Exception as e: | |
return f"Error generating answer: {e}", None | |
# API route for prediction | |
async def predict(file: UploadFile, question: str = Form(...)): | |
answer, audio_path = answer_question_from_doc(file, question) | |
if audio_path: | |
return JSONResponse(content={"answer": answer, "audio": f"/audio/{os.path.basename(audio_path)}"}) | |
else: | |
return JSONResponse(content={"answer": answer}) | |
# Route to serve audio | |
async def get_audio(filename: str): | |
file_path = os.path.join(tempfile.gettempdir(), filename) | |
return FileResponse(path=file_path, media_type="audio/mpeg") | |