File size: 9,446 Bytes
e2fade1
0c9548a
0540355
 
29f5581
0540355
 
 
 
 
29f5581
 
b1622cb
d74850e
 
 
29f5581
0540355
4e1a845
29f5581
 
 
4e1a845
29f5581
 
 
4e1a845
29f5581
239c804
0540355
29f5581
8e24199
29f5581
1be9899
0540355
 
4e1a845
0540355
4e1a845
0540355
 
4e1a845
0540355
 
 
29f5581
4e1a845
29f5581
4e1a845
 
 
29f5581
 
 
4e1a845
29f5581
4e1a845
 
 
2be14bd
29f5581
 
4e1a845
29f5581
4e1a845
 
 
2be14bd
0540355
29f5581
4e1a845
29f5581
4e1a845
 
 
29f5581
 
 
4e1a845
29f5581
0540355
 
29f5581
4e1a845
 
93ae425
4e1a845
0540355
29f5581
 
4e1a845
29f5581
 
4e1a845
29f5581
4e1a845
d74850e
0540355
0b363e7
 
 
93ae425
 
0540355
29f5581
0b363e7
29f5581
 
 
 
 
0b363e7
29f5581
2be14bd
4e1a845
d2931fe
0540355
2be14bd
4e1a845
d2931fe
0540355
7e5ddc3
29f5581
 
4e1a845
29f5581
 
 
753db53
4e1a845
2852c90
2be14bd
93ae425
0540355
 
 
 
93ae425
 
0540355
93ae425
0540355
 
93ae425
0540355
01cb6f1
4e1a845
0540355
f404b85
d74850e
1b0d519
 
f404b85
e2fade1
a768964
 
 
e2fade1
a768964
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8802df4
a768964
 
 
8802df4
a768964
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
"""import gradio as gr
import numpy as np
import fitz  # PyMuPDF
import torch
import asyncio
from fastapi import FastAPI
from transformers import pipeline
from PIL import Image
from starlette.responses import RedirectResponse
from openpyxl import load_workbook
from docx import Document
from pptx import Presentation

# Initialize FastAPI
app = FastAPI()

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")

# Function to load models lazily
def get_qa_pipeline():
    print("πŸ”„ Loading QA pipeline model...")
    return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)

def get_image_captioning_pipeline():
    print("πŸ”„ Loading Image Captioning model...")
    return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")

ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
MAX_INPUT_LENGTH = 1024  # Limit input length for faster processing

# βœ… Validate File Type
def validate_file_type(file):
    if hasattr(file, "name"):
        ext = file.name.split(".")[-1].lower()
        print(f"πŸ“ File extension detected: {ext}")
        if ext not in ALLOWED_EXTENSIONS:
            print(f"❌ Unsupported file format: {ext}")
            return f"❌ Unsupported file format: {ext}"
        return None
    print("❌ Invalid file format!")
    return "❌ Invalid file format!"

# βœ… Extract Text from PDF
async def extract_text_from_pdf(file):
    print(f"πŸ“„ Extracting text from PDF: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
    print(f"βœ… Extracted {len(text)} characters from PDF")
    return text

# βœ… Extract Text from DOCX
async def extract_text_from_docx(file):
    print(f"πŸ“„ Extracting text from DOCX: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
    print(f"βœ… Extracted {len(text)} characters from DOCX")
    return text

# βœ… Extract Text from PPTX
async def extract_text_from_pptx(file):
    print(f"πŸ“„ Extracting text from PPTX: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
    print(f"βœ… Extracted {len(text)} characters from PPTX")
    return text

# βœ… Extract Text from Excel
async def extract_text_from_excel(file):
    print(f"πŸ“„ Extracting text from Excel: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
    print(f"βœ… Extracted {len(text)} characters from Excel")
    return text

# βœ… Truncate Long Text
def truncate_text(text):
    print(f"βœ‚οΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
    return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text

# βœ… Answer Questions from Image or Document
async def answer_question(file, question: str):
    print(f"❓ Question received: {question}")

    if isinstance(file, np.ndarray):  # Image Processing
        print("πŸ–ΌοΈ Processing image for captioning...")
        image = Image.fromarray(file)
        image_captioning = get_image_captioning_pipeline()
        caption = image_captioning(image)[0]['generated_text']
        print(f"πŸ“ Generated caption: {caption}")

        qa = get_qa_pipeline()
        print("πŸ€– Running QA model...")
        response = qa(f"Question: {question}\nContext: {caption}")
        print(f"βœ… Model response: {response[0]['generated_text']}")
        return response[0]["generated_text"]

    validation_error = validate_file_type(file)
    if validation_error:
        return validation_error
    
    file_ext = file.name.split(".")[-1].lower()

    # Extract text asynchronously
    if file_ext == "pdf":
        text = await extract_text_from_pdf(file)
    elif file_ext == "docx":
        text = await extract_text_from_docx(file)
    elif file_ext == "pptx":
        text = await extract_text_from_pptx(file)
    elif file_ext == "xlsx":
        text = await extract_text_from_excel(file)
    else:
        print("❌ Unsupported file format!")
        return "❌ Unsupported file format!"

    if not text:
        print("⚠️ No text extracted from the document.")
        return "⚠️ No text extracted from the document."

    truncated_text = truncate_text(text)
    
    # Run QA model asynchronously
    print("πŸ€– Running QA model...")
    loop = asyncio.get_event_loop()
    qa = get_qa_pipeline()
    response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")

    print(f"βœ… Model response: {response[0]['generated_text']}")
    return response[0]["generated_text"]

# βœ… Gradio Interface (Separate File & Image Inputs)
with gr.Blocks() as demo:
    gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")

    with gr.Row():
        file_input = gr.File(label="Upload Document")
        image_input = gr.Image(label="Upload Image")
    
    question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
    answer_output = gr.Textbox(label="Answer")
    submit_btn = gr.Button("Get Answer")
    
    submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)


# βœ… Mount Gradio with FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")
"""
import gradio as gr
import numpy as np
import fitz  # PyMuPDF
import torch
from fastapi import FastAPI
from transformers import pipeline
from PIL import Image
from starlette.responses import RedirectResponse
from openpyxl import load_workbook
from docx import Document
from pptx import Presentation

# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Check if CUDA is Available (For Debugging)
device = "cpu"
print(f"βœ… Running on: {device}")

# βœ… Lazy Load Model Function (Loads Only When Needed)
def get_qa_pipeline():
    print("πŸ”„ Loading QA Model on CPU...")
    return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1)

def get_image_captioning_pipeline():
    print("πŸ”„ Loading Image Captioning Model on CPU...")
    return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning", device=-1)

# βœ… File Type Validation
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}

def validate_file_type(file):
    print(f"πŸ“‚ Validating file: {file.name}")
    ext = file.name.split(".")[-1].lower()
    return None if ext in ALLOWED_EXTENSIONS else f"❌ Unsupported file format: {ext}"

# βœ… Extract Text Functions (Optimized)
def extract_text_from_pdf(file):
    print("πŸ“„ Extracting text from PDF...")
    with fitz.open(file.name) as doc:
        return " ".join(page.get_text() for page in doc)

def extract_text_from_docx(file):
    print("πŸ“„ Extracting text from DOCX...")
    doc = Document(file.name)
    return " ".join(p.text for p in doc.paragraphs)

def extract_text_from_pptx(file):
    print("πŸ“„ Extracting text from PPTX...")
    ppt = Presentation(file.name)
    return " ".join(shape.text for slide in ppt.slides for shape in slide.shapes if hasattr(shape, "text"))

def extract_text_from_excel(file):
    print("πŸ“Š Extracting text from Excel...")
    wb = load_workbook(file.name, data_only=True)
    return " ".join(" ".join(str(cell) for cell in row if cell) for sheet in wb.worksheets for row in sheet.iter_rows(values_only=True))

# βœ… Question Answering Function (Efficient Processing)
async def answer_question(file, question: str):
    print("πŸ” Processing file for QA...")
    
    validation_error = validate_file_type(file)
    if validation_error:
        return validation_error

    file_ext = file.name.split(".")[-1].lower()
    text = ""

    if file_ext == "pdf":
        text = extract_text_from_pdf(file)
    elif file_ext == "docx":
        text = extract_text_from_docx(file)
    elif file_ext == "pptx":
        text = extract_text_from_pptx(file)
    elif file_ext == "xlsx":
        text = extract_text_from_excel(file)
    
    if not text.strip():
        return "⚠️ No text extracted from the document."

    print("βœ‚οΈ Truncating text for faster processing...")
    truncated_text = text[:1024]  # Reduce to 1024 characters for better speed

    qa_pipeline = get_qa_pipeline()
    response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}")

    return response[0]["generated_text"]

# βœ… Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")

    with gr.Row():
        file_input = gr.File(label="Upload Document")
        question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
    
    answer_output = gr.Textbox(label="Answer")
    submit_btn = gr.Button("Get Answer")
    
    submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)

# βœ… Mount Gradio with FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")