ikraamkb commited on
Commit
da9e0ce
Β·
verified Β·
1 Parent(s): 1e4a65e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -179
app.py CHANGED
@@ -1,157 +1,3 @@
1
- """import gradio as gr
2
- import numpy as np
3
- import fitz # PyMuPDF
4
- import torch
5
- import asyncio
6
- from fastapi import FastAPI
7
- from transformers import pipeline
8
- from PIL import Image
9
- from starlette.responses import RedirectResponse
10
- from openpyxl import load_workbook
11
- from docx import Document
12
- from pptx import Presentation
13
-
14
- # Initialize FastAPI
15
- app = FastAPI()
16
-
17
- # Use GPU if available
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- print(f"βœ… Using device: {device}")
20
-
21
- # Function to load models lazily
22
- def get_qa_pipeline():
23
- print("πŸ”„ Loading QA pipeline model...")
24
- return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
25
-
26
- def get_image_captioning_pipeline():
27
- print("πŸ”„ Loading Image Captioning model...")
28
- return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
29
-
30
- ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
31
- MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing
32
-
33
- # βœ… Validate File Type
34
- def validate_file_type(file):
35
- if hasattr(file, "name"):
36
- ext = file.name.split(".")[-1].lower()
37
- print(f"πŸ“ File extension detected: {ext}")
38
- if ext not in ALLOWED_EXTENSIONS:
39
- print(f"❌ Unsupported file format: {ext}")
40
- return f"❌ Unsupported file format: {ext}"
41
- return None
42
- print("❌ Invalid file format!")
43
- return "❌ Invalid file format!"
44
-
45
- # βœ… Extract Text from PDF
46
- async def extract_text_from_pdf(file):
47
- print(f"πŸ“„ Extracting text from PDF: {file.name}")
48
- loop = asyncio.get_event_loop()
49
- text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
50
- print(f"βœ… Extracted {len(text)} characters from PDF")
51
- return text
52
-
53
- # βœ… Extract Text from DOCX
54
- async def extract_text_from_docx(file):
55
- print(f"πŸ“„ Extracting text from DOCX: {file.name}")
56
- loop = asyncio.get_event_loop()
57
- text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
58
- print(f"βœ… Extracted {len(text)} characters from DOCX")
59
- return text
60
-
61
- # βœ… Extract Text from PPTX
62
- async def extract_text_from_pptx(file):
63
- print(f"πŸ“„ Extracting text from PPTX: {file.name}")
64
- loop = asyncio.get_event_loop()
65
- text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
66
- print(f"βœ… Extracted {len(text)} characters from PPTX")
67
- return text
68
-
69
- # βœ… Extract Text from Excel
70
- async def extract_text_from_excel(file):
71
- print(f"πŸ“„ Extracting text from Excel: {file.name}")
72
- loop = asyncio.get_event_loop()
73
- text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
74
- print(f"βœ… Extracted {len(text)} characters from Excel")
75
- return text
76
-
77
- # βœ… Truncate Long Text
78
- def truncate_text(text):
79
- print(f"βœ‚οΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
80
- return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
81
-
82
- # βœ… Answer Questions from Image or Document
83
- async def answer_question(file, question: str):
84
- print(f"❓ Question received: {question}")
85
-
86
- if isinstance(file, np.ndarray): # Image Processing
87
- print("πŸ–ΌοΈ Processing image for captioning...")
88
- image = Image.fromarray(file)
89
- image_captioning = get_image_captioning_pipeline()
90
- caption = image_captioning(image)[0]['generated_text']
91
- print(f"πŸ“ Generated caption: {caption}")
92
-
93
- qa = get_qa_pipeline()
94
- print("πŸ€– Running QA model...")
95
- response = qa(f"Question: {question}\nContext: {caption}")
96
- print(f"βœ… Model response: {response[0]['generated_text']}")
97
- return response[0]["generated_text"]
98
-
99
- validation_error = validate_file_type(file)
100
- if validation_error:
101
- return validation_error
102
-
103
- file_ext = file.name.split(".")[-1].lower()
104
-
105
- # Extract text asynchronously
106
- if file_ext == "pdf":
107
- text = await extract_text_from_pdf(file)
108
- elif file_ext == "docx":
109
- text = await extract_text_from_docx(file)
110
- elif file_ext == "pptx":
111
- text = await extract_text_from_pptx(file)
112
- elif file_ext == "xlsx":
113
- text = await extract_text_from_excel(file)
114
- else:
115
- print("❌ Unsupported file format!")
116
- return "❌ Unsupported file format!"
117
-
118
- if not text:
119
- print("⚠️ No text extracted from the document.")
120
- return "⚠️ No text extracted from the document."
121
-
122
- truncated_text = truncate_text(text)
123
-
124
- # Run QA model asynchronously
125
- print("πŸ€– Running QA model...")
126
- loop = asyncio.get_event_loop()
127
- qa = get_qa_pipeline()
128
- response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
129
-
130
- print(f"βœ… Model response: {response[0]['generated_text']}")
131
- return response[0]["generated_text"]
132
-
133
- # βœ… Gradio Interface (Separate File & Image Inputs)
134
- with gr.Blocks() as demo:
135
- gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")
136
-
137
- with gr.Row():
138
- file_input = gr.File(label="Upload Document")
139
- image_input = gr.Image(label="Upload Image")
140
-
141
- question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
142
- answer_output = gr.Textbox(label="Answer")
143
- submit_btn = gr.Button("Get Answer")
144
-
145
- submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
146
-
147
-
148
- # βœ… Mount Gradio with FastAPI
149
- app = gr.mount_gradio_app(app, demo, path="/")
150
-
151
- @app.get("/")
152
- def home():
153
- return RedirectResponse(url="/")
154
- """
155
  from fastapi import FastAPI, Form, File, UploadFile
