File size: 4,257 Bytes
7e5ddc3
2be14bd
 
 
 
dbe3ba4
 
 
 
2be14bd
a5ffabc
 
0c9548a
7e5ddc3
2be14bd
 
7e5ddc3
 
d36238a
dbe3ba4
 
 
 
 
 
 
 
 
7e5ddc3
c724805
 
49b29c3
2be14bd
7e5ddc3
2be14bd
 
 
 
bb7eb3d
2be14bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb7eb3d
2be14bd
 
dbe3ba4
7e5ddc3
0c9548a
 
 
 
7e5ddc3
0c9548a
 
 
7e5ddc3
a5ffabc
 
2be14bd
 
a5ffabc
2be14bd
a5ffabc
2be14bd
a5ffabc
2be14bd
a5ffabc
2be14bd
a5ffabc
2be14bd
 
a5ffabc
7e5ddc3
 
 
 
 
 
2be14bd
7e5ddc3
 
 
 
dbe3ba4
7e5ddc3
 
c724805
f57a980
7e5ddc3
f57a980
a5ffabc
7e5ddc3
a5ffabc
 
 
 
 
 
 
7e5ddc3
 
 
 
 
 
 
 
 
a5ffabc
2be14bd
a5ffabc
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from fastapi import FastAPI, File, UploadFile
import pdfplumber
import docx
import openpyxl
from pptx import Presentation
import torch
from torchvision import transforms
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from PIL import Image
from transformers import pipeline
import gradio as gr
from fastapi.responses import RedirectResponse
import numpy as np
# Initialize FastAPI
app = FastAPI()

# Load AI Model for Question Answering
qa_pipeline = pipeline("text2text-generation", model="google/flan-t5-large", tokenizer="google/flan-t5-large", use_fast=True)

# Load Pretrained Object Detection Model (Torchvision)
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.eval()

# Image Transformations
transform = transforms.Compose([
    transforms.ToTensor()
])

# Function to truncate text to 450 tokens
def truncate_text(text, max_tokens=450):
    words = text.split()
    return " ".join(words[:max_tokens])

# Functions to extract text from different file formats
def extract_text_from_pdf(pdf_file):
    text = ""
    with pdfplumber.open(pdf_file) as pdf:
        for page in pdf.pages:
            text += page.extract_text() + "\n"
    return text.strip()

def extract_text_from_docx(docx_file):
    doc = docx.Document(docx_file)
    return "\n".join([para.text for para in doc.paragraphs])

def extract_text_from_pptx(pptx_file):
    ppt = Presentation(pptx_file)
    text = []
    for slide in ppt.slides:
        for shape in slide.shapes:
            if hasattr(shape, "text"):
                text.append(shape.text)
    return "\n".join(text)

def extract_text_from_excel(excel_file):
    wb = openpyxl.load_workbook(excel_file)
    text = []
    for sheet in wb.worksheets:
        for row in sheet.iter_rows(values_only=True):
            text.append(" ".join(map(str, row)))
    return "\n".join(text)

# Function to perform object detection using Torchvision
def extract_text_from_image(image_file):
    if isinstance(image_file, np.ndarray):  # Check if input is a NumPy array
        image = Image.fromarray(image_file)  # Convert NumPy array to PIL image
    else:
        image = Image.open(image_file).convert("RGB")  # Handle file input

    reader = easyocr.Reader(["en"])
    result = reader.readtext(np.array(image))  # Convert PIL image back to NumPy array
    return " ".join([res[1] for res in result])
# Function to answer questions based on document content
def answer_question_from_document(file, question):
    file_ext = file.name.split(".")[-1].lower()
    
    if file_ext == "pdf":
        text = extract_text_from_pdf(file)
    elif file_ext == "docx":
        text = extract_text_from_docx(file)
    elif file_ext == "pptx":
        text = extract_text_from_pptx(file)
    elif file_ext == "xlsx":
        text = extract_text_from_excel(file)
    else:
        return "Unsupported file format!"

    if not text:
        return "No text extracted from the document."
    
    truncated_text = truncate_text(text)
    input_text = f"Question: {question} Context: {truncated_text}"
    response = qa_pipeline(input_text)
    
    return response[0]["generated_text"]

# Function to answer questions based on image content
def answer_question_from_image(image, question):
    image_text = extract_text_from_image(image)
    if not image_text:
        return "No meaningful content detected in the image."
    
    truncated_text = truncate_text(image_text)
    input_text = f"Question: {question} Context: {truncated_text}"
    response = qa_pipeline(input_text)
    
    return response[0]["generated_text"]

# Gradio UI for Document & Image QA
doc_interface = gr.Interface(
    fn=answer_question_from_document,
    inputs=[gr.File(label="Upload Document"), gr.Textbox(label="Ask a Question")],
    outputs="text",
    title="AI Document Question Answering"
)

img_interface = gr.Interface(
    fn=answer_question_from_image,
    inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")],
    outputs="text",
    title="AI Image Question Answering"
)

# Mount Gradio Interfaces
demo = gr.TabbedInterface([doc_interface, img_interface], ["Document QA", "Image QA"])
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def home():
    return RedirectResponse(url="/")