simondh commited on
Commit
36183d4
·
1 Parent(s): 720c911

classify async

Browse files
Files changed (4) hide show
  1. app.py +37 -17
  2. classifiers/llm.py +62 -74
  3. client.py +37 -0
  4. process.py +16 -27
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import gradio as gr
 
3
 
4
- from litellm import OpenAI
5
  import json
6
  from sklearn.cluster import KMeans
7
  from sklearn.decomposition import PCA
@@ -9,7 +9,9 @@ import matplotlib.pyplot as plt
9
 
10
  import logging
11
  from dotenv import load_dotenv
12
- from process import update_api_key, process_file, export_results
 
 
13
  # Load environment variables from .env file
14
  load_dotenv()
15
 
@@ -30,16 +32,13 @@ logging.basicConfig(
30
  # Initialize API key from environment variable
31
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
32
 
33
- # Only initialize client if API key is available
34
- client = None
35
  if OPENAI_API_KEY:
36
- try:
37
- client = OpenAI(api_key=OPENAI_API_KEY)
38
  logging.info("OpenAI client initialized successfully")
39
- except Exception as e:
40
- logging.error(f"Failed to initialize OpenAI client: {str(e)}")
41
-
42
-
43
 
44
  # Create Gradio interface
45
  with gr.Blocks(title="Text Classification System") as demo:
@@ -57,9 +56,8 @@ with gr.Blocks(title="Text Classification System") as demo:
57
  api_key_message = gr.Textbox(label="Status", interactive=False)
58
 
59
  # Display current API status
60
- api_status = (
61
- "API Key is set" if OPENAI_API_KEY else "No API Key found. Please set one."
62
- )
63
  gr.Markdown(f"**Current API Status**: {api_status}")
64
 
65
  api_key_button.click(
@@ -344,7 +342,7 @@ with gr.Blocks(title="Text Classification System") as demo:
344
  return gr.File(value=file_path, visible=True)
345
 
346
  # Function to improve classification based on validation report
347
- def improve_classification(
348
  df,
349
  validation_report,
350
  text_columns,
@@ -353,7 +351,7 @@ with gr.Blocks(title="Text Classification System") as demo:
353
  show_explanations,
354
  file,
355
  ):
356
- """Improve classification based on validation report"""
357
  if df is None or not validation_report:
358
  return (
359
  df,
@@ -420,7 +418,7 @@ with gr.Blocks(title="Text Classification System") as demo:
420
  categories = ",".join(all_categories)
421
 
422
  # Process with improved parameters
423
- improved_df, new_validation = process_file(
424
  file,
425
  text_columns,
426
  categories,
@@ -466,6 +464,28 @@ with gr.Blocks(title="Text Classification System") as demo:
466
  ),
467
  )
468
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  # Connect functions
470
  load_categories_button.click(
471
  load_file_and_suggest_categories,
@@ -506,7 +526,7 @@ with gr.Blocks(title="Text Classification System") as demo:
506
  process_button.click(
507
  lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df]
508
  ).then(
509
- process_file,
510
  inputs=[
511
  file_input,
512
  text_column,
 
1
  import os
2
  import gradio as gr
3
+ import asyncio
4
 
 
5
  import json
6
  from sklearn.cluster import KMeans
7
  from sklearn.decomposition import PCA
 
9
 
10
  import logging
11
  from dotenv import load_dotenv
12
+ from process import update_api_key, process_file_async, export_results
13
+ from client import get_client, initialize_client
14
+
15
  # Load environment variables from .env file
16
  load_dotenv()
17
 
 
32
  # Initialize API key from environment variable
33
  OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
34
 
35
+ # Initialize client if API key is available
 
36
  if OPENAI_API_KEY:
37
+ success, message = initialize_client(OPENAI_API_KEY)
38
+ if success:
39
  logging.info("OpenAI client initialized successfully")
40
+ else:
41
+ logging.error(f"Failed to initialize OpenAI client: {message}")
 
 
42
 
43
  # Create Gradio interface
44
  with gr.Blocks(title="Text Classification System") as demo:
 
56
  api_key_message = gr.Textbox(label="Status", interactive=False)
57
 
58
  # Display current API status
59
+ client = get_client()
60
+ api_status = "API Key is set" if client else "No API Key found. Please set one."
 
61
  gr.Markdown(f"**Current API Status**: {api_status}")
62
 
63
  api_key_button.click(
 
342
  return gr.File(value=file_path, visible=True)
343
 
344
  # Function to improve classification based on validation report
345
+ async def improve_classification_async(
346
  df,
347
  validation_report,
348
  text_columns,
 
351
  show_explanations,
352
  file,
353
  ):
354
+ """Async version of improve_classification"""
355
  if df is None or not validation_report:
356
  return (
357
  df,
 
418
  categories = ",".join(all_categories)
419
 
420
  # Process with improved parameters
421
+ improved_df, new_validation = await process_file_async(
422
  file,
423
  text_columns,
424
  categories,
 
464
  ),
465
  )
466
 
467
+ def improve_classification(
468
+ df,
469
+ validation_report,
470
+ text_columns,
471
+ categories,
472
+ classifier_type,
473
+ show_explanations,
474
+ file,
475
+ ):
476
+ """Synchronous wrapper for improve_classification_async"""
477
+ return asyncio.run(
478
+ improve_classification_async(
479
+ df,
480
+ validation_report,
481
+ text_columns,
482
+ categories,
483
+ classifier_type,
484
+ show_explanations,
485
+ file,
486
+ )
487
+ )
488
+
489
  # Connect functions
490
  load_categories_button.click(
491
  load_file_and_suggest_categories,
 
526
  process_button.click(
527
  lambda: gr.Dataframe(visible=True), inputs=[], outputs=[results_df]
528
  ).then(
529
+ process_file_async,
530
  inputs=[
531
  file_input,
532
  text_column,
classifiers/llm.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import numpy as np
3
  import pandas as pd
4
  from sklearn.feature_extraction.text import TfidfVectorizer
@@ -6,14 +5,13 @@ from sklearn.cluster import KMeans
6
  from sklearn.metrics.pairwise import cosine_similarity
7
  import random
8
  import json
9
- from concurrent.futures import ThreadPoolExecutor, as_completed
10
  from typing import List, Dict, Any, Optional
11
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
12
 
13
  from .base import BaseClassifier
14
 
15
 
16
-
17
  class LLMClassifier(BaseClassifier):
18
  """Classifier using a Large Language Model for more accurate but slower classification"""
19
 
@@ -22,77 +20,15 @@ class LLMClassifier(BaseClassifier):
22
  self.client = client
23
  self.model = model
24
 
25
- def classify(
26
- self, texts: List[str], categories: Optional[List[str]] = None
27
- ) -> List[Dict[str, Any]]:
28
- """Classify texts using an LLM with parallel processing"""
29
- if not categories:
30
- # First, use LLM to generate appropriate categories
31
- categories = self._suggest_categories(texts)
32
-
33
- # Process texts in parallel
34
- with ThreadPoolExecutor(max_workers=10) as executor:
35
- # Submit all tasks with their original indices
36
- future_to_index = {
37
- executor.submit(self._classify_text, text, categories): idx
38
- for idx, text in enumerate(texts)
39
- }
40
-
41
- # Initialize results list with None values
42
- results = [None] * len(texts)
43
-
44
- # Collect results as they complete
45
- for future in as_completed(future_to_index):
46
- original_idx = future_to_index[future]
47
- try:
48
- result = future.result()
49
- results[original_idx] = result
50
- except Exception as e:
51
- print(f"Error processing text: {str(e)}")
52
- results[original_idx] = {
53
- "category": categories[0],
54
- "confidence": 50,
55
- "explanation": f"Error during classification: {str(e)}",
56
- }
57
-
58
- return results
59
-
60
- def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
61
- """Use LLM to suggest appropriate categories for the dataset"""
62
- # Take a sample of texts to avoid token limitations
63
- if len(texts) > sample_size:
64
- sample_texts = random.sample(texts, sample_size)
65
- else:
66
- sample_texts = texts
67
-
68
- prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
69
-
70
- try:
71
- response = self.client.chat.completions.create(
72
- model=self.model,
73
- messages=[{"role": "user", "content": prompt}],
74
- temperature=0.2,
75
- max_tokens=100,
76
- )
77
-
78
- # Parse response to get categories
79
- categories_text = response.choices[0].message.content.strip()
80
- categories = [cat.strip() for cat in categories_text.split(",")]
81
-
82
- return categories
83
- except Exception as e:
84
- # Fallback to default categories on error
85
- print(f"Error suggesting categories: {str(e)}")
86
- return self._generate_default_categories(texts)
87
-
88
- def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
89
- """Use LLM to classify a single text"""
90
  prompt = TEXT_CLASSIFICATION_PROMPT.format(
91
- categories=", ".join(categories), text=text
 
92
  )
93
 
94
  try:
95
- response = self.client.chat.completions.create(
96
  model=self.model,
97
  messages=[{"role": "user", "content": prompt}],
98
  temperature=0,
@@ -101,17 +37,15 @@ class LLMClassifier(BaseClassifier):
101
 
102
  # Parse JSON response
103
  response_text = response.choices[0].message.content.strip()
104
-
105
  result = json.loads(response_text)
 
106
  # Ensure all required fields are present
107
  if not all(k in result for k in ["category", "confidence", "explanation"]):
108
  raise ValueError("Missing required fields in LLM response")
109
 
110
  # Validate category is in the list
111
  if result["category"] not in categories:
112
- result["category"] = categories[
113
- 0
114
- ] # Default to first category if invalid
115
 
116
  # Validate confidence is a number between 0 and 100
117
  try:
@@ -135,3 +69,57 @@ class LLMClassifier(BaseClassifier):
135
  "confidence": 50,
136
  "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
137
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
3
  from sklearn.feature_extraction.text import TfidfVectorizer
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
  import random
7
  import json
8
+ import asyncio
9
  from typing import List, Dict, Any, Optional
10
  from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
11
 
12
  from .base import BaseClassifier
13
 
14
 
 
15
  class LLMClassifier(BaseClassifier):
16
  """Classifier using a Large Language Model for more accurate but slower classification"""
17
 
 
20
  self.client = client
21
  self.model = model
22
 
23
+ async def _classify_text_async(self, text: str, categories: List[str]) -> Dict[str, Any]:
24
+ """Async version of text classification"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  prompt = TEXT_CLASSIFICATION_PROMPT.format(
26
+ categories=", ".join(categories),
27
+ text=text
28
  )
29
 
30
  try:
31
+ response = await self.client.chat.completions.create(
32
  model=self.model,
33
  messages=[{"role": "user", "content": prompt}],
34
  temperature=0,
 
37
 
38
  # Parse JSON response
39
  response_text = response.choices[0].message.content.strip()
 
40
  result = json.loads(response_text)
41
+
42
  # Ensure all required fields are present
43
  if not all(k in result for k in ["category", "confidence", "explanation"]):
44
  raise ValueError("Missing required fields in LLM response")
45
 
46
  # Validate category is in the list
47
  if result["category"] not in categories:
48
+ result["category"] = categories[0] # Default to first category if invalid
 
 
49
 
50
  # Validate confidence is a number between 0 and 100
51
  try:
 
69
  "confidence": 50,
70
  "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
71
  }
72
+ except Exception as e:
73
+ return {
74
+ "category": categories[0],
75
+ "confidence": 50,
76
+ "explanation": f"Error during classification: {str(e)}",
77
+ }
78
+
79
+ async def _suggest_categories_async(self, texts: List[str], sample_size: int = 20) -> List[str]:
80
+ """Async version of category suggestion"""
81
+ # Take a sample of texts to avoid token limitations
82
+ if len(texts) > sample_size:
83
+ sample_texts = random.sample(texts, sample_size)
84
+ else:
85
+ sample_texts = texts
86
+
87
+ prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
88
+
89
+ try:
90
+ response = await self.client.chat.completions.create(
91
+ model=self.model,
92
+ messages=[{"role": "user", "content": prompt}],
93
+ temperature=0.2,
94
+ max_tokens=100,
95
+ )
96
+
97
+ # Parse response to get categories
98
+ categories_text = response.choices[0].message.content.strip()
99
+ categories = [cat.strip() for cat in categories_text.split(",")]
100
+
101
+ return categories
102
+ except Exception as e:
103
+ # Fallback to default categories on error
104
+ print(f"Error suggesting categories: {str(e)}")
105
+ return self._generate_default_categories(texts)
106
+
107
+ async def classify_async(
108
+ self, texts: List[str], categories: Optional[List[str]] = None
109
+ ) -> List[Dict[str, Any]]:
110
+ """Async method to classify texts"""
111
+ if not categories:
112
+ categories = await self._suggest_categories_async(texts)
113
+
114
+ # Create tasks for all texts
115
+ tasks = [self._classify_text_async(text, categories) for text in texts]
116
+
117
+ # Gather all results
118
+ results = await asyncio.gather(*tasks)
119
+ return results
120
+
121
+ def classify(
122
+ self, texts: List[str], categories: Optional[List[str]] = None
123
+ ) -> List[Dict[str, Any]]:
124
+ """Synchronous wrapper for backwards compatibility"""
125
+ return asyncio.run(self.classify_async(texts, categories))
client.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from litellm import OpenAI
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ # Load environment variables
6
+ load_dotenv()
7
+
8
+ # Initialize client as None
9
+ client = None
10
+
11
+ def get_client():
12
+ """Get the OpenAI client instance"""
13
+ global client
14
+ return client
15
+
16
+ def initialize_client(api_key=None):
17
+ """Initialize the OpenAI client with an API key"""
18
+ global client
19
+
20
+ # Use provided API key or get from environment
21
+ api_key = api_key or os.environ.get("OPENAI_API_KEY")
22
+
23
+ if not api_key:
24
+ return False, "No API key provided"
25
+
26
+ try:
27
+ client = OpenAI(api_key=api_key)
28
+ # Test the connection with a simple request
29
+ response = client.chat.completions.create(
30
+ model="gpt-3.5-turbo",
31
+ messages=[{"role": "user", "content": "test"}],
32
+ max_tokens=5,
33
+ )
34
+ return True, "API Key updated and verified successfully"
35
+ except Exception as e:
36
+ client = None
37
+ return False, f"Failed to initialize client: {str(e)}"
process.py CHANGED
@@ -1,41 +1,22 @@
1
-
2
-
3
  import logging
4
  import time
5
  import traceback
 
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
 
8
- from litellm import OpenAI
9
  from classifiers import TFIDFClassifier, LLMClassifier
10
  from utils import load_data, validate_results
 
11
 
12
 
13
  def update_api_key(api_key):
14
  """Update the OpenAI API key"""
15
- global OPENAI_API_KEY, client
16
-
17
- if not api_key:
18
- return "API Key cannot be empty"
19
-
20
- OPENAI_API_KEY = api_key
21
-
22
- try:
23
- client = OpenAI(api_key=api_key)
24
- # Test the connection with a simple request
25
- response = client.chat.completions.create(
26
- model="gpt-3.5-turbo",
27
- messages=[{"role": "user", "content": "test"}],
28
- max_tokens=5,
29
- )
30
- return f"API Key updated and verified successfully"
31
- except Exception as e:
32
- error_msg = str(e)
33
- logging.error(f"API key update failed: {error_msg}")
34
- return f"Failed to update API Key: {error_msg}"
35
 
36
 
37
- def process_file(file, text_columns, categories, classifier_type, show_explanations):
38
- """Process the uploaded file and classify text data"""
39
  # Initialize result_df and validation_report
40
  result_df = None
41
  validation_report = None
@@ -83,6 +64,9 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
83
  else:
84
  classifier_type = "tfidf"
85
 
 
 
 
86
  # Initialize appropriate classifier
87
  if classifier_type == "tfidf":
88
  classifier = TFIDFClassifier()
@@ -95,7 +79,7 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
95
  )
96
  model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
97
  classifier = LLMClassifier(client=client, model=model)
98
- results = classifier.classify(texts, category_list)
99
  else: # hybrid
100
  if client is None:
101
  return (
@@ -121,7 +105,7 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
121
  results.append(tfidf_result)
122
 
123
  if low_confidence_texts:
124
- llm_results = llm_classifier.classify(
125
  low_confidence_texts, category_list
126
  )
127
  for idx, llm_result in zip(low_confidence_indices, llm_results):
@@ -145,6 +129,11 @@ def process_file(file, text_columns, categories, classifier_type, show_explanati
145
  return None, f"Error: {str(e)}\n{error_traceback}"
146
 
147
 
 
 
 
 
 
148
  def export_results(df, format_type):
149
  """Export results to a file and return the file path for download"""
150
  if df is None:
 
 
 
1
  import logging
2
  import time
3
  import traceback
4
+ import asyncio
5
  from sklearn.feature_extraction.text import TfidfVectorizer
6
 
 
7
  from classifiers import TFIDFClassifier, LLMClassifier
8
  from utils import load_data, validate_results
9
+ from client import get_client
10
 
11
 
12
  def update_api_key(api_key):
13
  """Update the OpenAI API key"""
14
+ from client import initialize_client
15
+ return initialize_client(api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
 
18
+ async def process_file_async(file, text_columns, categories, classifier_type, show_explanations):
19
+ """Async version of process_file"""
20
  # Initialize result_df and validation_report
21
  result_df = None
22
  validation_report = None
 
64
  else:
65
  classifier_type = "tfidf"
66
 
67
+ # Get the client instance
68
+ client = get_client()
69
+
70
  # Initialize appropriate classifier
71
  if classifier_type == "tfidf":
72
  classifier = TFIDFClassifier()
 
79
  )
80
  model = "gpt-3.5-turbo" if classifier_type == "gpt35" else "gpt-4"
81
  classifier = LLMClassifier(client=client, model=model)
82
+ results = await classifier.classify_async(texts, category_list)
83
  else: # hybrid
84
  if client is None:
85
  return (
 
105
  results.append(tfidf_result)
106
 
107
  if low_confidence_texts:
108
+ llm_results = await llm_classifier.classify_async(
109
  low_confidence_texts, category_list
110
  )
111
  for idx, llm_result in zip(low_confidence_indices, llm_results):
 
129
  return None, f"Error: {str(e)}\n{error_traceback}"
130
 
131
 
132
+ def process_file(file, text_columns, categories, classifier_type, show_explanations):
133
+ """Synchronous wrapper for process_file_async"""
134
+ return asyncio.run(process_file_async(file, text_columns, categories, classifier_type, show_explanations))
135
+
136
+
137
  def export_results(df, format_type):
138
  """Export results to a file and return the file path for download"""
139
  if df is None: