Spaces:
Sleeping
Sleeping
classify async
Browse files- app.py +37 -17
- classifiers/llm.py +62 -74
- client.py +37 -0
- process.py +16 -27
app.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import os
|
2 |
import gradio as gr
|
|
|
3 |
|
4 |
-
from litellm import OpenAI
|
5 |
import json
|
6 |
from sklearn.cluster import KMeans
|
7 |
from sklearn.decomposition import PCA
|
@@ -9,7 +9,9 @@ import matplotlib.pyplot as plt
|
|
9 |
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
-
from process import update_api_key,
|
|
|
|
|
13 |
# Load environment variables from .env file
|
14 |
load_dotenv()
|
15 |
|
@@ -30,16 +32,13 @@ logging.basicConfig(
|
|
30 |
# Initialize API key from environment variable
|
31 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
32 |
|
33 |
-
#
|
34 |
-
client = None
|
35 |
if OPENAI_API_KEY:
|
36 |
-
|
37 |
-
|
38 |
logging.info("OpenAI client initialized successfully")
|
39 |
-
|
40 |
-
logging.error(f"Failed to initialize OpenAI client: {
|
41 |
-
|
42 |
-
|
43 |
|
44 |
# Create Gradio interface
|
45 |
with gr.Blocks(title="Text Classification System") as demo:
|
@@ -57,9 +56,8 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
57 |
api_key_message = gr.Textbox(label="Status", interactive=False)
|
58 |
|
59 |
# Display current API status
|
60 |
-
|
61 |
-
|
62 |
-
)
|
63 |
gr.Markdown(f"**Current API Status**: {api_status}")
|
64 |
|
65 |
api_key_button.click(
|
@@ -344,7 +342,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
344 |
return gr.File(value=file_path, visible=True)
|
345 |
|
346 |
# Function to improve classification based on validation report
|
347 |
-
def
|
348 |
df,
|
349 |
validation_report,
|
350 |
text_columns,
|
@@ -353,7 +351,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
353 |
show_explanations,
|
354 |
file,
|
355 |
):
|
356 |
-
"""
|
357 |
if df is None or not validation_report:
|
358 |
return (
|
359 |
df,
|
@@ -420,7 +418,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
420 |
categories = ",".join(all_categories)
|
421 |
|
422 |
# Process with improved parameters
|
423 |
-
improved_df, new_validation =
|
424 |
file,
|
425 |
text_columns,
|
426 |
categories,
|
@@ -466,6 +464,28 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
466 |
),
|
467 |
)
|
468 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
469 |
# Connect functions
|
470 |
load_categories_button.click(
|
471 |
load_file_and_suggest_categories,
|
@@ -506,7 +526,7 @@ with gr.Blocks(title="Text Classification System") as demo:
|
|
506 |
process_button.click(
|
507 |
lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df]
|
508 |
).then(
|
509 |
-
|
510 |
inputs=[
|
511 |
file_input,
|
512 |
text_column,
|
|
|
1 |
import os
|
2 |
import gradio as gr
|
3 |
+
import asyncio
|
4 |
|
|
|
5 |
import json
|
6 |
from sklearn.cluster import KMeans
|
7 |
from sklearn.decomposition import PCA
|
|
|
9 |
|
10 |
import logging
|
11 |
from dotenv import load_dotenv
|
12 |
+
from process import update_api_key, process_file_async, export_results
|
13 |
+
from client import get_client, initialize_client
|
14 |
+
|
15 |
# Load environment variables from .env file
|
16 |
load_dotenv()
|
17 |
|
|
|
32 |
# Initialize API key from environment variable
|
33 |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
34 |
|
35 |
+
# Initialize client if API key is available
|
|
|
36 |
if OPENAI_API_KEY:
|
37 |
+
success, message = initialize_client(OPENAI_API_KEY)
|
38 |
+
if success:
|
39 |
logging.info("OpenAI client initialized successfully")
|
40 |
+
else:
|
41 |
+
logging.error(f"Failed to initialize OpenAI client: {message}")
|
|
|
|
|
42 |
|
43 |
# Create Gradio interface
|
44 |
with gr.Blocks(title="Text Classification System") as demo:
|
|
|
56 |
api_key_message = gr.Textbox(label="Status", interactive=False)
|
57 |
|
58 |
# Display current API status
|
59 |
+
client = get_client()
|
60 |
+
api_status = "API Key is set" if client else "No API Key found. Please set one."
|
|
|
61 |
gr.Markdown(f"**Current API Status**: {api_status}")
|
62 |
|
63 |
api_key_button.click(
|
|
|
342 |
return gr.File(value=file_path, visible=True)
|
343 |
|
344 |
# Function to improve classification based on validation report
|
345 |
+
async def improve_classification_async(
|
346 |
df,
|
347 |
validation_report,
|
348 |
text_columns,
|
|
|
351 |
show_explanations,
|
352 |
file,
|
353 |
):
|
354 |
+
"""Async version of improve_classification"""
|
355 |
if df is None or not validation_report:
|
356 |
return (
|
357 |
df,
|
|
|
418 |
categories = ",".join(all_categories)
|
419 |
|
420 |
# Process with improved parameters
|
421 |
+
improved_df, new_validation = await process_file_async(
|
422 |
file,
|
423 |
text_columns,
|
424 |
categories,
|
|
|
464 |
),
|
465 |
)
|
466 |
|
467 |
+
def improve_classification(
|
468 |
+
df,
|
469 |
+
validation_report,
|
470 |
+
text_columns,
|
471 |
+
categories,
|
472 |
+
classifier_type,
|
473 |
+
show_explanations,
|
474 |
+
file,
|
475 |
+
):
|
476 |
+
"""Synchronous wrapper for improve_classification_async"""
|
477 |
+
return asyncio.run(
|
478 |
+
improve_classification_async(
|
479 |
+
df,
|
480 |
+
validation_report,
|
481 |
+
text_columns,
|
482 |
+
categories,
|
483 |
+
classifier_type,
|
484 |
+
show_explanations,
|
485 |
+
file,
|
486 |
+
)
|
487 |
+
)
|
488 |
+
|
489 |
# Connect functions
|
490 |
load_categories_button.click(
|
491 |
load_file_and_suggest_categories,
|
|
|
526 |
process_button.click(
|
527 |
lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df]
|
528 |
).then(
|
529 |
+
process_file_async,
|
530 |
inputs=[
|
531 |
file_input,
|
532 |
text_column,
|
classifiers/llm.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1 |
-
|
2 |
import numpy as np
|
3 |
import pandas as pd
|
4 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
@@ -6,14 +5,13 @@ from sklearn.cluster import KMeans
|
|
6 |
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
import random
|
8 |
import json
|
9 |
-
|
10 |
from typing import List, Dict, Any, Optional
|
11 |
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
|
12 |
|
13 |
from .base import BaseClassifier
|
14 |
|
15 |
|
16 |
-
|
17 |
class LLMClassifier(BaseClassifier):
|
18 |
"""Classifier using a Large Language Model for more accurate but slower classification"""
|
19 |
|
@@ -22,77 +20,15 @@ class LLMClassifier(BaseClassifier):
|
|
22 |
self.client = client
|
23 |
self.model = model
|
24 |
|
25 |
-
def
|
26 |
-
|
27 |
-
) -> List[Dict[str, Any]]:
|
28 |
-
"""Classify texts using an LLM with parallel processing"""
|
29 |
-
if not categories:
|
30 |
-
# First, use LLM to generate appropriate categories
|
31 |
-
categories = self._suggest_categories(texts)
|
32 |
-
|
33 |
-
# Process texts in parallel
|
34 |
-
with ThreadPoolExecutor(max_workers=10) as executor:
|
35 |
-
# Submit all tasks with their original indices
|
36 |
-
future_to_index = {
|
37 |
-
executor.submit(self._classify_text, text, categories): idx
|
38 |
-
for idx, text in enumerate(texts)
|
39 |
-
}
|
40 |
-
|
41 |
-
# Initialize results list with None values
|
42 |
-
results = [None] * len(texts)
|
43 |
-
|
44 |
-
# Collect results as they complete
|
45 |
-
for future in as_completed(future_to_index):
|
46 |
-
original_idx = future_to_index[future]
|
47 |
-
try:
|
48 |
-
result = future.result()
|
49 |
-
results[original_idx] = result
|
50 |
-
except Exception as e:
|
51 |
-
print(f"Error processing text: {str(e)}")
|
52 |
-
results[original_idx] = {
|
53 |
-
"category": categories[0],
|
54 |
-
"confidence": 50,
|
55 |
-
"explanation": f"Error during classification: {str(e)}",
|
56 |
-
}
|
57 |
-
|
58 |
-
return results
|
59 |
-
|
60 |
-
def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
|
61 |
-
"""Use LLM to suggest appropriate categories for the dataset"""
|
62 |
-
# Take a sample of texts to avoid token limitations
|
63 |
-
if len(texts) > sample_size:
|
64 |
-
sample_texts = random.sample(texts, sample_size)
|
65 |
-
else:
|
66 |
-
sample_texts = texts
|
67 |
-
|
68 |
-
prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
|
69 |
-
|
70 |
-
try:
|
71 |
-
response = self.client.chat.completions.create(
|
72 |
-
model=self.model,
|
73 |
-
messages=[{"role": "user", "content": prompt}],
|
74 |
-
temperature=0.2,
|
75 |
-
max_tokens=100,
|
76 |
-
)
|
77 |
-
|
78 |
-
# Parse response to get categories
|
79 |
-
categories_text = response.choices[0].message.content.strip()
|
80 |
-
categories = [cat.strip() for cat in categories_text.split(",")]
|
81 |
-
|
82 |
-
return categories
|
83 |
-
except Exception as e:
|
84 |
-
# Fallback to default categories on error
|
85 |
-
print(f"Error suggesting categories: {str(e)}")
|
86 |
-
return self._generate_default_categories(texts)
|
87 |
-
|
88 |
-
def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
|
89 |
-
"""Use LLM to classify a single text"""
|
90 |
prompt = TEXT_CLASSIFICATION_PROMPT.format(
|
91 |
-
categories=", ".join(categories),
|
|
|
92 |
)
|
93 |
|
94 |
try:
|
95 |
-
response = self.client.chat.completions.create(
|
96 |
model=self.model,
|
97 |
messages=[{"role": "user", "content": prompt}],
|
98 |
temperature=0,
|
@@ -101,17 +37,15 @@ class LLMClassifier(BaseClassifier):
|
|
101 |
|
102 |
# Parse JSON response
|
103 |
response_text = response.choices[0].message.content.strip()
|
104 |
-
|
105 |
result = json.loads(response_text)
|
|
|
106 |
# Ensure all required fields are present
|
107 |
if not all(k in result for k in ["category", "confidence", "explanation"]):
|
108 |
raise ValueError("Missing required fields in LLM response")
|
109 |
|
110 |
# Validate category is in the list
|
111 |
if result["category"] not in categories:
|
112 |
-
result["category"] = categories[
|
113 |
-
0
|
114 |
-
] # Default to first category if invalid
|
115 |
|
116 |
# Validate confidence is a number between 0 and 100
|
117 |
try:
|
@@ -135,3 +69,57 @@ class LLMClassifier(BaseClassifier):
|
|
135 |
"confidence": 50,
|
136 |
"explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
|
137 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
|
5 |
from sklearn.metrics.pairwise import cosine_similarity
|
6 |
import random
|
7 |
import json
|
8 |
+
import asyncio
|
9 |
from typing import List, Dict, Any, Optional
|
10 |
from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
|
11 |
|
12 |
from .base import BaseClassifier
|
13 |
|
14 |
|
|
|
15 |
class LLMClassifier(BaseClassifier):
|
16 |
"""Classifier using a Large Language Model for more accurate but slower classification"""
|
17 |
|
|
|
20 |
self.client = client
|
21 |
self.model = model
|
22 |
|
23 |
+
async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
|
24 |
+
"""Async version of text classification"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
prompt = TEXT_CLASSIFICATION_PROMPT.format(
|
26 |
+
categories=", ".join(categories),
|
27 |
+
text=text
|
28 |
)
|
29 |
|
30 |
try:
|
31 |
+
response = await self.client.chat.completions.create(
|
32 |
model=self.model,
|
33 |
messages=[{"role": "user", "content": prompt}],
|
34 |
temperature=0,
|
|
|
37 |
|
38 |
# Parse JSON response
|
39 |
response_text = response.choices[0].message.content.strip()
|
|
|
40 |
result = json.loads(response_text)
|
41 |
+
|
42 |
# Ensure all required fields are present
|
43 |
if not all(k in result for k in ["category", "confidence", "explanation"]):
|
44 |
raise ValueError("Missing required fields in LLM response")
|
45 |
|
46 |
# Validate category is in the list
|
47 |
if result["category"] not in categories:
|
48 |
+
result["category"] = categories[0] # Default to first category if invalid
|
|
|
|
|
49 |
|
50 |
# Validate confidence is a number between 0 and 100
|
51 |
try:
|
|
|
69 |
"confidence": 50,
|
70 |
"explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
|
71 |
}
|
72 |
+
except Exception as e:
|
73 |
+
return {
|
74 |
+
"category": categories[0],
|
75 |
+
"confidence": 50,
|
76 |
+
"explanation": f"Error during classification: {str(e)}",
|
77 |
+
}
|
78 |
+
|
79 |
+
async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
|
80 |
+
"""Async version of category suggestion"""
|
81 |
+
# Take a sample of texts to avoid token limitations
|
82 |
+
if len(texts) > sample_size:
|
83 |
+
sample_texts = random.sample(texts, sample_size)
|
84 |
+
else:
|
85 |
+
sample_texts = texts
|
86 |
+
|
87 |
+
prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
|
88 |
+
|
89 |
+
try:
|
90 |
+
response = await self.client.chat.completions.create(
|
91 |
+
model=self.model,
|
92 |
+
messages=[{"role": "user", "content": prompt}],
|
93 |
+
temperature=0.2,
|
94 |
+
max_tokens=100,
|
95 |
+
)
|
96 |
+
|
97 |
+
# Parse response to get categories
|
98 |
+
categories_text = response.choices[0].message.content.strip()
|
99 |
+
categories = [cat.strip() for cat in categories_text.split(",")]
|
100 |
+
|
101 |
+
return categories
|
102 |
+
except Exception as e:
|
103 |
+
# Fallback to default categories on error
|
104 |
+
print(f"Error suggesting categories: {str(e)}")
|
105 |
+
return self._generate_default_categories(texts)
|
106 |
+
|
107 |
+
async def classify_async(
|
108 |
+
self, texts: List[str], categories: Optional[List[str]] = None
|
109 |
+
) -> List[Dict[str, Any]]:
|
110 |
+
"""Async method to classify texts"""
|
111 |
+
if not categories:
|
112 |
+
categories = await self._suggest_categories_async(texts)
|
113 |
+
|
114 |
+
# Create tasks for all texts
|
115 |
+
tasks = [self._classify_text_async(text, categories) for text in texts]
|
116 |
+
|
117 |
+
# Gather all results
|
118 |
+
results = await asyncio.gather(*tasks)
|
119 |
+
return results
|
120 |
+
|
121 |
+
def classify(
|
122 |
+
self, texts: List[str], categories: Optional[List[str]] = None
|
123 |
+
) -> List[Dict[str, Any]]:
|
124 |
+
"""Synchronous wrapper for backwards compatibility"""
|
125 |
+
return asyncio.run(self.classify_async(texts, categories))
|
client.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from litellm import OpenAI
|
2 |
+
import os
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
|
5 |
+
# Load environment variables
|
6 |
+
load_dotenv()
|
7 |
+
|
8 |
+
# Initialize client as None
|
9 |
+
client = None
|
10 |
+
|
11 |
+
def get_client():
|
12 |
+
"""Get the OpenAI client instance"""
|
13 |
+
global client
|
14 |
+
return client
|
15 |
+
|
16 |
+
def initialize_client(api_key=None):
|
17 |
+
"""Initialize the OpenAI client with an API key"""
|
18 |
+
global client
|
19 |
+
|
20 |
+
# Use provided API key or get from environment
|
21 |
+
api_key = api_key or os.environ.get("OPENAI_API_KEY")
|
22 |
+
|
23 |
+
if not api_key:
|
24 |
+
return False, "No API key provided"
|
25 |
+
|
26 |
+
try:
|
27 |
+
client = OpenAI(api_key=api_key)
|
28 |
+
# Test the connection with a simple request
|
29 |
+
response = client.chat.completions.create(
|
30 |
+
model="gpt-3.5-turbo",
|
31 |
+
messages=[{"role": "user", "content": "test"}],
|
32 |
+
max_tokens=5,
|
33 |
+
)
|
34 |
+
return True, "API Key updated and verified successfully"
|
35 |
+
except Exception as e:
|
36 |
+
client = None
|
37 |
+
return False, f"Failed to initialize client: {str(e)}"
|
process.py
CHANGED
@@ -1,41 +1,22 @@
|
|
1 |
-
|
2 |
-
|
3 |
import logging
|
4 |
import time
|
5 |
import traceback
|
|
|
6 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
7 |
|
8 |
-
from litellm import OpenAI
|
9 |
from classifiers import TFIDFClassifier, LLMClassifier
|
10 |
from utils import load_data, validate_results
|
|
|
11 |
|
12 |
|
13 |
def update_api_key(api_key):
|
14 |
"""Update the OpenAI API key"""
|
15 |
-
|
16 |
-
|
17 |
-
if not api_key:
|
18 |
-
return "API Key cannot be empty"
|
19 |
-
|
20 |
-
OPENAI_API_KEY = api_key
|
21 |
-
|
22 |
-
try:
|
23 |
-
client = OpenAI(api_key=api_key)
|
24 |
-
# Test the connection with a simple request
|
25 |
-
response = client.chat.completions.create(
|
26 |
-
model="gpt-3.5-turbo",
|
27 |
-
messages=[{"role": "user", "content": "test"}],
|
28 |
-
max_tokens=5,
|
29 |
-
)
|
30 |
-
return f"API Key updated and verified successfully"
|
31 |
-
except Exception as e:
|
32 |
-
error_msg = str(e)
|
33 |
-
logging.error(f"API key update failed: {error_msg}")
|
34 |
-
return f"Failed to update API Key: {error_msg}"
|
35 |
|
36 |
|
37 |
-
def
|
38 |
-
"""
|
39 |
# Initialize result_df and validation_report
|
40 |
result_df = None
|
41 |
validation_report = None
|
@@ -83,6 +64,9 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
|
|
83 |
else:
|
84 |
classifier_type = "tfidf"
|
85 |
|
|
|
|
|
|
|
86 |
# Initialize appropriate classifier
|
87 |
if classifier_type == "tfidf":
|
88 |
classifier = TFIDFClassifier()
|
@@ -95,7 +79,7 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
|
|
95 |
)
|
96 |
model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
|
97 |
classifier = LLMClassifier(client=client, model=model)
|
98 |
-
results = classifier.
|
99 |
else: # hybrid
|
100 |
if client is None:
|
101 |
return (
|
@@ -121,7 +105,7 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
|
|
121 |
results.append(tfidf_result)
|
122 |
|
123 |
if low_confidence_texts:
|
124 |
-
llm_results = llm_classifier.
|
125 |
low_confidence_texts, category_list
|
126 |
)
|
127 |
for idx, llm_result in zip(low_confidence_indices, llm_results):
|
@@ -145,6 +129,11 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
|
|
145 |
return None, f"Error: {str(e)}\n{error_traceback}"
|
146 |
|
147 |
|
|
|
|
|
|
|
|
|
|
|
148 |
def export_results(df, format_type):
|
149 |
"""Export results to a file and return the file path for download"""
|
150 |
if df is None:
|
|
|
|
|
|
|
1 |
import logging
|
2 |
import time
|
3 |
import traceback
|
4 |
+
import asyncio
|
5 |
from sklearn.feature_extraction.text import TfidfVectorizer
|
6 |
|
|
|
7 |
from classifiers import TFIDFClassifier, LLMClassifier
|
8 |
from utils import load_data, validate_results
|
9 |
+
from client import get_client
|
10 |
|
11 |
|
12 |
def update_api_key(api_key):
|
13 |
"""Update the OpenAI API key"""
|
14 |
+
from client import initialize_client
|
15 |
+
return initialize_client(api_key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
|
18 |
+
async def process_file_async(file, text_columns, categories, classifier_type, show_explanations):
|
19 |
+
"""Async version of process_file"""
|
20 |
# Initialize result_df and validation_report
|
21 |
result_df = None
|
22 |
validation_report = None
|
|
|
64 |
else:
|
65 |
classifier_type = "tfidf"
|
66 |
|
67 |
+
# Get the client instance
|
68 |
+
client = get_client()
|
69 |
+
|
70 |
# Initialize appropriate classifier
|
71 |
if classifier_type == "tfidf":
|
72 |
classifier = TFIDFClassifier()
|
|
|
79 |
)
|
80 |
model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
|
81 |
classifier = LLMClassifier(client=client, model=model)
|
82 |
+
results = await classifier.classify_async(texts, category_list)
|
83 |
else: # hybrid
|
84 |
if client is None:
|
85 |
return (
|
|
|
105 |
results.append(tfidf_result)
|
106 |
|
107 |
if low_confidence_texts:
|
108 |
+
llm_results = await llm_classifier.classify_async(
|
109 |
low_confidence_texts, category_list
|
110 |
)
|
111 |
for idx, llm_result in zip(low_confidence_indices, llm_results):
|
|
|
129 |
return None, f"Error: {str(e)}\n{error_traceback}"
|
130 |
|
131 |
|
132 |
+
def process_file(file, text_columns, categories, classifier_type, show_explanations):
|
133 |
+
"""Synchronous wrapper for process_file_async"""
|
134 |
+
return asyncio.run(process_file_async(file, text_columns, categories, classifier_type, show_explanations))
|
135 |
+
|
136 |
+
|
137 |
def export_results(df, format_type):
|
138 |
"""Export results to a file and return the file path for download"""
|
139 |
if df is None:
|