Spaces:
Sleeping
Sleeping
lighten app file
Browse files- app.py +4 -170
- process.py +172 -0
app.py
CHANGED
@@ -1,25 +1,20 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
-
|
4 |
-
import numpy as np
|
5 |
from litellm import OpenAI
|
6 |
import json
|
7 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
8 |
from sklearn.cluster import KMeans
|
9 |
from sklearn.decomposition import PCA
|
10 |
import matplotlib.pyplot as plt
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import traceback
|
14 |
import logging
|
15 |
from dotenv import load_dotenv
|
16 |
-
|
17 |
# Load environment variables from .env file
|
18 |
load_dotenv()
|
19 |
|
20 |
# Import local modules
|
21 |
-
from
|
22 |
-
from utils import load_data, export_data, visualize_results, validate_results
|
23 |
from prompts import (
|
24 |
CATEGORY_SUGGESTION_PROMPT,
|
25 |
ADDITIONAL_CATEGORY_PROMPT,
|
@@ -45,167 +40,6 @@ if OPENAI_API_KEY:
|
|
45 |
logging.error(f"Failed to initialize OpenAI client: {str(e)}")
|
46 |
|
47 |
|
48 |
-
def update_api_key(api_key):
|
49 |
-
"""Update the OpenAI API key"""
|
50 |
-
global OPENAI_API_KEY, client
|
51 |
-
|
52 |
-
if not api_key:
|
53 |
-
return "API Key cannot be empty"
|
54 |
-
|
55 |
-
OPENAI_API_KEY = api_key
|
56 |
-
|
57 |
-
try:
|
58 |
-
client = OpenAI(api_key=api_key)
|
59 |
-
# Test the connection with a simple request
|
60 |
-
response = client.chat.completions.create(
|
61 |
-
model="gpt-3.5-turbo",
|
62 |
-
messages=[{"role": "user", "content": "test"}],
|
63 |
-
max_tokens=5,
|
64 |
-
)
|
65 |
-
return f"API Key updated and verified successfully"
|
66 |
-
except Exception as e:
|
67 |
-
error_msg = str(e)
|
68 |
-
logging.error(f"API key update failed: {error_msg}")
|
69 |
-
return f"Failed to update API Key: {error_msg}"
|
70 |
-
|
71 |
-
|
72 |
-
def process_file(file, text_columns, categories, classifier_type, show_explanations):
|
73 |
-
"""Process the uploaded file and classify text data"""
|
74 |
-
# Initialize result_df and validation_report
|
75 |
-
result_df = None
|
76 |
-
validation_report = None
|
77 |
-
|
78 |
-
try:
|
79 |
-
# Load data from file
|
80 |
-
if isinstance(file, str):
|
81 |
-
df = load_data(file)
|
82 |
-
else:
|
83 |
-
df = load_data(file.name)
|
84 |
-
|
85 |
-
if not text_columns:
|
86 |
-
return None, "Please select at least one text column"
|
87 |
-
|
88 |
-
# Check if all selected columns exist
|
89 |
-
missing_columns = [col for col in text_columns if col not in df.columns]
|
90 |
-
if missing_columns:
|
91 |
-
return (
|
92 |
-
None,
|
93 |
-
f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}",
|
94 |
-
)
|
95 |
-
|
96 |
-
# Combine text from selected columns
|
97 |
-
texts = []
|
98 |
-
for _, row in df.iterrows():
|
99 |
-
combined_text = " ".join(str(row[col]) for col in text_columns)
|
100 |
-
texts.append(combined_text)
|
101 |
-
|
102 |
-
# Parse categories if provided
|
103 |
-
category_list = []
|
104 |
-
if categories:
|
105 |
-
category_list = [cat.strip() for cat in categories.split(",")]
|
106 |
-
|
107 |
-
# Select classifier based on data size and user choice
|
108 |
-
num_texts = len(texts)
|
109 |
-
|
110 |
-
# If no specific model is chosen, select the most appropriate one
|
111 |
-
if classifier_type == "auto":
|
112 |
-
if num_texts <= 500:
|
113 |
-
classifier_type = "gpt4"
|
114 |
-
elif num_texts <= 1000:
|
115 |
-
classifier_type = "gpt35"
|
116 |
-
elif num_texts <= 5000:
|
117 |
-
classifier_type = "hybrid"
|
118 |
-
else:
|
119 |
-
classifier_type = "tfidf"
|
120 |
-
|
121 |
-
# Initialize appropriate classifier
|
122 |
-
if classifier_type == "tfidf":
|
123 |
-
classifier = TFIDFClassifier()
|
124 |
-
results = classifier.classify(texts, category_list)
|
125 |
-
elif classifier_type in ["gpt35", "gpt4"]:
|
126 |
-
if client is None:
|
127 |
-
return (
|
128 |
-
None,
|
129 |
-
"Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
|
130 |
-
)
|
131 |
-
model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
|
132 |
-
classifier = LLMClassifier(client=client, model=model)
|
133 |
-
results = classifier.classify(texts, category_list)
|
134 |
-
else: # hybrid
|
135 |
-
if client is None:
|
136 |
-
return (
|
137 |
-
None,
|
138 |
-
"Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
|
139 |
-
)
|
140 |
-
# First pass with TF-IDF
|
141 |
-
tfidf_classifier = TFIDFClassifier()
|
142 |
-
tfidf_results = tfidf_classifier.classify(texts, category_list)
|
143 |
-
|
144 |
-
# Second pass with LLM for low confidence results
|
145 |
-
llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
|
146 |
-
results = []
|
147 |
-
low_confidence_texts = []
|
148 |
-
low_confidence_indices = []
|
149 |
-
|
150 |
-
for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)):
|
151 |
-
if tfidf_result["confidence"] < 70: # If confidence is below 70%
|
152 |
-
low_confidence_texts.append(text)
|
153 |
-
low_confidence_indices.append(i)
|
154 |
-
results.append(None) # Placeholder
|
155 |
-
else:
|
156 |
-
results.append(tfidf_result)
|
157 |
-
|
158 |
-
if low_confidence_texts:
|
159 |
-
llm_results = llm_classifier.classify(
|
160 |
-
low_confidence_texts, category_list
|
161 |
-
)
|
162 |
-
for idx, llm_result in zip(low_confidence_indices, llm_results):
|
163 |
-
results[idx] = llm_result
|
164 |
-
|
165 |
-
# Create results dataframe
|
166 |
-
result_df = df.copy()
|
167 |
-
result_df["Category"] = [r["category"] for r in results]
|
168 |
-
result_df["Confidence"] = [r["confidence"] for r in results]
|
169 |
-
|
170 |
-
if show_explanations:
|
171 |
-
result_df["Explanation"] = [r["explanation"] for r in results]
|
172 |
-
|
173 |
-
# Validate results using LLM
|
174 |
-
validation_report = validate_results(result_df, text_columns, client)
|
175 |
-
|
176 |
-
return result_df, validation_report
|
177 |
-
|
178 |
-
except Exception as e:
|
179 |
-
error_traceback = traceback.format_exc()
|
180 |
-
return None, f"Error: {str(e)}\n{error_traceback}"
|
181 |
-
|
182 |
-
|
183 |
-
def export_results(df, format_type):
|
184 |
-
"""Export results to a file and return the file path for download"""
|
185 |
-
if df is None:
|
186 |
-
return None
|
187 |
-
|
188 |
-
# Create a temporary file
|
189 |
-
import tempfile
|
190 |
-
import os
|
191 |
-
|
192 |
-
# Create a temporary directory if it doesn't exist
|
193 |
-
temp_dir = "temp_exports"
|
194 |
-
os.makedirs(temp_dir, exist_ok=True)
|
195 |
-
|
196 |
-
# Generate a unique filename
|
197 |
-
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
198 |
-
filename = f"classification_results_{timestamp}"
|
199 |
-
|
200 |
-
if format_type == "excel":
|
201 |
-
file_path = os.path.join(temp_dir, f"{filename}.xlsx")
|
202 |
-
df.to_excel(file_path, index=False)
|
203 |
-
else:
|
204 |
-
file_path = os.path.join(temp_dir, f"{filename}.csv")
|
205 |
-
df.to_csv(file_path, index=False)
|
206 |
-
|
207 |
-
return file_path
|
208 |
-
|
209 |
|
210 |
# Create Gradio interface
|
211 |
with gr.Blocks(title="Text Classification System") as demo:
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
|
|
|
4 |
from litellm import OpenAI
|
5 |
import json
|
|
|
6 |
from sklearn.cluster import KMeans
|
7 |
from sklearn.decomposition import PCA
|
8 |
import matplotlib.pyplot as plt
|
9 |
+
|
|
|
|
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
+
from process import update_api_key, process_file, export_results
|
13 |
# Load environment variables from .env file
|
14 |
load_dotenv()
|
15 |
|
16 |
# Import local modules
|
17 |
+
from utils import load_data, visualize_results
|
|
|
18 |
from prompts import (
|
19 |
CATEGORY_SUGGESTION_PROMPT,
|
20 |
ADDITIONAL_CATEGORY_PROMPT,
|
|
|
40 |
logging.error(f"Failed to initialize OpenAI client: {str(e)}")
|
41 |
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
# Create Gradio interface
|
45 |
with gr.Blocks(title="Text Classification System") as demo:
|
process.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
import logging
|
4 |
+
import time
|
5 |
+
import traceback
|
6 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
+
|
8 |
+
from litellm import OpenAI
|
9 |
+
from classifiers import TFIDFClassifier, LLMClassifier
|
10 |
+
from utils import load_data, validate_results
|
11 |
+
|
12 |
+
|
13 |
+
def update_api_key(api_key):
|
14 |
+
"""Update the OpenAI API key"""
|
15 |
+
global OPENAI_API_KEY, client
|
16 |
+
|
17 |
+
if not api_key:
|
18 |
+
return "API Key cannot be empty"
|
19 |
+
|
20 |
+
OPENAI_API_KEY = api_key
|
21 |
+
|
22 |
+
try:
|
23 |
+
client = OpenAI(api_key=api_key)
|
24 |
+
# Test the connection with a simple request
|
25 |
+
response = client.chat.completions.create(
|
26 |
+
model="gpt-3.5-turbo",
|
27 |
+
messages=[{"role": "user", "content": "test"}],
|
28 |
+
max_tokens=5,
|
29 |
+
)
|
30 |
+
return f"API Key updated and verified successfully"
|
31 |
+
except Exception as e:
|
32 |
+
error_msg = str(e)
|
33 |
+
logging.error(f"API key update failed: {error_msg}")
|
34 |
+
return f"Failed to update API Key: {error_msg}"
|
35 |
+
|
36 |
+
|
37 |
+
def process_file(file, text_columns, categories, classifier_type, show_explanations):
|
38 |
+
"""Process the uploaded file and classify text data"""
|
39 |
+
# Initialize result_df and validation_report
|
40 |
+
result_df = None
|
41 |
+
validation_report = None
|
42 |
+
|
43 |
+
try:
|
44 |
+
# Load data from file
|
45 |
+
if isinstance(file, str):
|
46 |
+
df = load_data(file)
|
47 |
+
else:
|
48 |
+
df = load_data(file.name)
|
49 |
+
|
50 |
+
if not text_columns:
|
51 |
+
return None, "Please select at least one text column"
|
52 |
+
|
53 |
+
# Check if all selected columns exist
|
54 |
+
missing_columns = [col for col in text_columns if col not in df.columns]
|
55 |
+
if missing_columns:
|
56 |
+
return (
|
57 |
+
None,
|
58 |
+
f"Columns not found in the file: {', '.join(missing_columns)}. Available columns: {', '.join(df.columns)}",
|
59 |
+
)
|
60 |
+
|
61 |
+
# Combine text from selected columns
|
62 |
+
texts = []
|
63 |
+
for _, row in df.iterrows():
|
64 |
+
combined_text = " ".join(str(row[col]) for col in text_columns)
|
65 |
+
texts.append(combined_text)
|
66 |
+
|
67 |
+
# Parse categories if provided
|
68 |
+
category_list = []
|
69 |
+
if categories:
|
70 |
+
category_list = [cat.strip() for cat in categories.split(",")]
|
71 |
+
|
72 |
+
# Select classifier based on data size and user choice
|
73 |
+
num_texts = len(texts)
|
74 |
+
|
75 |
+
# If no specific model is chosen, select the most appropriate one
|
76 |
+
if classifier_type == "auto":
|
77 |
+
if num_texts <= 500:
|
78 |
+
classifier_type = "gpt4"
|
79 |
+
elif num_texts <= 1000:
|
80 |
+
classifier_type = "gpt35"
|
81 |
+
elif num_texts <= 5000:
|
82 |
+
classifier_type = "hybrid"
|
83 |
+
else:
|
84 |
+
classifier_type = "tfidf"
|
85 |
+
|
86 |
+
# Initialize appropriate classifier
|
87 |
+
if classifier_type == "tfidf":
|
88 |
+
classifier = TFIDFClassifier()
|
89 |
+
results = classifier.classify(texts, category_list)
|
90 |
+
elif classifier_type in ["gpt35", "gpt4"]:
|
91 |
+
if client is None:
|
92 |
+
return (
|
93 |
+
None,
|
94 |
+
"Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
|
95 |
+
)
|
96 |
+
model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
|
97 |
+
classifier = LLMClassifier(client=client, model=model)
|
98 |
+
results = classifier.classify(texts, category_list)
|
99 |
+
else: # hybrid
|
100 |
+
if client is None:
|
101 |
+
return (
|
102 |
+
None,
|
103 |
+
"Erreur : Le client API n'est pas initialisé. Veuillez configurer une clé API valide dans l'onglet 'Setup'.",
|
104 |
+
)
|
105 |
+
# First pass with TF-IDF
|
106 |
+
tfidf_classifier = TFIDFClassifier()
|
107 |
+
tfidf_results = tfidf_classifier.classify(texts, category_list)
|
108 |
+
|
109 |
+
# Second pass with LLM for low confidence results
|
110 |
+
llm_classifier = LLMClassifier(client=client, model="gpt-3.5-turbo")
|
111 |
+
results = []
|
112 |
+
low_confidence_texts = []
|
113 |
+
low_confidence_indices = []
|
114 |
+
|
115 |
+
for i, (text, tfidf_result) in enumerate(zip(texts, tfidf_results)):
|
116 |
+
if tfidf_result["confidence"] < 70: # If confidence is below 70%
|
117 |
+
low_confidence_texts.append(text)
|
118 |
+
low_confidence_indices.append(i)
|
119 |
+
results.append(None) # Placeholder
|
120 |
+
else:
|
121 |
+
results.append(tfidf_result)
|
122 |
+
|
123 |
+
if low_confidence_texts:
|
124 |
+
llm_results = llm_classifier.classify(
|
125 |
+
low_confidence_texts, category_list
|
126 |
+
)
|
127 |
+
for idx, llm_result in zip(low_confidence_indices, llm_results):
|
128 |
+
results[idx] = llm_result
|
129 |
+
|
130 |
+
# Create results dataframe
|
131 |
+
result_df = df.copy()
|
132 |
+
result_df["Category"] = [r["category"] for r in results]
|
133 |
+
result_df["Confidence"] = [r["confidence"] for r in results]
|
134 |
+
|
135 |
+
if show_explanations:
|
136 |
+
result_df["Explanation"] = [r["explanation"] for r in results]
|
137 |
+
|
138 |
+
# Validate results using LLM
|
139 |
+
validation_report = validate_results(result_df, text_columns, client)
|
140 |
+
|
141 |
+
return result_df, validation_report
|
142 |
+
|
143 |
+
except Exception as e:
|
144 |
+
error_traceback = traceback.format_exc()
|
145 |
+
return None, f"Error: {str(e)}\n{error_traceback}"
|
146 |
+
|
147 |
+
|
148 |
+
def export_results(df, format_type):
|
149 |
+
"""Export results to a file and return the file path for download"""
|
150 |
+
if df is None:
|
151 |
+
return None
|
152 |
+
|
153 |
+
# Create a temporary file
|
154 |
+
import tempfile
|
155 |
+
import os
|
156 |
+
|
157 |
+
# Create a temporary directory if it doesn't exist
|
158 |
+
temp_dir = "temp_exports"
|
159 |
+
os.makedirs(temp_dir, exist_ok=True)
|
160 |
+
|
161 |
+
# Generate a unique filename
|
162 |
+
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
163 |
+
filename = f"classification_results_{timestamp}"
|
164 |
+
|
165 |
+
if format_type == "excel":
|
166 |
+
file_path = os.path.join(temp_dir, f"{filename}.xlsx")
|
167 |
+
df.to_excel(file_path, index=False)
|
168 |
+
else:
|
169 |
+
file_path = os.path.join(temp_dir, f"{filename}.csv")
|
170 |
+
df.to_csv(file_path, index=False)
|
171 |
+
|
172 |
+
return file_path
|