Spaces:
Running
Running
"""from fastapi import FastAPI, Form, File, UploadFile | |
from fastapi.responses import RedirectResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import os | |
from PIL import Image | |
import io | |
import pdfplumber | |
import docx | |
import openpyxl | |
import pytesseract | |
from io import BytesIO | |
import fitz # PyMuPDF | |
import easyocr | |
from fastapi.templating import Jinja2Templates | |
from starlette.requests import Request | |
# Initialize the app | |
app = FastAPI() | |
# Mount the static directory to serve HTML, CSS, JS files | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
# Initialize transformers pipelines | |
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2") | |
image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base") | |
# Initialize EasyOCR for image-based text extraction | |
reader = easyocr.Reader(['en']) | |
# Define a template for rendering HTML | |
templates = Jinja2Templates(directory="templates") | |
# Ensure temp_files directory exists | |
temp_dir = "temp_files" | |
os.makedirs(temp_dir, exist_ok=True) | |
# Function to process PDFs | |
def extract_pdf_text(file_path: str): | |
with pdfplumber.open(file_path) as pdf: | |
text = "" | |
for page in pdf.pages: | |
text += page.extract_text() | |
return text | |
# Function to process DOCX files | |
def extract_docx_text(file_path: str): | |
doc = docx.Document(file_path) | |
text = "\n".join([para.text for para in doc.paragraphs]) | |
return text | |
# Function to process PPTX files | |
def extract_pptx_text(file_path: str): | |
from pptx import Presentation | |
prs = Presentation(file_path) | |
text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")]) | |
return text | |
# Function to extract text from images using OCR | |
def extract_text_from_image(image: Image): | |
return pytesseract.image_to_string(image) | |
# Home route | |
@app.get("/") | |
def home(): | |
return RedirectResponse(url="/docs") | |
# Function to answer questions based on document content | |
@app.post("/question-answering-doc") | |
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)): | |
file_path = os.path.join(temp_dir, file.filename) | |
with open(file_path, "wb") as f: | |
f.write(await file.read()) | |
if file.filename.endswith(".pdf"): | |
text = extract_pdf_text(file_path) | |
elif file.filename.endswith(".docx"): | |
text = extract_docx_text(file_path) | |
elif file.filename.endswith(".pptx"): | |
text = extract_pptx_text(file_path) | |
else: | |
return {"error": "Unsupported file format"} | |
qa_result = qa_pipeline(question=question, context=text) | |
return {"answer": qa_result['answer']} | |
# Function to answer questions based on images | |
@app.post("/question-answering-image") | |
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)): | |
image = Image.open(BytesIO(await image_file.read())) | |
image_text = extract_text_from_image(image) | |
image_qa_result = image_qa_pipeline({"image": image, "question": question}) | |
return {"answer": image_qa_result[0]['answer'], "image_text": image_text} | |
# Serve the application in Hugging Face space | |
@app.get("/docs") | |
async def get_docs(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
""" | |
from fastapi import FastAPI | |
from fastapi.responses import RedirectResponse | |
import gradio as gr | |
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, AutoTokenizer, AutoModelForCausalLM | |
from PIL import Image | |
import torch | |
import fitz # PyMuPDF for PDF | |
app = FastAPI() | |
# ========== Document QA Setup ========== | |
doc_tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
doc_model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") | |
def read_pdf(file): | |
doc = fitz.open(stream=file.read(), filetype="pdf") | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
return text | |
def answer_question_from_doc(file, question): | |
if file is None or not question.strip(): | |
return "Please upload a document and ask a question." | |
text = read_pdf(file) | |
prompt = f"Context: {text}\nQuestion: {question}\nAnswer:" | |
inputs = doc_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
with torch.no_grad(): | |
outputs = doc_model.generate(**inputs, max_new_tokens=100) | |
answer = doc_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return answer.split("Answer:")[-1].strip() | |
# ========== Image QA Setup ========== | |
vqa_processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
vqa_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") | |
def answer_question_from_image(image, question): | |
if image is None or not question.strip(): | |
return "Please upload an image and ask a question." | |
inputs = vqa_processor(image, question, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = vqa_model(**inputs) | |
predicted_id = outputs.logits.argmax(-1).item() | |
return vqa_model.config.id2label[predicted_id] | |
# ========== Gradio Interfaces ========== | |
doc_interface = gr.Interface( | |
fn=answer_question_from_doc, | |
inputs=[gr.File(label="Upload Document (PDF)"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="Document Question Answering" | |
) | |
img_interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="Image Question Answering" | |
) | |
# ========== Combine and Mount ========== | |
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def root(): | |
return RedirectResponse(url="/") | |