Spaces:
Running
Running
from fastapi import FastAPI, Form, File, UploadFile | |
from fastapi.responses import RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import os | |
from PIL import Image | |
import io | |
import pdfplumber | |
import docx | |
import openpyxl | |
import pytesseract | |
from io import BytesIO | |
import fitz # PyMuPDF | |
import easyocr | |
from fastapi.templating import Jinja2Templates | |
from starlette.requests import Request | |
# Initialize the app | |
app = FastAPI() | |
# Mount the static directory to serve HTML, CSS, JS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize transformers pipelines | |
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2") | |
image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base") | |
# Initialize EasyOCR for image-based text extraction | |
reader = easyocr.Reader(['en']) | |
# Define a template for rendering HTML | |
templates = Jinja2Templates(directory="templates") | |
# Ensure temp_files directory exists | |
temp_dir = "temp_files" | |
os.makedirs(temp_dir, exist_ok=True) | |
# Function to process PDFs | |
def extract_pdf_text(file_path: str): | |
with pdfplumber.open(file_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
# Function to process DOCX files | |
def extract_docx_text(file_path: str): | |
doc = docx.Document(file_path) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
return text | |
# Function to process PPTX files | |
def extract_pptx_text(file_path: str): | |
from pptx import Presentation | |
prs = Presentation(file_path) | |
text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]) | |
return text | |
# Function to extract text from images using OCR | |
def extract_text_from_image(image: Image): | |
return pytesseract.image_to_string(image) | |
# Home route | |
def home(): | |
return RedirectResponse(url="/docs") | |
# Function to answer questions based on document content | |
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)): | |
file_path = os.path.join(temp_dir, file.filename) | |
with open(file_path, "wb") as f: | |
f.write(await file.read()) | |
if file.filename.endswith(".pdf"): | |
text = extract_pdf_text(file_path) | |
elif file.filename.endswith(".docx"): | |
text = extract_docx_text(file_path) | |
elif file.filename.endswith(".pptx"): | |
text = extract_pptx_text(file_path) | |
else: | |
return {"error": "Unsupported file format"} | |
qa_result = qa_pipeline(question=question, context=text) | |
return {"answer": qa_result['answer']} | |
# Function to answer questions based on images | |
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)): | |
image = Image.open(BytesIO(await image_file.read())) | |
image_text = extract_text_from_image(image) | |
image_qa_result = image_qa_pipeline({"image": image, "question": question}) | |
return {"answer": image_qa_result[0]['answer'], "image_text": image_text} | |
# Serve the application in Hugging Face space | |
async def get_docs(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |