File size: 6,120 Bytes
e2fade1
0c9548a
0540355
 
29f5581
0540355
 
 
 
 
29f5581
 
b1622cb
d74850e
 
 
29f5581
0540355
4e1a845
29f5581
 
 
4e1a845
29f5581
 
 
4e1a845
29f5581
239c804
0540355
29f5581
8e24199
29f5581
1be9899
0540355
 
4e1a845
0540355
4e1a845
0540355
 
4e1a845
0540355
 
 
29f5581
4e1a845
29f5581
4e1a845
 
 
29f5581
 
 
4e1a845
29f5581
4e1a845
 
 
2be14bd
29f5581
 
4e1a845
29f5581
4e1a845
 
 
2be14bd
0540355
29f5581
4e1a845
29f5581
4e1a845
 
 
29f5581
 
 
4e1a845
29f5581
0540355
 
29f5581
4e1a845
 
93ae425
4e1a845
0540355
29f5581
 
4e1a845
29f5581
 
4e1a845
29f5581
4e1a845
d74850e
0540355
0b363e7
 
 
93ae425
 
0540355
29f5581
0b363e7
29f5581
 
 
 
 
0b363e7
29f5581
2be14bd
4e1a845
d2931fe
0540355
2be14bd
4e1a845
d2931fe
0540355
7e5ddc3
29f5581
 
4e1a845
29f5581
 
 
753db53
4e1a845
2852c90
2be14bd
93ae425
0540355
 
 
 
93ae425
 
0540355
93ae425
0540355
 
93ae425
0540355
01cb6f1
4e1a845
0540355
f404b85
d74850e
1b0d519
 
f404b85
e2fade1
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""import gradio as gr
import numpy as np
import fitz  # PyMuPDF
import torch
import asyncio
from fastapi import FastAPI
from transformers import pipeline
from PIL import Image
from starlette.responses import RedirectResponse
from openpyxl import load_workbook
from docx import Document
from pptx import Presentation

# Initialize FastAPI
app = FastAPI()

# Use GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"βœ… Using device: {device}")

# Function to load models lazily
def get_qa_pipeline():
    print("πŸ”„ Loading QA pipeline model...")
    return pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=device, torch_dtype=torch.float16)

def get_image_captioning_pipeline():
    print("πŸ”„ Loading Image Captioning model...")
    return pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")

ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"}
MAX_INPUT_LENGTH = 1024  # Limit input length for faster processing

# βœ… Validate File Type
def validate_file_type(file):
    if hasattr(file, "name"):
        ext = file.name.split(".")[-1].lower()
        print(f"πŸ“ File extension detected: {ext}")
        if ext not in ALLOWED_EXTENSIONS:
            print(f"❌ Unsupported file format: {ext}")
            return f"❌ Unsupported file format: {ext}"
        return None
    print("❌ Invalid file format!")
    return "❌ Invalid file format!"

# βœ… Extract Text from PDF
async def extract_text_from_pdf(file):
    print(f"πŸ“„ Extracting text from PDF: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([page.get_text() for page in fitz.open(file.name)]))
    print(f"βœ… Extracted {len(text)} characters from PDF")
    return text

# βœ… Extract Text from DOCX
async def extract_text_from_docx(file):
    print(f"πŸ“„ Extracting text from DOCX: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([p.text for p in Document(file).paragraphs]))
    print(f"βœ… Extracted {len(text)} characters from DOCX")
    return text

# βœ… Extract Text from PPTX
async def extract_text_from_pptx(file):
    print(f"πŸ“„ Extracting text from PPTX: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([shape.text for slide in Presentation(file).slides for shape in slide.shapes if hasattr(shape, "text")]))
    print(f"βœ… Extracted {len(text)} characters from PPTX")
    return text

# βœ… Extract Text from Excel
async def extract_text_from_excel(file):
    print(f"πŸ“„ Extracting text from Excel: {file.name}")
    loop = asyncio.get_event_loop()
    text = await loop.run_in_executor(None, lambda: "\n".join([" ".join(str(cell) for cell in row if cell) for sheet in load_workbook(file.name, data_only=True).worksheets for row in sheet.iter_rows(values_only=True)]))
    print(f"βœ… Extracted {len(text)} characters from Excel")
    return text

# βœ… Truncate Long Text
def truncate_text(text):
    print(f"βœ‚οΈ Truncating text to {MAX_INPUT_LENGTH} characters (if needed)...")
    return text[:MAX_INPUT_LENGTH] if len(text) > MAX_INPUT_LENGTH else text

# βœ… Answer Questions from Image or Document
async def answer_question(file, question: str):
    print(f"❓ Question received: {question}")

    if isinstance(file, np.ndarray):  # Image Processing
        print("πŸ–ΌοΈ Processing image for captioning...")
        image = Image.fromarray(file)
        image_captioning = get_image_captioning_pipeline()
        caption = image_captioning(image)[0]['generated_text']
        print(f"πŸ“ Generated caption: {caption}")

        qa = get_qa_pipeline()
        print("πŸ€– Running QA model...")
        response = qa(f"Question: {question}\nContext: {caption}")
        print(f"βœ… Model response: {response[0]['generated_text']}")
        return response[0]["generated_text"]

    validation_error = validate_file_type(file)
    if validation_error:
        return validation_error
    
    file_ext = file.name.split(".")[-1].lower()

    # Extract text asynchronously
    if file_ext == "pdf":
        text = await extract_text_from_pdf(file)
    elif file_ext == "docx":
        text = await extract_text_from_docx(file)
    elif file_ext == "pptx":
        text = await extract_text_from_pptx(file)
    elif file_ext == "xlsx":
        text = await extract_text_from_excel(file)
    else:
        print("❌ Unsupported file format!")
        return "❌ Unsupported file format!"

    if not text:
        print("⚠️ No text extracted from the document.")
        return "⚠️ No text extracted from the document."

    truncated_text = truncate_text(text)
    
    # Run QA model asynchronously
    print("πŸ€– Running QA model...")
    loop = asyncio.get_event_loop()
    qa = get_qa_pipeline()
    response = await loop.run_in_executor(None, qa, f"Question: {question}\nContext: {truncated_text}")

    print(f"βœ… Model response: {response[0]['generated_text']}")
    return response[0]["generated_text"]

# βœ… Gradio Interface (Separate File & Image Inputs)
with gr.Blocks() as demo:
    gr.Markdown("## πŸ“„ AI-Powered Document & Image QA")

    with gr.Row():
        file_input = gr.File(label="Upload Document")
        image_input = gr.Image(label="Upload Image")
    
    question_input = gr.Textbox(label="Ask a Question", placeholder="What is this document about?")
    answer_output = gr.Textbox(label="Answer")
    submit_btn = gr.Button("Get Answer")
    
    submit_btn.click(answer_question, inputs=[file_input, question_input], outputs=answer_output)


# βœ… Mount Gradio with FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")
"""
import torch
print("CUDA Available:", torch.cuda.is_available())
print("Torch Device Count:", torch.cuda.device_count())
print("Current Device:", torch.cuda.current_device() if torch.cuda.is_available() else "CPU")
print("CUDA Device Name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")