Spaces:
Running
Running
from fastapi import FastAPI | |
import pdfplumber | |
import easyocr | |
import docx | |
import openpyxl | |
from pptx import Presentation | |
from transformers import pipeline | |
import gradio as gr | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from fastapi.responses import RedirectResponse | |
import io | |
# β Initialize FastAPI | |
app = FastAPI() | |
# β Load AI Models | |
vqa_pipeline = pipeline("image-to-text", model="Salesforce/blip-vqa-base") | |
code_generator = pipeline("text-generation", model="openai-community/gpt2-medium") | |
table_analyzer = pipeline("table-question-answering", model="google/tapas-large-finetuned-wtq") | |
# β Corrected Question-Answering Model | |
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2") | |
# β Functions for Document & Image QA | |
def extract_text_from_pdf(pdf_file): | |
text = "" | |
with pdfplumber.open(pdf_file) as pdf: | |
for page in pdf.pages: | |
text += page.extract_text() + "\n" if page.extract_text() else "" | |
return text.strip() | |
def extract_text_from_docx(docx_file): | |
doc = docx.Document(docx_file) | |
return "\n".join([para.text for para in doc.paragraphs]) | |
def extract_text_from_pptx(pptx_file): | |
ppt = Presentation(pptx_file) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text.append(shape.text) | |
return "\n".join(text) | |
def extract_text_from_excel(excel_file): | |
wb = openpyxl.load_workbook(excel_file) | |
text = [] | |
for sheet in wb.worksheets: | |
for row in sheet.iter_rows(values_only=True): | |
text.append(" ".join([str(cell) if cell is not None else "" for cell in row])) | |
return "\n".join(text) | |
def extract_text_from_image(image_file): | |
reader = easyocr.Reader(["en"]) | |
result = reader.readtext(image_file) | |
return " ".join([res[1] for res in result]) | |
def answer_question_from_document(file, question): | |
file_ext = file.name.split(".")[-1].lower() | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file) | |
elif file_ext == "docx": | |
text = extract_text_from_docx(file) | |
elif file_ext == "pptx": | |
text = extract_text_from_pptx(file) | |
elif file_ext == "xlsx": | |
text = extract_text_from_excel(file) | |
else: | |
return "Unsupported file format!" | |
if not text: | |
return "No text extracted from the document." | |
response = qa_pipeline({"question": question, "context": text}) | |
return response["answer"] | |
def answer_question_from_image(image, question): | |
image_text = extract_text_from_image(image) | |
if not image_text: | |
return "No text detected in the image." | |
response = qa_pipeline({"question": question, "context": image_text}) | |
return response["answer"] | |
# β Gradio UI for Document & Image QA | |
doc_interface = gr.Interface( | |
fn=answer_question_from_document, | |
inputs=[gr.File(label="Upload Document"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="AI Document Question Answering" | |
) | |
img_interface = gr.Interface( | |
fn=answer_question_from_image, | |
inputs=[gr.Image(label="Upload Image"), gr.Textbox(label="Ask a Question")], | |
outputs="text", | |
title="AI Image Question Answering" | |
) | |
# β Data Visualization Function | |
def generate_visualization(excel_file, viz_type, user_request): | |
try: | |
df = pd.read_excel(excel_file) | |
df = df.astype(str).fillna("") | |
table_input = { | |
"table": df.to_dict(orient="records"), | |
"query": user_request.strip() if isinstance(user_request, str) else "What is the summary?" | |
} | |
table_answer = table_analyzer(**table_input) | |
prompt = ( | |
f"Given a dataset with columns {list(df.columns)}, generate Python code using Matplotlib and Seaborn " | |
f"to create a {viz_type.lower()} based on: {user_request}. Only return valid Python code, no explanations." | |
) | |
code_response = code_generator(prompt, max_new_tokens=150, do_sample=True) | |
if isinstance(code_response, list) and "generated_text" in code_response[0]: | |
generated_code = code_response[0]["generated_text"] | |
else: | |
generated_code = "Error: Model did not return valid code." | |
if "plt" not in generated_code or "sns" not in generated_code: | |
return generated_code, "Generated code seems incorrect." | |
try: | |
exec_globals = {"plt": plt, "sns": sns, "pd": pd, "df": df, "io": io} | |
exec(generated_code, exec_globals) | |
fig = plt.gcf() | |
img_buf = io.BytesIO() | |
fig.savefig(img_buf, format='png') | |
img_buf.seek(0) | |
plt.close(fig) | |
except Exception as e: | |
return generated_code, f"Error in executing visualization: {str(e)}" | |
return generated_code, img_buf | |
except Exception as e: | |
return f"Error: {str(e)}", "Failed to analyze table." | |
# β Gradio UI for Data Visualization | |
viz_interface = gr.Interface( | |
fn=generate_visualization, | |
inputs=[ | |
gr.File(label="Upload Excel File"), | |
gr.Radio(["Bar Chart", "Line Chart", "Scatter Plot", "Histogram"], label="Choose Visualization Type"), | |
gr.Textbox(label="Enter Visualization Request") | |
], | |
outputs=[gr.Code(label="Generated Python Code"), gr.Image(label="Visualization Output")], | |
title="AI-Powered Data Visualization" | |
) | |
# β Mount Gradio Interfaces | |
demo = gr.TabbedInterface([doc_interface, img_interface, viz_interface], ["Document QA", "Image QA", "Data Visualization"]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |