Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import pandas as pd | |
import numpy as np | |
from litellm import OpenAI | |
import json | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.cluster import KMeans | |
from sklearn.decomposition import PCA | |
import matplotlib.pyplot as plt | |
import time | |
import torch | |
import traceback | |
import logging | |
# Import local modules | |
from classifiers import TFIDFClassifier, LLMClassifier | |
from utils import load_data, export_data, visualize_results, validate_results | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
# Initialize API key from environment variable | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") | |
# Only initialize client if API key is available | |
client = None | |
if OPENAI_API_KEY: | |
try: | |
client = OpenAI(api_key=OPENAI_API_KEY) | |
logging.info("OpenAI client initialized successfully") | |
except Exception as e: | |
logging.error(f"Failed to initialize OpenAI client: {str(e)}") | |
def update_api_key(api_key): | |
"""Update the OpenAI API key""" | |
global OPENAI_API_KEY, client | |
if not api_key: | |
return "API Key cannot be empty" | |
OPENAI_API_KEY = api_key | |
try: | |
client = OpenAI(api_key=api_key) | |
# Test the connection with a simple request | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": "test"}], | |
max_tokens=5 | |
) | |
return f"API Key updated and verified successfully" | |
except Exception as e: | |
error_msg = str(e) | |
logging.error(f"API key update failed: {error_msg}") | |
return f"Failed to update API Key: {error_msg}" | |
def process_file(file, text_columns, categories, classifier_type, show_explanations): | |
"""Process the uploaded file and classify text data""" | |
try: | |
# Load data from file | |
if isinstance(file, str): | |
df = load_data(file) | |
else: | |
df = load_data(file.name) | |
if not text_columns: | |
return None, "Please select at least one text column" | |
# Check if all selected columns exist | |
missing_columns = [col for col in text_columns if col not in df.columns] | |
if missing_columns: | |
return None, f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}" | |
# Combine text from selected columns | |
texts = [] | |
for _, row in df.iterrows(): | |
combined_text = " ".join(str(row[col]) for col in text_columns) | |
texts.append(combined_text) | |
# Parse categories if provided | |
category_list = [] | |
if categories: | |
category_list = [cat.strip() for cat in categories.split(",")] | |
# Select classifier based on data size and user choice | |
num_texts = len(texts) | |
# If no specific model is chosen, select the most appropriate one | |
if classifier_type == "auto": | |
if num_texts <= 500: | |
classifier_type = "gpt4" | |
elif num_texts <= 1000: | |
classifier_type = "gpt35" | |
elif num_texts <= 5000: | |
classifier_type = "hybrid" | |
else: | |
classifier_type = "tfidf" | |
# Initialize appropriate classifier | |
if classifier_type == "tfidf": | |
classifier = TFIDFClassifier() | |
results = classifier.classify(texts, category_list) | |
elif classifier_type == "gpt35": | |
if client is None: | |
return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'." | |
classifier = LLMClassifier(client=client, model="gpt-3.5-turbo") | |
results = classifier.classify(texts, category_list) | |
elif classifier_type == "gpt4": | |
if client is None: | |
return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'." | |
classifier = LLMClassifier(client=client, model="gpt-4") | |
results = classifier.classify(texts, category_list) | |
else: # hybrid | |
if client is None: | |
return None, "Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'." | |
# First pass with TF-IDF | |
tfidf_classifier = TFIDFClassifier() | |
tfidf_results = tfidf_classifier.classify(texts, category_list) | |
# Second pass with LLM for low confidence results | |
llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo") | |
results = [] | |
for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)): | |
if tfidf_result["confidence"] < 70: # If confidence is below 70% | |
llm_result = llm_classifier.classify([text], category_list)[0] | |
results.append(llm_result) | |
else: | |
results.append(tfidf_result) | |
# Create results dataframe | |
result_df = df.copy() | |
result_df["Category"] = [r["category"] for r in results] | |
result_df["Confidence"] = [r["confidence"] for r in results] | |
if show_explanations: | |
result_df["Explanation"] = [r["explanation"] for r in results] | |
# Validate results using LLM | |
validation_report = validate_results(result_df, text_columns, client) | |
return result_df, validation_report | |
except Exception as e: | |
error_traceback = traceback.format_exc() | |
return None, f"Error: {str(e)}\n{error_traceback}" | |
def export_results(df, format_type): | |
"""Export results to a file and return the file path for download""" | |
if df is None: | |
return None | |
# Create a temporary file | |
import tempfile | |
import os | |
# Create a temporary directory if it doesn't exist | |
temp_dir = "temp_exports" | |
os.makedirs(temp_dir, exist_ok=True) | |
# Generate a unique filename | |
timestamp = time.strftime("%Y%m%d-%H%M%S") | |
filename = f"classification_results_{timestamp}" | |
if format_type == "excel": | |
file_path = os.path.join(temp_dir, f"{filename}.xlsx") | |
df.to_excel(file_path, index=False) | |
else: | |
file_path = os.path.join(temp_dir, f"{filename}.csv") | |
df.to_csv(file_path, index=False) | |
return file_path | |
# Create Gradio interface | |
with gr.Blocks(title="Text Classification System") as demo: | |
gr.Markdown("# Text Classification System") | |
gr.Markdown("Upload your data file (Excel/CSV) and classify text using AI") | |
with gr.Tab("Setup"): | |
api_key_input = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="Enter your API key here", | |
type="password", | |
value=OPENAI_API_KEY | |
) | |
api_key_button = gr.Button("Update API Key") | |
api_key_message = gr.Textbox(label="Status", interactive=False) | |
# Display current API status | |
api_status = "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one." | |
gr.Markdown(f"**Current API Status**: {api_status}") | |
api_key_button.click(update_api_key, inputs=[api_key_input], outputs=[api_key_message]) | |
with gr.Tab("Classify Data"): | |
with gr.Column(): | |
file_input = gr.File(label="Upload Excel/CSV File") | |
# Variable to store available columns | |
available_columns = gr.State([]) | |
# Button to load file and suggest categories | |
load_categories_button = gr.Button("Load File") | |
# Display original dataframe | |
original_df = gr.Dataframe( | |
label="Original Data", | |
interactive=False, | |
visible=False | |
) | |
with gr.Row(): | |
with gr.Column(): | |
suggested_categories = gr.CheckboxGroup( | |
label="Suggested Categories", | |
choices=[], | |
value=[], | |
interactive=True, | |
visible=False | |
) | |
new_category = gr.Textbox( | |
label="Add New Category", | |
placeholder="Enter a new category name", | |
visible=False | |
) | |
with gr.Row(): | |
add_category_button = gr.Button("Add Category", visible=False) | |
suggest_category_button = gr.Button("Suggest Category", visible=False) | |
# Original categories input (hidden) | |
categories = gr.Textbox( | |
visible=False | |
) | |
with gr.Column(): | |
text_column = gr.CheckboxGroup( | |
label="Select Text Columns", | |
choices=[], | |
interactive=True, | |
visible=False | |
) | |
classifier_type = gr.Dropdown( | |
choices=[ | |
("TF-IDF (Rapide, <1000 lignes)", "tfidf"), | |
("LLM GPT-3.5 (Fiable, <1000 lignes)", "gpt35"), | |
("LLM GPT-4 (Très fiable, <500 lignes)", "gpt4"), | |
("TF-IDF + LLM (Hybride, >1000 lignes)", "hybrid") | |
], | |
label="Modèle de classification", | |
value="tfidf", | |
visible=False | |
) | |
show_explanations = gr.Checkbox(label="Show Explanations", value=True, visible=False) | |
process_button = gr.Button("Process and Classify", visible=False) | |
results_df = gr.Dataframe(interactive=True, visible=False) | |
# Create containers for visualization and validation report | |
with gr.Row(visible=False) as results_row: | |
with gr.Column(): | |
visualization = gr.Plot(label="Classification Distribution") | |
with gr.Row(): | |
csv_download = gr.File(label="Download CSV", visible=False) | |
excel_download = gr.File(label="Download Excel", visible=False) | |
with gr.Column(): | |
validation_output = gr.Textbox(label="Validation Report", interactive=False) | |
improve_button = gr.Button("Improve Classification with Report", visible=False) | |
# Function to load file and suggest categories | |
def load_file_and_suggest_categories(file): | |
if not file: | |
return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False) | |
try: | |
df = load_data(file.name) | |
columns = list(df.columns) | |
# Analyze columns to suggest text columns | |
suggested_text_columns = [] | |
for col in columns: | |
# Check if column contains text data | |
if df[col].dtype == 'object': # String type | |
# Check if column contains mostly text (not just numbers or dates) | |
sample = df[col].head(100).dropna() | |
if len(sample) > 0: | |
# Check if most values contain spaces (indicating text) | |
text_ratio = sum(' ' in str(val) for val in sample) / len(sample) | |
if text_ratio > 0.3: # If more than 30% of values contain spaces | |
suggested_text_columns.append(col) | |
# If no columns were suggested, use all object columns | |
if not suggested_text_columns: | |
suggested_text_columns = [col for col in columns if df[col].dtype == 'object'] | |
# Get a sample of text for category suggestion | |
sample_texts = [] | |
for col in suggested_text_columns: | |
sample_texts.extend(df[col].head(5).tolist()) | |
# Use LLM to suggest categories | |
if client: | |
prompt = f""" | |
Based on these example texts, suggest 5 appropriate categories for classification: | |
{sample_texts[:5]} | |
Return your answer as a comma-separated list of category names only. | |
""" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.2, | |
max_tokens=100 | |
) | |
suggested_cats = [cat.strip() for cat in response.choices[0].message.content.strip().split(",")] | |
except: | |
suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"] | |
else: | |
suggested_cats = ["Positive", "Negative", "Neutral", "Mixed", "Other"] | |
return ( | |
columns, | |
gr.CheckboxGroup(choices=columns, value=suggested_text_columns), | |
gr.CheckboxGroup(choices=suggested_cats, value=suggested_cats, visible=True), | |
gr.Textbox(visible=True), | |
gr.Button(visible=True), | |
gr.Button(visible=True), | |
gr.CheckboxGroup(choices=columns, value=suggested_text_columns, visible=True), | |
gr.Dropdown(visible=True), | |
gr.Checkbox(visible=True), | |
gr.Button(visible=True), | |
gr.Dataframe(value=df, visible=True) | |
) | |
except Exception as e: | |
return [], gr.CheckboxGroup(choices=[]), gr.CheckboxGroup(choices=[], visible=False), gr.Textbox(visible=False), gr.Button(visible=False), gr.Button(visible=False), gr.CheckboxGroup(choices=[], visible=False), gr.Dropdown(visible=False), gr.Checkbox(visible=False), gr.Button(visible=False), gr.Dataframe(visible=False) | |
# Function to add a new category | |
def add_new_category(current_categories, new_category): | |
if not new_category or new_category.strip() == "": | |
return current_categories | |
new_categories = current_categories + [new_category.strip()] | |
return gr.CheckboxGroup(choices=new_categories, value=new_categories) | |
# Function to update categories textbox | |
def update_categories_textbox(selected_categories): | |
return ", ".join(selected_categories) | |
# Function to show results after processing | |
def show_results(df, validation_report): | |
if df is None: | |
return gr.Row(visible=False), gr.File(visible=False), gr.File(visible=False), gr.Dataframe(visible=False) | |
# Export to both formats | |
csv_path = export_results(df, "csv") | |
excel_path = export_results(df, "excel") | |
return gr.Row(visible=True), gr.File(value=csv_path, visible=True), gr.File(value=excel_path, visible=True), gr.Dataframe(value=df, visible=True) | |
# Function to suggest a new category | |
def suggest_new_category(file, current_categories, text_columns): | |
if not file or not text_columns: | |
return gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
try: | |
df = load_data(file.name) | |
# Get sample texts from selected columns | |
sample_texts = [] | |
for col in text_columns: | |
sample_texts.extend(df[col].head(5).tolist()) | |
if client: | |
prompt = f""" | |
Based on these example texts and the existing categories ({', '.join(current_categories)}), | |
suggest one additional appropriate category for classification. | |
Example texts: | |
{sample_texts[:5]} | |
Return only the suggested category name, nothing else. | |
""" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.2, | |
max_tokens=50 | |
) | |
new_cat = response.choices[0].message.content.strip() | |
if new_cat and new_cat not in current_categories: | |
current_categories.append(new_cat) | |
except: | |
pass | |
return gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
except Exception as e: | |
return gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
# Function to handle export and show download button | |
def handle_export(df, format_type): | |
if df is None: | |
return gr.File(visible=False) | |
file_path = export_results(df, format_type) | |
return gr.File(value=file_path, visible=True) | |
# Function to improve classification based on validation report | |
def improve_classification(df, validation_report, text_columns, categories, classifier_type, show_explanations, file): | |
"""Improve classification based on validation report""" | |
if df is None or not validation_report: | |
return df, validation_report, gr.Button(visible=False), gr.CheckboxGroup(choices=[], value=[]) | |
try: | |
# Extract insights from validation report | |
if client: | |
prompt = f""" | |
Based on this validation report, analyze the current classification and suggest improvements: | |
{validation_report} | |
Return your answer in JSON format with these fields: | |
- suggested_categories: list of improved category names (must be different from current categories: {categories}) | |
- confidence_threshold: a number between 0 and 100 for minimum confidence | |
- focus_areas: list of specific aspects to focus on during classification | |
- analysis: a brief analysis of what needs improvement | |
- new_categories_needed: boolean indicating if new categories should be added | |
JSON response: | |
""" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.2, | |
max_tokens=300 | |
) | |
improvements = json.loads(response.choices[0].message.content.strip()) | |
# Get current categories | |
current_categories = [cat.strip() for cat in categories.split(",")] | |
# If new categories are needed, suggest them based on the data | |
if improvements.get("new_categories_needed", False): | |
# Get sample texts for category suggestion | |
sample_texts = [] | |
for col in text_columns: | |
if isinstance(file, str): | |
temp_df = load_data(file) | |
else: | |
temp_df = load_data(file.name) | |
sample_texts.extend(temp_df[col].head(5).tolist()) | |
category_prompt = f""" | |
Based on these example texts and the current categories ({', '.join(current_categories)}), | |
suggest new categories that would improve the classification. The validation report indicates: | |
{improvements.get('analysis', '')} | |
Example texts: | |
{sample_texts[:5]} | |
Return your answer as a comma-separated list of new category names only. | |
""" | |
category_response = client.chat.completions.create( | |
model="gpt-4", | |
messages=[{"role": "user", "content": category_prompt}], | |
temperature=0.2, | |
max_tokens=100 | |
) | |
new_categories = [cat.strip() for cat in category_response.choices[0].message.content.strip().split(",")] | |
# Combine current and new categories | |
all_categories = current_categories + new_categories | |
categories = ",".join(all_categories) | |
# Process with improved parameters | |
improved_df, new_validation = process_file( | |
file, | |
text_columns, | |
categories, | |
classifier_type, | |
show_explanations | |
) | |
return improved_df, new_validation, gr.Button(visible=True), gr.CheckboxGroup(choices=all_categories, value=all_categories) | |
except Exception as e: | |
print(f"Error in improvement process: {str(e)}") | |
return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
else: | |
return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
except Exception as e: | |
print(f"Error in improvement process: {str(e)}") | |
return df, validation_report, gr.Button(visible=True), gr.CheckboxGroup(choices=current_categories, value=current_categories) | |
# Connect functions | |
load_categories_button.click( | |
load_file_and_suggest_categories, | |
inputs=[file_input], | |
outputs=[ | |
available_columns, | |
text_column, | |
suggested_categories, | |
new_category, | |
add_category_button, | |
suggest_category_button, | |
text_column, | |
classifier_type, | |
show_explanations, | |
process_button, | |
original_df | |
] | |
) | |
add_category_button.click( | |
add_new_category, | |
inputs=[suggested_categories, new_category], | |
outputs=[suggested_categories] | |
) | |
suggested_categories.change( | |
update_categories_textbox, | |
inputs=[suggested_categories], | |
outputs=[categories] | |
) | |
suggest_category_button.click( | |
suggest_new_category, | |
inputs=[file_input, suggested_categories, text_column], | |
outputs=[suggested_categories] | |
) | |
process_button.click( | |
process_file, | |
inputs=[file_input, text_column, categories, classifier_type, show_explanations], | |
outputs=[results_df, validation_output] | |
).then( | |
show_results, | |
inputs=[results_df, validation_output], | |
outputs=[results_row, csv_download, excel_download, results_df] | |
).then( | |
visualize_results, | |
inputs=[results_df, text_column], | |
outputs=[visualization] | |
).then( | |
lambda x: gr.Button(visible=True), | |
inputs=[], | |
outputs=[improve_button] | |
) | |
improve_button.click( | |
improve_classification, | |
inputs=[results_df, validation_output, text_column, categories, classifier_type, show_explanations, file_input], | |
outputs=[results_df, validation_output, improve_button, suggested_categories] | |
).then( | |
show_results, | |
inputs=[results_df, validation_output], | |
outputs=[results_row, csv_download, excel_download, results_df] | |
).then( | |
visualize_results, | |
inputs=[results_df, text_column], | |
outputs=[visualization] | |
) | |
def create_example_data(): | |
"""Create example data for demonstration""" | |
from utils import create_example_file | |
example_path = create_example_file() | |
return f"Example file created at: {example_path}" | |
if __name__ == "__main__": | |
# Create examples directory and sample file if it doesn't exist | |
if not os.path.exists("examples"): | |
create_example_data() | |
# Launch the Gradio app | |
demo.launch() | |