Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import gradio as gr
|
2 |
import uvicorn
|
3 |
import numpy as np
|
4 |
import fitz # PyMuPDF
|
@@ -111,6 +111,137 @@ with gr.Blocks() as demo:
|
|
111 |
# β
Mount Gradio with FastAPI
|
112 |
app = gr.mount_gradio_app(app, demo, path="/")
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
@app.get("/")
|
115 |
def home():
|
116 |
return RedirectResponse(url="/")
|
|
|
1 |
+
"""import gradio as gr
|
2 |
import uvicorn
|
3 |
import numpy as np
|
4 |
import fitz # PyMuPDF
|
|
|
111 |
# β
Mount Gradio with FastAPI
|
112 |
app = gr.mount_gradio_app(app, demo, path="/")
|
113 |
|
114 |
+
@app.get("/")
|
115 |
+
def home():
|
116 |
+
return RedirectResponse(url="/")
|
117 |
+
|
118 |
+
# β
Run FastAPI + Gradio
|
119 |
+
if __name__ == "__main__":
|
120 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
121 |
+
"""
|
122 |
+
import gradio as gr
|
123 |
+
import uvicorn
|
124 |
+
import numpy as np
|
125 |
+
import fitz # PyMuPDF
|
126 |
+
import tika
|
127 |
+
import torch
|
128 |
+
from fastapi import FastAPI
|
129 |
+
from transformers import pipeline
|
130 |
+
from PIL import Image
|
131 |
+
from io import BytesIO
|
132 |
+
from starlette.responses import RedirectResponse
|
133 |
+
from tika import parser
|
134 |
+
from openpyxl import load_workbook
|
135 |
+
import os
|
136 |
+
|
137 |
+
# Initialize Tika for DOCX & PPTX parsing
|
138 |
+
tika.initVM()
|
139 |
+
|
140 |
+
# Initialize FastAPI
|
141 |
+
app = FastAPI()
|
142 |
+
|
143 |
+
# Load models
|
144 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
145 |
+
qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)
|
146 |
+
image_captioning_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
|
147 |
+
|
148 |
+
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
|
149 |
+
|
150 |
+
# β
Function to Validate File Type
|
151 |
+
def validate_file_type(file):
|
152 |
+
if isinstance(file, str): # If it's text input (NamedString)
|
153 |
+
return None
|
154 |
+
if hasattr(file, "name") and file.name:
|
155 |
+
ext = file.name.split(".")[-1].lower()
|
156 |
+
if ext not in ALLOWED_EXTENSIONS:
|
157 |
+
return f"β Unsupported file format: {ext}"
|
158 |
+
return None
|
159 |
+
return "β Invalid file format!"
|
160 |
+
|
161 |
+
# β
Extract Text from PDF
|
162 |
+
def extract_text_from_pdf(pdf_bytes):
|
163 |
+
doc = fitz.open(stream=pdf_bytes, filetype="pdf")
|
164 |
+
return "\n".join([page.get_text() for page in doc])
|
165 |
+
|
166 |
+
# β
Extract Text from DOCX & PPTX using Tika
|
167 |
+
def extract_text_with_tika(file_bytes):
|
168 |
+
parsed = parser.from_buffer(file_bytes)
|
169 |
+
return parsed.get("content", "").strip() if parsed else ""
|
170 |
+
|
171 |
+
# β
Extract Text from Excel
|
172 |
+
def extract_text_from_excel(file_path):
|
173 |
+
wb = load_workbook(file_path, data_only=True)
|
174 |
+
text = []
|
175 |
+
for sheet in wb.worksheets:
|
176 |
+
for row in sheet.iter_rows(values_only=True):
|
177 |
+
text.append(" ".join(str(cell) for cell in row if cell))
|
178 |
+
return "\n".join(text)
|
179 |
+
|
180 |
+
# β
Truncate Long Text for Model
|
181 |
+
def truncate_text(text, max_length=2048):
|
182 |
+
return text[:max_length] if len(text) > max_length else text
|
183 |
+
|
184 |
+
# β
Answer Questions from Image or Document
|
185 |
+
def answer_question(file, question: str):
|
186 |
+
# πΌοΈ Handle Image Input (Gradio sends NumPy arrays)
|
187 |
+
if isinstance(file, np.ndarray):
|
188 |
+
image = Image.fromarray(file)
|
189 |
+
caption = image_captioning_pipeline(image)[0]['generated_text']
|
190 |
+
response = qa_pipeline(f"Question: {question}\nContext: {caption}")
|
191 |
+
return response[0]["generated_text"]
|
192 |
+
|
193 |
+
# Validate File
|
194 |
+
validation_error = validate_file_type(file)
|
195 |
+
if validation_error:
|
196 |
+
return validation_error
|
197 |
+
|
198 |
+
file_ext = file.name.split(".")[-1].lower() if hasattr(file, "name") else None
|
199 |
+
|
200 |
+
# π οΈ Fix: Read File Bytes Correctly (Gradio Provides File Path)
|
201 |
+
try:
|
202 |
+
with open(file.name, "rb") as f:
|
203 |
+
file_bytes = f.read()
|
204 |
+
except Exception as e:
|
205 |
+
return f"β Error reading file: {str(e)}"
|
206 |
+
|
207 |
+
if not file_bytes:
|
208 |
+
return "β Could not read file content!"
|
209 |
+
|
210 |
+
# π Extract Text from Supported Documents
|
211 |
+
if file_ext == "pdf":
|
212 |
+
text = extract_text_from_pdf(file_bytes)
|
213 |
+
elif file_ext in ["docx", "pptx"]:
|
214 |
+
text = extract_text_with_tika(file_bytes)
|
215 |
+
elif file_ext == "xlsx":
|
216 |
+
text = extract_text_from_excel(file.name)
|
217 |
+
else:
|
218 |
+
return "β Unsupported file format!"
|
219 |
+
|
220 |
+
if not text.strip():
|
221 |
+
return "β οΈ No text extracted from the document."
|
222 |
+
|
223 |
+
# π₯ Run Model on Extracted Text
|
224 |
+
truncated_text = truncate_text(text)
|
225 |
+
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
|
226 |
+
|
227 |
+
return response[0]["generated_text"]
|
228 |
+
|
229 |
+
# β
Gradio Interface (Unified for Images & Documents)
|
230 |
+
with gr.Blocks() as demo:
|
231 |
+
gr.Markdown("## π AI-Powered Document & Image QA")
|
232 |
+
|
233 |
+
with gr.Row():
|
234 |
+
file_input = gr.File(label="Upload Document / Image")
|
235 |
+
question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
|
236 |
+
|
237 |
+
answer_output = gr.Textbox(label="Answer")
|
238 |
+
|
239 |
+
submit_btn = gr.Button("Get Answer")
|
240 |
+
submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
|
241 |
+
|
242 |
+
# β
Mount Gradio with FastAPI
|
243 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
244 |
+
|
245 |
@app.get("/")
|
246 |
def home():
|
247 |
return RedirectResponse(url="/")
|