156
  from fastapi.responses import RedirectResponse
157
  from fastapi.staticfiles import StaticFiles
@@ -178,7 +24,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
178
 
179
  # Initialize transformers pipelines
180
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
181
- image_qa_pipeline = pipeline("image-question-answering", model="Salesforce/blip-vqa-base", tokenizer="Salesforce/blip-vqa-base")
182
 
183
  # Initialize EasyOCR for image-based text extraction
184
  reader = easyocr.Reader(['en'])
@@ -186,6 +32,10 @@ reader = easyocr.Reader(['en'])
186
  # Define a template for rendering HTML
187
  templates = Jinja2Templates(directory="templates")
188
 
 
 
 
 
189
  # Function to process PDFs
190
  def extract_pdf_text(file_path: str):
191
  with pdfplumber.open(file_path) as pdf:
@@ -197,26 +47,19 @@ def extract_pdf_text(file_path: str):
197
  # Function to process DOCX files
198
  def extract_docx_text(file_path: str):
199
  doc = docx.Document(file_path)
200
- text = ""
201
- for para in doc.paragraphs:
202
- text += para.text
203
  return text
204
 
205
  # Function to process PPTX files
206
  def extract_pptx_text(file_path: str):
207
  from pptx import Presentation
208
  prs = Presentation(file_path)
209
- text = ""
210
- for slide in prs.slides:
211
- for shape in slide.shapes:
212
- if hasattr(shape, "text"):
213
- text += shape.text
214
  return text
215
 
216
  # Function to extract text from images using OCR
217
  def extract_text_from_image(image: Image):
218
- text = pytesseract.image_to_string(image)
219
- return text
220
 
221
  # Home route
222
  @app.get("/")
@@ -226,13 +69,10 @@ def home():
226
  # Function to answer questions based on document content
227
  @app.post("/question-answering-doc")
228
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
229
- # Save the uploaded file temporarily
230
- file_path = f"temp_files/{file.filename}"
231
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
232
  with open(file_path, "wb") as f:
233
  f.write(await file.read())
234
-
235
- # Extract text based on file type
236
  if file.filename.endswith(".pdf"):
237
  text = extract_pdf_text(file_path)
238
  elif file.filename.endswith(".docx"):
