Spaces:
Running
Running
from fastapi import FastAPI, File, UploadFile | |
import fitz # PyMuPDF for PDF parsing | |
from tika import parser # Apache Tika for document parsing | |
import openpyxl | |
from pptx import Presentation | |
import torch | |
from torchvision import transforms | |
from torchvision.models.detection import fasterrcnn_resnet50_fpn | |
from PIL import Image | |
from transformers import pipeline | |
import gradio as gr | |
from fastapi.responses import RedirectResponse | |
import numpy as np | |
# Initialize FastAPI | |
print("π FastAPI server is starting...") | |
app = FastAPI() | |
# Load AI Model for Question Answering (DeepSeek-V2-Chat) | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# Preload Hugging Face model | |
print(f"π Loading models") | |
qa_pipeline = pipeline("text-generation", model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", device=-1) | |
# Load Pretrained Object Detection Model (Torchvision) | |
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights | |
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT | |
model = fasterrcnn_resnet50_fpn(weights=weights) | |
model.eval() | |
# Image Transformations | |
transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
# Allowed File Extensions | |
ALLOWED_EXTENSIONS = {"pdf", "docx", "pptx", "xlsx"} | |
def validate_file_type(file): | |
ext = file.name.split(".")[-1].lower() | |
print(f"π Validating file type: {ext}") | |
if ext not in ALLOWED_EXTENSIONS: | |
return f"β Unsupported file format: {ext}" | |
return None | |
# Function to truncate text to 450 tokens | |
def truncate_text(text, max_tokens=450): | |
words = text.split() | |
truncated = " ".join(words[:max_tokens]) | |
print(f"βοΈ Truncated text to {max_tokens} tokens.") | |
return truncated | |
# Document Text Extraction Functions | |
def extract_text_from_pdf(pdf_file): | |
try: | |
print("π Extracting text from PDF...") | |
doc = fitz.open(pdf_file) | |
text = "\n".join([page.get_text("text") for page in doc]) | |
print("β PDF text extraction completed.") | |
return text if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading PDF: {str(e)}" | |
def extract_text_with_tika(file): | |
try: | |
print("π Extracting text with Tika...") | |
parsed = parser.from_buffer(file) | |
print("β Tika text extraction completed.") | |
return parsed.get("content", "β οΈ No text found.").strip() | |
except Exception as e: | |
return f"β Error reading document: {str(e)}" | |
def extract_text_from_pptx(pptx_file): | |
try: | |
print("π Extracting text from PPTX...") | |
ppt = Presentation(pptx_file) | |
text = [] | |
for slide in ppt.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
text.append(shape.text) | |
print("β PPTX text extraction completed.") | |
return "\n".join(text) if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading PPTX: {str(e)}" | |
def extract_text_from_excel(excel_file): | |
try: | |
print("π Extracting text from Excel...") | |
wb = openpyxl.load_workbook(excel_file, read_only=True) | |
text = [] | |
for sheet in wb.worksheets: | |
for row in sheet.iter_rows(values_only=True): | |
text.append(" ".join(map(str, row))) | |
print("β Excel text extraction completed.") | |
return "\n".join(text) if text else "β οΈ No text found." | |
except Exception as e: | |
return f"β Error reading Excel: {str(e)}" | |
def answer_question_from_document(file, question): | |
print("π Processing document for QA...") | |
validation_error = validate_file_type(file) | |
if validation_error: | |
return validation_error | |
file_ext = file.name.split(".")[-1].lower() | |
if file_ext == "pdf": | |
text = extract_text_from_pdf(file) | |
elif file_ext in ["docx", "pptx"]: | |
text = extract_text_with_tika(file) | |
elif file_ext == "xlsx": | |
text = extract_text_from_excel(file) | |
else: | |
return "β Unsupported file format!" | |
if not text: | |
return "β οΈ No text extracted from the document." | |
truncated_text = truncate_text(text) | |
print("π€ Generating response...") | |
response = qa_pipeline(f"Question: {question}\nContext: {truncated_text}") | |
print("β AI response generated.") | |
return response[0]["generated_text"] | |
print("β Models loaded successfully.") | |
doc_interface = gr.Interface(fn=answer_question_from_document, inputs=[gr.File(), gr.Textbox()], outputs="text") | |
demo = gr.TabbedInterface([doc_interface], ["Document QA"]) | |
app = gr.mount_gradio_app(app, demo, path="/") | |
def home(): | |
return RedirectResponse(url="/") | |
"""import gradio as gr | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
from fastapi import FastAPI | |
from transformers import pipeline | |
from fastapi.responses import RedirectResponse | |
import io | |
import ast | |
from PIL import Image | |
import re | |
# β Load AI models | |
print("π Initializing application...") | |
table_analyzer = pipeline("question-answering", model="deepset/tinyroberta-squad2", device=-1) | |
code_generator = pipeline("text-generation", model="distilgpt2", device=-1) | |
print("β AI models loaded successfully!") | |
# β Initialize FastAPI | |
app = FastAPI() | |
def generate_visualization(excel_file, viz_type, user_request): | |
Generates Python visualization code and insights based on user requests and Excel data. | |
try: | |
print("π Loading Excel file...") | |
df = pd.read_excel(excel_file) | |
print("β File loaded successfully! Columns:", df.columns) | |
# Convert date columns | |
for col in df.select_dtypes(include=["object", "datetime64"]): | |
try: | |
df[col] = pd.to_datetime(df[col], errors='coerce').dt.strftime('%Y-%m-%d %H:%M:%S') | |
except Exception: | |
pass | |
df = df.fillna(0) # Fill NaN values | |
formatted_table = [{col: str(value) for col, value in row.items()} for row in df.to_dict(orient="records")] | |
print(f"π Formatted table: {formatted_table[:5]}") | |
print(f"π User request: {user_request}") | |
if not isinstance(user_request, str): | |
raise ValueError("User request must be a string") | |
print("π§ Sending data to TAPAS model for analysis...") | |
table_answer = table_analyzer({"table": formatted_table, "query": user_request}) | |
print("β Table analysis completed!") | |
# β AI-generated code | |
prompt = f Generate clean and executable Python code to visualize the following dataset: | |
Columns: {list(df.columns)} | |
Visualization type: {viz_type} | |
User request: {user_request} | |
Use the provided DataFrame 'df' without reloading it. | |
Ensure 'plt.show()' is at the end. | |
print("π€ Sending request to AI code generator...") | |
generated_code = code_generator(prompt, max_length=200)[0]['generated_text'] | |
print("π AI-generated code:") | |
print(generated_code) | |
# β Validate generated code | |
valid_syntax = re.match(r".*plt\.show\(\).*", generated_code, re.DOTALL) | |
if not valid_syntax: | |
print("β οΈ AI code generation failed! Using fallback visualization...") | |
return generated_code, "Error: The AI did not generate a valid Matplotlib script." | |
try: | |
ast.parse(generated_code) # Syntax validation | |
except SyntaxError as e: | |
return generated_code, f"Syntax error: {e}" | |
# β Execute AI-generated code | |
try: | |
print("β‘ Executing AI-generated code...") | |
exec_globals = {"plt": plt, "sns": sns, "pd": pd, "df": df.copy(), "io": io} | |
exec(generated_code, exec_globals) | |
fig = plt.gcf() | |
img_buf = io.BytesIO() | |
fig.savefig(img_buf, format='png') | |
img_buf.seek(0) | |
plt.close(fig) | |
except Exception as e: | |
print(f"β Error executing AI-generated code: {str(e)}") | |
return generated_code, f"Error executing visualization: {str(e)}" | |
img = Image.open(img_buf) | |
return generated_code, img | |
except Exception as e: | |
print(f"β An error occurred: {str(e)}") | |
return f"Error: {str(e)}", "Table analysis failed." | |
# β Gradio UI setup | |
print("π οΈ Setting up Gradio interface...") | |
gradio_ui = gr.Interface( | |
fn=generate_visualization, | |
inputs=[ | |
gr.File(label="Upload Excel File"), | |
gr.Radio([ | |
"Bar Chart", "Line Chart", "Scatter Plot", "Histogram", | |
"Boxplot", "Heatmap", "Pie Chart", "Area Chart", "Bubble Chart", "Violin Plot" | |
], label="Select Visualization Type"), | |
gr.Textbox(label="Enter visualization request (e.g., 'Sales trend over time')") | |
], | |
outputs=[ | |
gr.Code(label="Generated Python Code"), | |
gr.Image(label="Visualization Result") | |
], | |
title="AI-Powered Data Visualization π", | |
description="Upload an Excel file, choose your visualization type, and ask a question about your data!" | |
) | |
print("β Gradio interface configured successfully!") | |
# β Mount Gradio app | |
print("π Mounting Gradio interface on FastAPI...") | |
app = gr.mount_gradio_app(app, gradio_ui, path="/") | |
print("β Gradio interface mounted successfully!") | |
@app.get("/") | |
def home(): | |
print("π Redirecting to UI...") | |
return RedirectResponse(url="/")""" |