ikraamkb commited on
Commit
29f5581
Β·
verified Β·
1 Parent(s): d51b69d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -36
app.py CHANGED
@@ -1,30 +1,33 @@
1
  import gradio as gr
2
  import numpy as np
3
  import fitz # PyMuPDF
4
- import tika
5
  import torch
 
6
  from fastapi import FastAPI
7
  from transformers import pipeline
8
  from PIL import Image
9
- from io import BytesIO
10
  from starlette.responses import RedirectResponse
11
- from tika import parser
12
  from openpyxl import load_workbook
13
-
14
- # Initialize Tika for DOCX & PPTX parsing (Ensure Java is installed)
15
- tika.initVM()
16
 
17
  # Initialize FastAPI
18
  app = FastAPI()
19
 
20
- # Load models
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
- qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device)
23
- image_captioning_pipeline = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
 
 
 
 
 
24
 
25
  ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
 
26
 
27
- # βœ… Function to Validate File Type
28
  def validate_file_type(file):
29
  if hasattr(file, "name"):
30
  ext = file.name.split(".")[-1].lower()
@@ -34,33 +37,38 @@ def validate_file_type(file):
34
  return "❌ Invalid file format!"
35
 
36
  # βœ… Extract Text from PDF
37
- def extract_text_from_pdf(file):
38
- with fitz.open(file.name) as doc:
39
- return "\n".join([page.get_text() for page in doc])
 
 
 
 
 
40
 
41
- # βœ… Extract Text from DOCX & PPTX using Tika
42
- def extract_text_with_tika(file):
43
- return parser.from_file(file.name)["content"]
 
44
 
45
  # βœ… Extract Text from Excel
46
- def extract_text_from_excel(file):
47
- wb = load_workbook(file.name, data_only=True)
48
- text = []
49
- for sheet in wb.worksheets:
50
- for row in sheet.iter_rows(values_only=True):
51
- text.append(" ".join(str(cell) for cell in row if cell))
52
- return "\n".join(text)
53
-
54
- # βœ… Truncate Long Text for Model
55
- def truncate_text(text, max_length=2048):
56
- return text[:max_length] if len(text) > max_length else text
57
 
58
  # βœ… Answer Questions from Image or Document
59
- def answer_question(file, question: str):
60
  if isinstance(file, np.ndarray): # Image Processing
61
  image = Image.fromarray(file)
62
- caption = image_captioning_pipeline(image)[0]['generated_text']
63
- response = qa_pipeline(f"Question: {question}\nContext: {caption}")
 
 
 
64
  return response[0]["generated_text"]
65
 
66
  validation_error = validate_file_type(file)
@@ -69,13 +77,15 @@ def answer_question(file, question: str):
69
 
70
  file_ext = file.name.split(".")[-1].lower()
71
 
72
- # Extract Text from Supported Documents
73
  if file_ext == "pdf":
74
- text = extract_text_from_pdf(file)
75
- elif file_ext in ["docx", "pptx"]:
76
- text = extract_text_with_tika(file)
 
 
77
  elif file_ext == "xlsx":
78
- text = extract_text_from_excel(file)
79
  else:
80
  return "❌ Unsupported file format!"
81
 
@@ -83,7 +93,11 @@ def answer_question(file, question: str):
83
  return "⚠️ No text extracted from the document."
84
 
85
  truncated_text = truncate_text(text)
86
- response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")
 
 
 
 
87
 
88
  return response[0]["generated_text"]
89
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import fitz # PyMuPDF
 
4
  import torch
5
+ import asyncio
6
  from fastapi import FastAPI
7
  from transformers import pipeline
8
  from PIL import Image
 
9
  from starlette.responses import RedirectResponse
 
10
  from openpyxl import load_workbook
11
+ from docx import Document
12
+ from pptx import Presentation
 
13
 
14
  # Initialize FastAPI
15
  app = FastAPI()
16
 
17
+ # Use GPU if available
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ # Function to load models lazily
21
+ def get_qa_pipeline():
22
+ return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)
23
+
24
+ def get_image_captioning_pipeline():
25
+ return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
26
 
27
  ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
28
+ MAX_INPUT_LENGTH = 1024 # Limit input length for faster processing
29
 
30
+ # βœ… Validate File Type
31
  def validate_file_type(file):
32
  if hasattr(file, "name"):
33
  ext = file.name.split(".")[-1].lower()
 
37
  return "❌ Invalid file format!"
38
 
39
  # βœ… Extract Text from PDF
40
+ async def extract_text_from_pdf(file):
41
+ loop = asyncio.get_event_loop()
42
+ return await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
43
+
44
+ # βœ… Extract Text from DOCX
45
+ async def extract_text_from_docx(file):
46
+ loop = asyncio.get_event_loop()
47
+ return await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
48
 
49
+ # βœ… Extract Text from PPTX
50
+ async def extract_text_from_pptx(file):
51
+ loop = asyncio.get_event_loop()
52
+ return await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
53
 
54
  # βœ… Extract Text from Excel
55
+ async def extract_text_from_excel(file):
56
+ loop = asyncio.get_event_loop()
57
+ return await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
58
+
59
+ # βœ… Truncate Long Text
60
+ def truncate_text(text):
61
+ return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text
 
 
 
 
62
 
63
  # βœ… Answer Questions from Image or Document
64
+ async def answer_question(file, question: str):
65
  if isinstance(file, np.ndarray): # Image Processing
66
  image = Image.fromarray(file)
67
+ image_captioning = get_image_captioning_pipeline()
68
+ caption = image_captioning(image)[0]['generated_text']
69
+
70
+ qa = get_qa_pipeline()
71
+ response = qa(f"Question: {question}\nContext: {caption}")
72
  return response[0]["generated_text"]
73
 
74
  validation_error = validate_file_type(file)
 
77
 
78
  file_ext = file.name.split(".")[-1].lower()
79
 
80
+ # Extract text asynchronously
81
  if file_ext == "pdf":
82
+ text = await extract_text_from_pdf(file)
83
+ elif file_ext == "docx":
84
+ text = await extract_text_from_docx(file)
85
+ elif file_ext == "pptx":
86
+ text = await extract_text_from_pptx(file)
87
  elif file_ext == "xlsx":
88
+ text = await extract_text_from_excel(file)
89
  else:
90
  return "❌ Unsupported file format!"
91
 
 
93
  return "⚠️ No text extracted from the document."
94
 
95
  truncated_text = truncate_text(text)
96
+
97
+ # Run QA model asynchronously
98
+ loop = asyncio.get_event_loop()
99
+ qa = get_qa_pipeline()
100
+ response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")
101
 
102
  return response[0]["generated_text"]
103