@@ -242,26 +82,20 @@ async def question_answering_doc(question: str = Form(...), file: UploadFile = F
242
  else:
243
  return {"error": "Unsupported file format"}
244
 
245
- # Use the model for question answering
246
  qa_result = qa_pipeline(question=question, context=text)
247
  return {"answer": qa_result['answer']}
248
 
249
  # Function to answer questions based on images
250
  @app.post("/question-answering-image")
251
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
252
- # Open the uploaded image
253
  image = Image.open(BytesIO(await image_file.read()))
254
-
255
- # Use EasyOCR to extract text if the image has textual content
256
  image_text = extract_text_from_image(image)
 
 
257
 
258
- # Use the BLIP VQA model for question answering on the image
259
- image_qa_result = image_qa_pipeline(image=image, question=question)
260
-
261
- return {"answer": image_qa_result['answer'], "image_text": image_text}
262
 
263
  # Serve the application in Hugging Face space
264
  @app.get("/docs")
265
  async def get_docs(request: Request):
266
  return templates.TemplateResponse("index.html", {"request": request})
267
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, Form, File, UploadFile
2
  from fastapi.responses import RedirectResponse
3
  from fastapi.staticfiles import StaticFiles
 
24
 
25
  # Initialize transformers pipelines
26
  qa_pipeline = pipeline("question-answering", model="microsoft/phi-2", tokenizer="microsoft/phi-2")
27
+ image_qa_pipeline = pipeline("vqa", model="Salesforce/blip-vqa-base")
28
 
29
  # Initialize EasyOCR for image-based text extraction
30
  reader = easyocr.Reader(['en'])
 
32
  # Define a template for rendering HTML
33
  templates = Jinja2Templates(directory="templates")
34
 
35
+ # Ensure temp_files directory exists
36
+ temp_dir = "temp_files"
37
+ os.makedirs(temp_dir, exist_ok=True)
38
+
39
  # Function to process PDFs
40
  def extract_pdf_text(file_path: str):
41
  with pdfplumber.open(file_path) as pdf:
 
47
  # Function to process DOCX files
48
  def extract_docx_text(file_path: str):
49
  doc = docx.Document(file_path)
50
+ text = "\n".join([para.text for para in doc.paragraphs])
 
 
51
  return text
52
 
53
  # Function to process PPTX files
54
  def extract_pptx_text(file_path: str):
55
  from pptx import Presentation
56
  prs = Presentation(file_path)
57
+ text = "\n".join([shape.text for slide in prs.slides for shape in slide.shapes if hasattr(shape, "text")])
 
 
 
 
58
  return text
59
 
60
  # Function to extract text from images using OCR
61
  def extract_text_from_image(image: Image):
62
+ return pytesseract.image_to_string(image)
 
63
 
64
  # Home route
65
  @app.get("/")
 
69
  # Function to answer questions based on document content
70
  @app.post("/question-answering-doc")
71
  async def question_answering_doc(question: str = Form(...), file: UploadFile = File(...)):
72
+ file_path = os.path.join(temp_dir, file.filename)
 
 
73
  with open(file_path, "wb") as f:
74
  f.write(await file.read())
75
+
 
76
  if file.filename.endswith(".pdf"):
77
  text = extract_pdf_text(file_path)
78
  elif file.filename.endswith(".docx"):
 
82
  else:
83
  return {"error": "Unsupported file format"}
84
 
 
85
  qa_result = qa_pipeline(question=question, context=text)
86
  return {"answer": qa_result['answer']}
87
 
88
  # Function to answer questions based on images
89
  @app.post("/question-answering-image")
90
  async def question_answering_image(question: str = Form(...), image_file: UploadFile = File(...)):
 
91
  image = Image.open(BytesIO(await image_file.read()))
 
 
92
  image_text = extract_text_from_image(image)
93
+
94
+ image_qa_result = image_qa_pipeline({"image": image, "question": question})
95
 
96
+ return {"answer": image_qa_result[0]['answer'], "image_text": image_text}
 
 
 
97
 
98
  # Serve the application in Hugging Face space
99
  @app.get("/docs")
100
  async def get_docs(request: Request):
101
  return templates.TemplateResponse("index.html", {"request": request})