ikraamkb commited on
Commit
0540355
Β·
verified Β·
1 Parent(s): d74850e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -57
app.py CHANGED
@@ -1,74 +1,83 @@
1
- from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import RedirectResponse
3
- import fitz # PyMuPDF for PDF parsing
4
- from tika import parser # Apache Tika for document parsing
5
- import openpyxl
6
- from pptx import Presentation
7
- from PIL import Image
8
- from transformers import pipeline
9
  import gradio as gr
 
10
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  # Initialize FastAPI
13
  app = FastAPI()
14
 
15
- print("πŸ”„ Loading models...")
16
-
17
- # Load Hugging Face Models
18
- qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1)
19
- image_captioning_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device=-1, use_fast=True)
20
 
21
- print("βœ… Models loaded (Optimized for Speed)")
22
-
23
- # Allowed File Extensions
24
- ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx", "jpg", "jpeg", "png"}
25
 
 
26
  def validate_file_type(file):
27
- ext = file.filename.split(".")[-1].lower()
28
- if ext not in ALLOWED_EXTENSIONS:
29
- return f"❌ Unsupported file format: {ext}"
30
- return None
31
-
32
- # Function to truncate text to 450 tokens
33
- def truncate_text(text, max_tokens=450):
34
- words = text.split()
35
- return " ".join(words[:max_tokens])
36
-
37
- # Document Text Extraction Functions
38
  def extract_text_from_pdf(pdf_bytes):
39
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
40
- text = "\n".join([page.get_text("text") for page in doc])
41
- return text if text else "⚠️ No text found."
42
 
 
43
  def extract_text_with_tika(file_bytes):
44
- parsed = parser.from_buffer(file_bytes)
45
- return parsed.get("content", "⚠️ No text found.").strip()
46
 
47
- def extract_text_from_excel(excel_bytes):
48
- wb = openpyxl.load_workbook(excel_bytes, read_only=True)
 
49
  text = []
50
  for sheet in wb.worksheets:
51
  for row in sheet.iter_rows(values_only=True):
52
- text.append(" ".join(map(str, row)))
53
- return "\n".join(text) if text else "⚠️ No text found."
54
 
55
- # Function to process file (document or image) and answer question
 
 
 
 
56
  def answer_question(file, question: str):
 
57
  if isinstance(file, np.ndarray):
58
- # Image processing
59
- image = Image.fromarray(file)
60
  caption = image_captioning_pipeline(image)[0]['generated_text']
61
  response = qa_pipeline(f"Question: {question}\nContext: {caption}")
62
  return response[0]["generated_text"]
63
-
64
- # Document processing
65
  validation_error = validate_file_type(file)
66
  if validation_error:
67
  return validation_error
68
-
69
- file_ext = file.name.split(".")[-1].lower()
70
- file_bytes = file.read()
71
 
 
 
 
 
 
 
72
  if file_ext == "pdf":
73
  text = extract_text_from_pdf(file_bytes)
74
  elif file_ext in ["docx", "pptx"]:
@@ -77,32 +86,35 @@ def answer_question(file, question: str):
77
  text = extract_text_from_excel(file_bytes)
78
  else:
79
  return "❌ Unsupported file format!"
80
-
81
  if not text:
82
  return "⚠️ No text extracted from the document."
83
-
84
  truncated_text = truncate_text(text)
85
  response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
86
 
87
  return response[0]["generated_text"]
88
 
89
- # Gradio Interface for both images & documents
90
- interface = gr.Interface(
91
- fn=answer_question,
92
- inputs=[gr.File(label="πŸ“‚ Upload Document or Image"), gr.Textbox(label="πŸ’¬ Ask a Question")],
93
- outputs="text",
94
- title="πŸ“„πŸ–ΌοΈ AI Document & Image Question Answering"
95
- )
 
 
 
 
 
96
 
97
- # Mount Gradio with FastAPI
98
- demo = interface
99
  app = gr.mount_gradio_app(app, demo, path="/")
100
 
101
  @app.get("/")
102
  def home():
103
  return RedirectResponse(url="/")
104
 
105
- # Run FastAPI + Gradio together
106
  if __name__ == "__main__":
107
- import uvicorn
108
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import uvicorn
3
  import numpy as np
4
+ import fitz # PyMuPDF
5
+ import tika
6
+ import torch
7
+ from fastapi import FastAPI
8
+ from transformers import pipeline
9
+ from PIL import Image
10
+ from io import BytesIO
11
+ from starlette.responses import RedirectResponse
12
+ from tika import parser
13
+ from openpyxl import load_workbook
14
+
15
+ # Initialize Tika for DOCX & PPTX parsing
16
+ tika.initVM()
17
 
18
  # Initialize FastAPI
19
  app = FastAPI()
20
 
21
+ # Load models
22
+ device = "cuda" if torch.cuda.is_available() else "cpu"
23
+ qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)
24
+ image_captioning_pipeline = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
 
25
 
26
+ ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
 
 
 
27
 
28
+ # βœ… Function to Validate File Type
29
  def validate_file_type(file):
30
+ if isinstance(file, str): # Text-based input (NamedString)
31
+ return None
32
+ if hasattr(file, "name"):
33
+ ext = file.name.split(".")[-1].lower()
34
+ if ext not in ALLOWED_EXTENSIONS:
35
+ return f"❌ Unsupported file format: {ext}"
36
+ return None
37
+ return "❌ Invalid file format!"
38
+
39
+ # βœ… Extract Text from PDF
 
40
  def extract_text_from_pdf(pdf_bytes):
41
  doc = fitz.open(stream=pdf_bytes, filetype="pdf")
42
+ return "\n".join([page.get_text() for page in doc])
 
43
 
44
+ # βœ… Extract Text from DOCX & PPTX using Tika
45
  def extract_text_with_tika(file_bytes):
46
+ return parser.from_buffer(file_bytes)["content"]
 
47
 
48
+ # βœ… Extract Text from Excel
49
+ def extract_text_from_excel(file_bytes):
50
+ wb = load_workbook(BytesIO(file_bytes), data_only=True)
51
  text = []
52
  for sheet in wb.worksheets:
53
  for row in sheet.iter_rows(values_only=True):
54
+ text.append(" ".join(str(cell) for cell in row if cell))
55
+ return "\n".join(text)
56
 
57
+ # βœ… Truncate Long Text for Model
58
+ def truncate_text(text, max_length=2048):
59
+ return text[:max_length] if len(text) > max_length else text
60
+
61
+ # βœ… Answer Questions from Image or Document
62
  def answer_question(file, question: str):
63
+ # Image Processing (Gradio sends images as NumPy arrays)
64
  if isinstance(file, np.ndarray):
65
+ image = Image.fromarray(file)
 
66
  caption = image_captioning_pipeline(image)[0]['generated_text']
67
  response = qa_pipeline(f"Question: {question}\nContext: {caption}")
68
  return response[0]["generated_text"]
69
+
70
+ # Validate File
71
  validation_error = validate_file_type(file)
72
  if validation_error:
73
  return validation_error
 
 
 
74
 
75
+ file_ext = file.name.split(".")[-1].lower() if hasattr(file, "name") else None
76
+ file_bytes = file.read() if hasattr(file, "read") else None
77
+ if not file_bytes:
78
+ return "❌ Could not read file content!"
79
+
80
+ # Extract Text from Supported Documents
81
  if file_ext == "pdf":
82
  text = extract_text_from_pdf(file_bytes)
83
  elif file_ext in ["docx", "pptx"]:
 
86
  text = extract_text_from_excel(file_bytes)
87
  else:
88
  return "❌ Unsupported file format!"
89
+
90
  if not text:
91
  return "⚠️ No text extracted from the document."
92
+
93
  truncated_text = truncate_text(text)
94
  response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
95
 
96
  return response[0]["generated_text"]
97
 
98
+ # βœ… Gradio Interface (Unified for Images & Documents)
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")
101
+
102
+ with gr.Row():
103
+ file_input = gr.File(label="Upload Document / Image")
104
+ question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
105
+
106
+ answer_output = gr.Textbox(label="Answer")
107
+
108
+ submit_btn = gr.Button("Get Answer")
109
+ submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)
110
 
111
+ # βœ… Mount Gradio with FastAPI
 
112
  app = gr.mount_gradio_app(app, demo, path="/")
113
 
114
  @app.get("/")
115
  def home():
116
  return RedirectResponse(url="/")
117
 
118
+ # βœ… Run FastAPI + Gradio
119
  if __name__ == "__main__":
 
120
  uvicorn.run(app, host="0.0.0.0", port=7860)