Spaces:
Sleeping
Sleeping
# app/gradio_interface.py | |
import os | |
import gradio as gr | |
import time | |
import threading | |
import tempfile | |
import shutil | |
from typing import Dict, List, Optional, Tuple, Union, Any | |
import json | |
import markdown | |
import matplotlib.pyplot as plt | |
import numpy as np | |
from PIL import Image | |
import io | |
import base64 | |
class GradioInterface: | |
def __init__(self, orchestrator): | |
"""Initialize the Gradio interface with the orchestrator.""" | |
self.orchestrator = orchestrator | |
self.active_sessions = {} | |
self.processing_threads = {} | |
# Create temporary directory for file uploads | |
self.temp_dir = tempfile.mkdtemp() | |
self.text_dir = os.path.join(self.temp_dir, "texts") | |
self.image_dir = os.path.join(self.temp_dir, "images") | |
os.makedirs(self.text_dir, exist_ok=True) | |
os.makedirs(self.image_dir, exist_ok=True) | |
def create_interface(self): | |
"""Create and return the Gradio interface.""" | |
with gr.Blocks(title="Deep Dive Analysis with Sustainable AI", | |
theme=gr.themes.Soft(primary_hue="teal")) as interface: | |
# Session management | |
session_id = gr.State("") | |
processing_status = gr.State("idle") | |
result_data = gr.State(None) | |
gr.Markdown("# ๐ฟ Deep Dive Analysis with Sustainable AI") | |
gr.Markdown("Upload text files and images to analyze a topic in depth, with optimized AI processing.") | |
with gr.Row(): | |
with gr.Column(scale=2): | |
# Input section | |
with gr.Blocks(): | |
gr.Markdown("## ๐ Input") | |
topic_input = gr.Textbox(label="Topic for Deep Dive", placeholder="Enter a topic to analyze...") | |
with gr.Row(): | |
text_files = gr.File(label="Upload Text Files", file_count="multiple", file_types=[".txt", ".md", ".pdf", ".docx"]) | |
image_files = gr.File(label="Upload Images", file_count="multiple", file_types=["image"]) | |
analyze_btn = gr.Button("Start Analysis", variant="primary") | |
status_msg = gr.Markdown("Ready to analyze.") | |
with gr.Column(scale=1): | |
# Sustainability metrics | |
with gr.Blocks(): | |
gr.Markdown("## ๐ Sustainability Metrics") | |
metrics_display = gr.Markdown("No metrics available yet.") | |
metrics_chart = gr.Plot(label="Energy Usage") | |
update_metrics_btn = gr.Button("Update Metrics") | |
# Results section | |
with gr.Blocks(): | |
gr.Markdown("## ๐ Analysis Results") | |
with gr.Tabs() as tabs: | |
with gr.TabItem("Executive Summary"): | |
exec_summary = gr.Markdown("No results available yet.") | |
confidence_indicator = gr.Markdown("") | |
with gr.TabItem("Detailed Report"): | |
detailed_report = gr.Markdown("No detailed report available yet.") | |
with gr.TabItem("Text Analysis"): | |
text_analysis = gr.Markdown("No text analysis available yet.") | |
with gr.TabItem("Image Analysis"): | |
with gr.Row(): | |
image_gallery = gr.Gallery(label="Analyzed Images") | |
image_analysis = gr.Markdown("No image analysis available yet.") | |
with gr.TabItem("Raw Data"): | |
raw_json = gr.JSON(None) | |
# Define event handlers | |
def initialize_session(): | |
"""Initialize a new session.""" | |
new_session = self.orchestrator.create_session() | |
return new_session, "idle", None | |
def on_ui_load(): | |
new_session, status, result_data = initialize_session() | |
return new_session, status, result_data | |
def process_files(session, topic, text_files, image_files, status): | |
"""Process uploaded files and start analysis.""" | |
if not topic: | |
return session, "error", "Please enter a topic for analysis.", None | |
if not text_files and not image_files: | |
return session, "error", "Please upload at least one text file or image.", None | |
# Save uploaded files to temp directories | |
text_file_paths = [] | |
if text_files: | |
for file in text_files: | |
dest_path = os.path.join(self.text_dir, os.path.basename(file.name)) | |
shutil.copy(file.name, dest_path) | |
text_file_paths.append(dest_path) | |
image_file_paths = [] | |
if image_files: | |
for file in image_files: | |
dest_path = os.path.join(self.image_dir, os.path.basename(file.name)) | |
shutil.copy(file.name, dest_path) | |
image_file_paths.append(dest_path) | |
# Start processing in a separate thread to avoid blocking the UI | |
def process_thread(): | |
try: | |
print("Starting workflow processing thread") | |
# Use synchronized workflow for better control | |
result = self.orchestrator.coordinate_workflow_with_synchronization( | |
session, topic, text_file_paths, image_file_paths) | |
print(f"Workflow completed with result status: {result.get('status', 'unknown')}") | |
print(f"Result keys: {result.keys() if isinstance(result, dict) else 'Not a dict'}") | |
# Store result for UI access | |
self.active_sessions[session] = result | |
print(f"Updated session {session} with result") | |
except Exception as e: | |
print(f"ERROR in processing thread: {str(e)}") | |
self.active_sessions[session] = {"error": str(e), "status": "error"} | |
# Start processing thread | |
thread = threading.Thread(target=process_thread) | |
thread.daemon = True | |
thread.start() | |
self.processing_threads[session] = thread | |
return session, "processing", "Analysis in progress... This may take a few minutes.", None | |
def check_status(session, status): | |
"""Check the status of the current processing job.""" | |
if session and session in self.active_sessions: | |
result = self.active_sessions[session] | |
if isinstance(result, dict): | |
if "error" in result: | |
return "error", f"Error: {result['error']}", result | |
elif result.get("status") == "completed": | |
print(f"Result data: {list(result.keys())}") | |
if "report" in result: | |
print(f"Report data: {list(result['report'].keys()) if isinstance(result['report'], dict) else 'Not a dict'}") | |
return "completed", "Analysis completed successfully!", result | |
if status == "processing": | |
return status, "Analysis in progress... This may take a few minutes.", None | |
return status, "Ready to analyze.", None | |
def update_results(result_data): | |
"""Update the UI with results.""" | |
if not result_data: | |
return ("No results available yet.", | |
"", | |
"No detailed report available yet.", | |
"No text analysis available yet.", | |
[], | |
"No image analysis available yet.", | |
None) | |
# Extract results | |
exec_summary_text = "No executive summary available." | |
confidence_text = "" | |
detailed_report_text = "No detailed report available." | |
text_analysis_text = "No text analysis available." | |
image_list = [] | |
image_analysis_text = "No image analysis available." | |
# Process report data | |
if "report" in result_data: | |
report = result_data["report"] | |
# Executive summary | |
if "executive_summary" in report: | |
exec_summary_text = report["executive_summary"] | |
# Confidence statement | |
if "confidence_statement" in report: | |
confidence_level = report.get("confidence_level", "unknown") | |
confidence_text = f"**Confidence Level: {confidence_level.title()}**\n\n" | |
confidence_text += report["confidence_statement"] | |
# Detailed report | |
if "detailed_report" in report: | |
detailed_report_text = report["detailed_report"] | |
# Process text analysis | |
if "results" in result_data and "text_analysis" in result_data["results"]: | |
text_data = result_data["results"]["text_analysis"] | |
if "document_analyses" in text_data: | |
text_analysis_text = f"### Text Analysis Results\n\n" | |
text_analysis_text += f"Found {text_data.get('relevant_documents', 0)} relevant documents out of {text_data.get('total_documents', 0)}.\n\n" | |
for i, doc in enumerate(text_data["document_analyses"]): | |
text_analysis_text += f"#### Document {i+1}: {doc.get('filename', 'Unknown')}\n\n" | |
text_analysis_text += f"Relevance: {doc.get('relevance_score', 0):.2f}\n\n" | |
text_analysis_text += f"{doc.get('summary', 'No summary available.')}\n\n" | |
# Process image analysis | |
if "results" in result_data and "image_analysis" in result_data["results"]: | |
img_data = result_data["results"]["image_analysis"] | |
if "image_analyses" in img_data: | |
image_analysis_text = f"### Image Analysis Results\n\n" | |
image_analysis_text += f"Found {img_data.get('relevant_images', 0)} relevant images out of {img_data.get('total_images', 0)}.\n\n" | |
# Get processed images for gallery | |
if "processed_images" in img_data: | |
for img_info in img_data["processed_images"]: | |
if img_info.get("is_relevant", False): | |
try: | |
img_path = img_info.get("filepath", "") | |
if os.path.exists(img_path): | |
# Add to gallery | |
image_list.append((img_path, img_info.get("caption", "No caption"))) | |
except Exception as e: | |
print(f"Error loading image: {e}") | |
# Format analysis text | |
for i, img in enumerate(img_data["image_analyses"]): | |
image_analysis_text += f"#### Image {i+1}: {img.get('filename', 'Unknown')}\n\n" | |
image_analysis_text += f"Caption: {img.get('caption', 'No caption available.')}\n\n" | |
image_analysis_text += f"Relevance: {img.get('relevance_score', 0):.2f}\n\n" | |
image_analysis_text += f"Model used: {img.get('model_used', 'unknown')}\n\n" | |
return (exec_summary_text, | |
confidence_text, | |
detailed_report_text, | |
text_analysis_text, | |
image_list, | |
image_analysis_text, | |
result_data) | |
def update_metrics(): | |
"""Update sustainability metrics display.""" | |
metrics = self.orchestrator.get_sustainability_metrics() | |
if "error" in metrics: | |
return "No metrics available: " + metrics["error"], None | |
# Format metrics for display | |
metrics_text = "### Sustainability Metrics\n\n" | |
# Energy usage | |
energy_usage = metrics.get("energy_usage", {}).get("total", 0) | |
metrics_text += f"**Total Energy Usage**: {energy_usage:.6f} Wh\n\n" | |
# Carbon footprint | |
carbon = metrics.get("carbon_footprint_kg", 0) | |
metrics_text += f"**Carbon Footprint**: {carbon:.6f} kg COโ\n\n" | |
# Optimization gains | |
opt_gains = metrics.get("optimization_gains", {}) | |
tokens_saved = opt_gains.get("tokens_saved", 0) | |
tokens_saved_pct = opt_gains.get("tokens_saved_pct", 0) | |
energy_saved = opt_gains.get("total_energy_saved", 0) | |
metrics_text += f"**Tokens Saved**: {tokens_saved} ({tokens_saved_pct:.1f}%)\n\n" | |
metrics_text += f"**Energy Saved**: {energy_saved:.6f} Wh\n\n" | |
# Environmental equivalents | |
env_equiv = metrics.get("environmental_equivalents", {}) | |
if env_equiv: | |
metrics_text += "### Environmental Impact\n\n" | |
for impact, value in env_equiv.items(): | |
name = impact.replace("_", " ").title() | |
metrics_text += f"**{name}**: {value:.2f}\n\n" | |
# Create chart | |
fig, ax = plt.subplots(figsize=(6, 4)) | |
# Energy by model | |
energy_by_model = metrics.get("energy_usage", {}).get("by_model", {}) | |
if energy_by_model: | |
models = list(energy_by_model.keys()) | |
values = list(energy_by_model.values()) | |
# Shorten model names for display | |
short_names = [m.split("/")[-1] if "/" in m else m for m in models] | |
ax.bar(short_names, values) | |
ax.set_ylabel("Energy (Wh)") | |
ax.set_title("Energy Usage by Model") | |
plt.xticks(rotation=45, ha="right") | |
plt.tight_layout() | |
return metrics_text, fig | |
# Connect event handlers | |
#session_id = gr.on_load(initialize_session)[0] | |
analyze_btn.click( | |
process_files, | |
inputs=[session_id, topic_input, text_files, image_files, processing_status], | |
outputs=[session_id, processing_status, status_msg, result_data] | |
) | |
# Periodic status check | |
# gr.on( | |
# "change", | |
# lambda s, st: check_status(s, st), | |
# inputs=[session_id, processing_status], | |
# outputs=[processing_status, status_msg, result_data], | |
# every=2 # Check every 2 seconds | |
# ) | |
refresh_btn = gr.Button("Refresh Status", visible=False) | |
refresh_btn.click( | |
fn=lambda s, st: check_status(s, st), | |
inputs=[session_id, processing_status], | |
outputs=[processing_status, status_msg, result_data], | |
) | |
interface.load(fn=on_ui_load, | |
outputs=[session_id, processing_status, result_data], | |
js=""" | |
function() { | |
setInterval(function() { | |
document.querySelector("button[id^='refresh']").click(); | |
}, 2000); // 2000 milliseconds = 2 seconds | |
return []; // Return empty array since JS doesn't handle the Python outputs | |
} | |
""" | |
) | |
# Update results when result_data changes | |
result_data.change( | |
update_results, | |
inputs=[result_data], | |
outputs=[exec_summary, confidence_indicator, detailed_report, text_analysis, | |
image_gallery, image_analysis, raw_json] | |
) | |
# Update metrics | |
update_metrics_btn.click( | |
update_metrics, | |
inputs=[], | |
outputs=[metrics_display, metrics_chart], | |
) | |
return interface | |
def launch(self, **kwargs): | |
"""Launch the Gradio interface.""" | |
interface = self.create_interface() | |
interface.launch(**kwargs) | |
def cleanup(self): | |
"""Clean up temporary files.""" | |
try: | |
shutil.rmtree(self.temp_dir) | |
except Exception as e: | |
print(f"Error cleaning up temp files: {e}") | |