Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,157 +1,3 @@
|
|
1 |
-
"""import gradio as gr
|
2 |
-
import numpy as np
|
3 |
-
import fitz # PyMuPDF
|
4 |
-
import torch
|
5 |
-
import asyncio
|
6 |
-
from fastapi import FastAPI
|
7 |
-
from transformers import pipeline
|
8 |
-
from PIL import Image
|
9 |
-
from starlette.responses import RedirectResponse
|
10 |
-
from openpyxl import load_workbook
|
11 |
-
from docx import Document
|
12 |
-
from pptx import Presentation
|
13 |
-
|
14 |
-
# Initialize FastAPI
|
15 |
-
app = FastAPI()
|
16 |
-
|
17 |
-
# Use GPU if available
|
18 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
-
print(f"β
Using device: {device}")
|
20 |
-
|
21 |
-
# Function to load models lazily
|
22 |
-
def get_qa_pipeline():
|
23 |
-
print("π Loading QA pipeline model...")
|
24 |
-
return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
|
25 |
-
|
26 |
-
def get_image_captioning_pipeline():
|
27 |
-
print("π Loading Image Captioning model...")
|
28 |
-
return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
|
29 |
-
|
30 |
-
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
|
31 |
-
MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing
|
32 |
-
|
33 |
-
# β
Validate File Type
|
34 |
-
def validate_file_type(file):
|
35 |
-
if hasattr(file, "name"):
|
36 |
-
ext = file.name.split(".")[-1].lower()
|
37 |
-
print(f"π File extension detected: {ext}")
|
38 |
-
if ext not in ALLOWED_EXTENSIONS:
|
39 |
-
print(f"β Unsupported file format: {ext}")
|
40 |
-
return f"β Unsupported file format: {ext}"
|
41 |
-
return None
|
42 |
-
print("β Invalid file format!")
|
43 |
-
return "β Invalid file format!"
|
44 |
-
|
45 |
-
# β
Extract Text from PDF
|
46 |
-
async def extract_text_from_pdf(file):
|
47 |
-
print(f"π Extracting text from PDF: {file.name}")
|
48 |
-
loop = asyncio.get_event_loop()
|
49 |
-
text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
|
50 |
-
print(f"β
Extracted {len(text)} characters from PDF")
|
51 |
-
return text
|
52 |
-
|
53 |
-
# β
Extract Text from DOCX
|
54 |
-
async def extract_text_from_docx(file):
|
55 |
-
print(f"π Extracting text from DOCX: {file.name}")
|
56 |
-
loop = asyncio.get_event_loop()
|
57 |
-
text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
|
58 |
-
print(f"β
Extracted {len(text)} characters from DOCX")
|
59 |
-
return text
|
60 |
-
|
61 |
-
# β
Extract Text from PPTX
|
62 |
-
async def extract_text_from_pptx(file):
|
63 |
-
print(f"π Extracting text from PPTX: {file.name}")
|
64 |
-
loop = asyncio.get_event_loop()
|
65 |
-
text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
|
66 |
-
print(f"β
Extracted {len(text)} characters from PPTX")
|
67 |
-
return text
|
68 |
-
|
69 |
-
# β
Extract Text from Excel
|
70 |
-
async def extract_text_from_excel(file):
|
71 |
-
print(f"π Extracting text from Excel: {file.name}")
|
72 |
-
loop = asyncio.get_event_loop()
|
73 |
-
text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
|
74 |
-
print(f"β
Extracted {len(text)} characters from Excel")
|
75 |
-
return text
|
76 |
-
|
77 |
-
# β
Truncate Long Text
|
78 |
-
def truncate_text(text):
|
79 |
-
print(f"βοΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
|
80 |
-
return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
|
81 |
-
|
82 |
-
# β
Answer Questions from Image or Document
|
83 |
-
async def answer_question(file, question: str):
|
84 |
-
print(f"β Question received: {question}")
|
85 |
-
|
86 |
-
if isinstance(file, np.ndarray): # Image Processing
|
87 |
-
print("πΌοΈ Processing image for captioning...")
|
88 |
-
image = Image.fromarray(file)
|
89 |
-
image_captioning = get_image_captioning_pipeline()
|
90 |
-
caption = image_captioning(image)[0]['generated_text']
|
91 |
-
print(f"π Generated caption: {caption}")
|
92 |
-
|
93 |
-
qa = get_qa_pipeline()
|
94 |
-
print("π€ Running QA model...")
|
95 |
-
response = qa(f"Question: {question}\nContext: {caption}")
|
96 |
-
print(f"β
Model response: {response[0]['generated_text']}")
|
97 |
-
return response[0]["generated_text"]
|
98 |
-
|
99 |
-
validation_error = validate_file_type(file)
|
100 |
-
if validation_error:
|
101 |
-
return validation_error
|
102 |
-
|
103 |
-
file_ext = file.name.split(".")[-1].lower()
|
104 |
-
|
105 |
-
# Extract text asynchronously
|
106 |
-
if file_ext == "pdf":
|
107 |
-
text = await extract_text_from_pdf(file)
|
108 |
-
elif file_ext == "docx":
|
109 |
-
text = await extract_text_from_docx(file)
|
110 |
-
elif file_ext == "pptx":
|
111 |
-
text = await extract_text_from_pptx(file)
|
112 |
-
elif file_ext == "xlsx":
|
113 |
-
text = await extract_text_from_excel(file)
|
114 |
-
else:
|
115 |
-
print("β Unsupported file format!")
|
116 |
-
return "β Unsupported file format!"
|
117 |
-
|
118 |
-
if not text:
|
119 |
-
print("β οΈ No text extracted from the document.")
|
120 |
-
return "β οΈ No text extracted from the document."
|
121 |
-
|
122 |
-
truncated_text = truncate_text(text)
|
123 |
-
|
124 |
-
# Run QA model asynchronously
|
125 |
-
print("π€ Running QA model...")
|
126 |
-
loop = asyncio.get_event_loop()
|
127 |
-
qa = get_qa_pipeline()
|
128 |
-
response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
|
129 |
-
|
130 |
-
print(f"β
Model response: {response[0]['generated_text']}")
|
131 |
-
return response[0]["generated_text"]
|
132 |
-
|
133 |
-
# β
Gradio Interface (Separate File & Image Inputs)
|
134 |
-
with gr.Blocks() as demo:
|
135 |
-
gr.Markdown("## π AI-Powered Document & Image QA")
|
136 |
-
|
137 |
-
with gr.Row():
|
138 |
-
file_input = gr.File(label="Upload Document")
|
139 |
-
image_input = gr.Image(label="Upload Image")
|
140 |
-
|
141 |
-
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
|
142 |
-
answer_output = gr.Textbox(label="Answer")
|
143 |
-
submit_btn = gr.Button("Get Answer")
|
144 |
-
|
145 |
-
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
|
146 |
-
|
147 |
-
|
148 |
-
# β
Mount Gradio with FastAPI
|
149 |
-
app = gr.mount_gradio_app(app, demo, path="/")
|
150 |
-
|
151 |
-
@app.get("/")
|
152 |
-
def home():
|
153 |
-
return RedirectResponse(url="/")
|
154 |
-
"""
|
155 |
from fastapi import FastAPI, Form, File, UploadFile
|
156 |
from fastapi.responses import RedirectResponse
|
157 |
from fastapi.staticfiles import StaticFiles
|
@@ -178,7 +24,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
178 |
|
179 |
# Initialize transformers pipelines
|
180 |
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
|
181 |
-
image_qa_pipeline = pipeline("
|
182 |
|
183 |
# Initialize EasyOCR for image-based text extraction
|
184 |
reader = easyocr.Reader(['en'])
|
@@ -186,6 +32,10 @@ reader = easyocr.Reader(['en'])
|
|
186 |
# Define a template for rendering HTML
|
187 |
templates = Jinja2Templates(directory="templates")
|
188 |
|
|
|
|
|
|
|
|
|
189 |
# Function to process PDFs
|
190 |
def extract_pdf_text(file_path: str):
|
191 |
with pdfplumber.open(file_path) as pdf:
|
@@ -197,26 +47,19 @@ def extract_pdf_text(file_path: str):
|
|
197 |
# Function to process DOCX files
|
198 |
def extract_docx_text(file_path: str):
|
199 |
doc = docx.Document(file_path)
|
200 |
-
text = ""
|
201 |
-
for para in doc.paragraphs:
|
202 |
-
text += para.text
|
203 |
return text
|
204 |
|
205 |
# Function to process PPTX files
|
206 |
def extract_pptx_text(file_path: str):
|
207 |
from pptx import Presentation
|
208 |
prs = Presentation(file_path)
|
209 |
-
text = ""
|
210 |
-
for slide in prs.slides:
|
211 |
-
for shape in slide.shapes:
|
212 |
-
if hasattr(shape, "text"):
|
213 |
-
text += shape.text
|
214 |
return text
|
215 |
|
216 |
# Function to extract text from images using OCR
|
217 |
def extract_text_from_image(image: Image):
|
218 |
-
|
219 |
-
return text
|
220 |
|
221 |
# Home route
|
222 |
@app.get("/")
|
@@ -226,13 +69,10 @@ def home():
|
|
226 |
# Function to answer questions based on document content
|
227 |
@app.post("/question-answering-doc")
|
228 |
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
|
229 |
-
|
230 |
-
file_path = f"temp_files/{file.filename}"
|
231 |
-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
232 |
with open(file_path, "wb") as f:
|
233 |
f.write(await file.read())
|
234 |
-
|
235 |
-
# Extract text based on file type
|
236 |
if file.filename.endswith(".pdf"):
|
237 |
text = extract_pdf_text(file_path)
|
238 |
elif file.filename.endswith(".docx"):
|
@@ -242,26 +82,20 @@ async def question_answering_doc(question: str = Form(...), file: UploadFile = F
|
|
242 |
else:
|
243 |
return {"error": "Unsupported file format"}
|
244 |
|
245 |
-
# Use the model for question answering
|
246 |
qa_result = qa_pipeline(question=question, context=text)
|
247 |
return {"answer": qa_result['answer']}
|
248 |
|
249 |
# Function to answer questions based on images
|
250 |
@app.post("/question-answering-image")
|
251 |
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
|
252 |
-
# Open the uploaded image
|
253 |
image = Image.open(BytesIO(await image_file.read()))
|
254 |
-
|
255 |
-
# Use EasyOCR to extract text if the image has textual content
|
256 |
image_text = extract_text_from_image(image)
|
|
|
|
|
257 |
|
258 |
-
|
259 |
-
image_qa_result = image_qa_pipeline(image=image, question=question)
|
260 |
-
|
261 |
-
return {"answer": image_qa_result['answer'], "image_text": image_text}
|
262 |
|
263 |
# Serve the application in Hugging Face space
|
264 |
@app.get("/docs")
|
265 |
async def get_docs(request: Request):
|
266 |
return templates.TemplateResponse("index.html", {"request": request})
|
267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, Form, File, UploadFile
|
2 |
from fastapi.responses import RedirectResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
|
|
24 |
|
25 |
# Initialize transformers pipelines
|
26 |
qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
|
27 |
+
image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
|
28 |
|
29 |
# Initialize EasyOCR for image-based text extraction
|
30 |
reader = easyocr.Reader(['en'])
|
|
|
32 |
# Define a template for rendering HTML
|
33 |
templates = Jinja2Templates(directory="templates")
|
34 |
|
35 |
+
# Ensure temp_files directory exists
|
36 |
+
temp_dir = "temp_files"
|
37 |
+
os.makedirs(temp_dir, exist_ok=True)
|
38 |
+
|
39 |
# Function to process PDFs
|
40 |
def extract_pdf_text(file_path: str):
|
41 |
with pdfplumber.open(file_path) as pdf:
|
|
|
47 |
# Function to process DOCX files
|
48 |
def extract_docx_text(file_path: str):
|
49 |
doc = docx.Document(file_path)
|
50 |
+
text = "\n".join([para.text for para in doc.paragraphs])
|
|
|
|
|
51 |
return text
|
52 |
|
53 |
# Function to process PPTX files
|
54 |
def extract_pptx_text(file_path: str):
|
55 |
from pptx import Presentation
|
56 |
prs = Presentation(file_path)
|
57 |
+
text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
|
|
|
|
|
|
|
|
|
58 |
return text
|
59 |
|
60 |
# Function to extract text from images using OCR
|
61 |
def extract_text_from_image(image: Image):
|
62 |
+
return pytesseract.image_to_string(image)
|
|
|
63 |
|
64 |
# Home route
|
65 |
@app.get("/")
|
|
|
69 |
# Function to answer questions based on document content
|
70 |
@app.post("/question-answering-doc")
|
71 |
async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
|
72 |
+
file_path = os.path.join(temp_dir, file.filename)
|
|
|
|
|
73 |
with open(file_path, "wb") as f:
|
74 |
f.write(await file.read())
|
75 |
+
|
|
|
76 |
if file.filename.endswith(".pdf"):
|
77 |
text = extract_pdf_text(file_path)
|
78 |
elif file.filename.endswith(".docx"):
|
|
|
82 |
else:
|
83 |
return {"error": "Unsupported file format"}
|
84 |
|
|
|
85 |
qa_result = qa_pipeline(question=question, context=text)
|
86 |
return {"answer": qa_result['answer']}
|
87 |
|
88 |
# Function to answer questions based on images
|
89 |
@app.post("/question-answering-image")
|
90 |
async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
|
|
|
91 |
image = Image.open(BytesIO(await image_file.read()))
|
|
|
|
|
92 |
image_text = extract_text_from_image(image)
|
93 |
+
|
94 |
+
image_qa_result = image_qa_pipeline({"image": image, "question": question})
|
95 |
|
96 |
+
return {"answer": image_qa_result[0]['answer'], "image_text": image_text}
|
|
|
|
|
|
|
97 |
|
98 |
# Serve the application in Hugging Face space
|
99 |
@app.get("/docs")
|
100 |
async def get_docs(request: Request):
|
101 |
return templates.TemplateResponse("index.html", {"request": request})
|
|