simondh commited on
Commit
522275f
·
1 Parent(s): a2b53c6

add server

Browse files
Files changed (4) hide show
  1. classifiers.py +0 -267
  2. requirements.txt +95 -9
  3. server.py +64 -0
  4. test_server.py +43 -0
classifiers.py DELETED
@@ -1,267 +0,0 @@
1
- import numpy as np
2
- import pandas as pd
3
- from sklearn.feature_extraction.text import TfidfVectorizer
4
- from sklearn.cluster import KMeans
5
- from sklearn.metrics.pairwise import cosine_similarity
6
- import random
7
- import json
8
- from concurrent.futures import ThreadPoolExecutor, as_completed
9
- from typing import List, Dict, Any, Optional
10
- from prompts import CATEGORY_SUGGESTION_PROMPT, TEXT_CLASSIFICATION_PROMPT
11
-
12
-
13
- class BaseClassifier:
14
- """Base class for text classifiers"""
15
-
16
- def __init__(self):
17
- pass
18
-
19
- def classify(self, texts, categories=None):
20
- """
21
- Classify a list of texts into categories
22
-
23
- Args:
24
- texts (list): List of text strings to classify
25
- categories (list, optional): List of category names. If None, categories will be auto-detected
26
-
27
- Returns:
28
- list: List of classification results with categories, confidence scores, and explanations
29
- """
30
- raise NotImplementedError("Subclasses must implement this method")
31
-
32
- def _generate_default_categories(self, texts, num_clusters=5):
33
- """
34
- Generate default categories based on text clustering
35
-
36
- Args:
37
- texts (list): List of text strings
38
- num_clusters (int): Number of clusters to generate
39
-
40
- Returns:
41
- list: List of category names
42
- """
43
- # Simple implementation - in real system this would be more sophisticated
44
- default_categories = [f"Category {i+1}" for i in range(num_clusters)]
45
- return default_categories
46
-
47
-
48
- class TFIDFClassifier(BaseClassifier):
49
- """Classifier using TF-IDF and clustering for fast classification"""
50
-
51
- def __init__(self):
52
- super().__init__()
53
- self.vectorizer = TfidfVectorizer(
54
- max_features=1000, stop_words="english", ngram_range=(1, 2)
55
- )
56
- self.model = None
57
- self.feature_names = None
58
- self.categories = None
59
- self.centroids = None
60
-
61
- def classify(self, texts, categories=None):
62
- """Classify texts using TF-IDF and clustering"""
63
- # Vectorize the texts
64
- X = self.vectorizer.fit_transform(texts)
65
- self.feature_names = self.vectorizer.get_feature_names_out()
66
-
67
- # Auto-detect categories if not provided
68
- if not categories:
69
- num_clusters = min(5, len(texts)) # Don't create more clusters than texts
70
- self.categories = self._generate_default_categories(texts, num_clusters)
71
- else:
72
- self.categories = categories
73
- num_clusters = len(categories)
74
-
75
- # Cluster the texts
76
- self.model = KMeans(n_clusters=num_clusters, random_state=42)
77
- clusters = self.model.fit_predict(X)
78
- self.centroids = self.model.cluster_centers_
79
-
80
- # Calculate distances to centroids for confidence
81
- distances = self._calculate_distances(X)
82
-
83
- # Prepare results
84
- results = []
85
- for i, text in enumerate(texts):
86
- cluster_idx = clusters[i]
87
-
88
- # Calculate confidence (inverse of distance, normalized)
89
- confidence = self._calculate_confidence(distances[i])
90
-
91
- # Create explanation
92
- explanation = self._generate_explanation(X[i], cluster_idx)
93
-
94
- results.append(
95
- {
96
- "category": self.categories[cluster_idx],
97
- "confidence": confidence,
98
- "explanation": explanation,
99
- }
100
- )
101
-
102
- return results
103
-
104
- def _calculate_distances(self, X):
105
- """Calculate distances from each point to each centroid"""
106
- return np.sqrt(
107
- (
108
- (X.toarray()[:, np.newaxis, :] - self.centroids[np.newaxis, :, :]) ** 2
109
- ).sum(axis=2)
110
- )
111
-
112
- def _calculate_confidence(self, distances):
113
- """Convert distances to confidence scores (0-100)"""
114
- min_dist = np.min(distances)
115
- max_dist = np.max(distances)
116
-
117
- # Normalize and invert (smaller distance = higher confidence)
118
- if max_dist == min_dist:
119
- return 70 # Default mid-range confidence when all distances are equal
120
-
121
- normalized_dist = (distances - min_dist) / (max_dist - min_dist)
122
- min_normalized = np.min(normalized_dist)
123
-
124
- # Invert and scale to 50-100 range (TF-IDF is never 100% confident)
125
- confidence = 100 - (min_normalized * 50)
126
- return round(confidence, 1)
127
-
128
- def _generate_explanation(self, text_vector, cluster_idx):
129
- """Generate an explanation for the classification"""
130
- # Get the most important features for this cluster
131
- centroid = self.centroids[cluster_idx]
132
-
133
- # Get indices of top features for this text
134
- text_array = text_vector.toarray()[0]
135
- top_indices = text_array.argsort()[-5:][::-1]
136
-
137
- # Get the feature names for these indices
138
- top_features = [self.feature_names[i] for i in top_indices if text_array[i] > 0]
139
-
140
- if not top_features:
141
- return "No significant features identified for this classification."
142
-
143
- explanation = f"Classification based on key terms: {', '.join(top_features)}"
144
- return explanation
145
-
146
-
147
- class LLMClassifier(BaseClassifier):
148
- """Classifier using a Large Language Model for more accurate but slower classification"""
149
-
150
- def __init__(self, client, model="gpt-3.5-turbo"):
151
- super().__init__()
152
- self.client = client
153
- self.model = model
154
-
155
- def classify(
156
- self, texts: List[str], categories: Optional[List[str]] = None
157
- ) -> List[Dict[str, Any]]:
158
- """Classify texts using an LLM with parallel processing"""
159
- if not categories:
160
- # First, use LLM to generate appropriate categories
161
- categories = self._suggest_categories(texts)
162
-
163
- # Process texts in parallel
164
- with ThreadPoolExecutor(max_workers=10) as executor:
165
- # Submit all tasks with their original indices
166
- future_to_index = {
167
- executor.submit(self._classify_text, text, categories): idx
168
- for idx, text in enumerate(texts)
169
- }
170
-
171
- # Initialize results list with None values
172
- results = [None] * len(texts)
173
-
174
- # Collect results as they complete
175
- for future in as_completed(future_to_index):
176
- original_idx = future_to_index[future]
177
- try:
178
- result = future.result()
179
- results[original_idx] = result
180
- except Exception as e:
181
- print(f"Error processing text: {str(e)}")
182
- results[original_idx] = {
183
- "category": categories[0],
184
- "confidence": 50,
185
- "explanation": f"Error during classification: {str(e)}",
186
- }
187
-
188
- return results
189
-
190
- def _suggest_categories(self, texts: List[str], sample_size: int = 20) -> List[str]:
191
- """Use LLM to suggest appropriate categories for the dataset"""
192
- # Take a sample of texts to avoid token limitations
193
- if len(texts) > sample_size:
194
- sample_texts = random.sample(texts, sample_size)
195
- else:
196
- sample_texts = texts
197
-
198
- prompt = CATEGORY_SUGGESTION_PROMPT.format("\n---\n".join(sample_texts))
199
-
200
- try:
201
- response = self.client.chat.completions.create(
202
- model=self.model,
203
- messages=[{"role": "user", "content": prompt}],
204
- temperature=0.2,
205
- max_tokens=100,
206
- )
207
-
208
- # Parse response to get categories
209
- categories_text = response.choices[0].message.content.strip()
210
- categories = [cat.strip() for cat in categories_text.split(",")]
211
-
212
- return categories
213
- except Exception as e:
214
- # Fallback to default categories on error
215
- print(f"Error suggesting categories: {str(e)}")
216
- return self._generate_default_categories(texts)
217
-
218
- def _classify_text(self, text: str, categories: List[str]) -> Dict[str, Any]:
219
- """Use LLM to classify a single text"""
220
- prompt = TEXT_CLASSIFICATION_PROMPT.format(
221
- categories=", ".join(categories), text=text
222
- )
223
-
224
- try:
225
- response = self.client.chat.completions.create(
226
- model=self.model,
227
- messages=[{"role": "user", "content": prompt}],
228
- temperature=0,
229
- max_tokens=200,
230
- )
231
-
232
- # Parse JSON response
233
- response_text = response.choices[0].message.content.strip()
234
-
235
- result = json.loads(response_text)
236
- # Ensure all required fields are present
237
- if not all(k in result for k in ["category", "confidence", "explanation"]):
238
- raise ValueError("Missing required fields in LLM response")
239
-
240
- # Validate category is in the list
241
- if result["category"] not in categories:
242
- result["category"] = categories[
243
- 0
244
- ] # Default to first category if invalid
245
-
246
- # Validate confidence is a number between 0 and 100
247
- try:
248
- result["confidence"] = float(result["confidence"])
249
- if not 0 <= result["confidence"] <= 100:
250
- result["confidence"] = 50
251
- except:
252
- result["confidence"] = 50
253
-
254
- return result
255
- except json.JSONDecodeError:
256
- # Fall back to simple parsing if JSON fails
257
- category = categories[0] # Default
258
- for cat in categories:
259
- if cat.lower() in response_text.lower():
260
- category = cat
261
- break
262
-
263
- return {
264
- "category": category,
265
- "confidence": 50,
266
- "explanation": f"Classification based on language model analysis. (Note: Structured response parsing failed)",
267
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,9 +1,95 @@
1
- gradio>=4.0.0
2
- litellm>=1.10.0
3
- pandas>=2.0.0
4
- numpy>=1.24.0
5
- scikit-learn>=1.2.0
6
- openpyxl>=3.1.0
7
- torch>=2.0.0
8
- transformers>=4.30.0
9
- matplotlib>=3.7.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.16
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ attrs==25.3.0
8
+ audioop-lts==0.2.1
9
+ certifi==2025.1.31
10
+ charset-normalizer==3.4.1
11
+ click==8.1.8
12
+ contourpy==1.3.2
13
+ cycler==0.12.1
14
+ distro==1.9.0
15
+ et-xmlfile==2.0.0
16
+ fastapi==0.115.12
17
+ ffmpy==0.5.0
18
+ filelock==3.18.0
19
+ fonttools==4.57.0
20
+ frozenlist==1.5.0
21
+ fsspec==2025.3.2
22
+ gradio==5.25.1
23
+ gradio-client==1.8.0
24
+ groovy==0.1.2
25
+ h11==0.14.0
26
+ httpcore==1.0.8
27
+ httpx==0.28.1
28
+ huggingface-hub==0.30.2
29
+ idna==3.10
30
+ importlib-metadata==8.6.1
31
+ jinja2==3.1.6
32
+ jiter==0.9.0
33
+ joblib==1.4.2
34
+ jsonschema==4.23.0
35
+ jsonschema-specifications==2024.10.1
36
+ kiwisolver==1.4.8
37
+ litellm==1.66.1
38
+ markdown-it-py==3.0.0
39
+ markupsafe==3.0.2
40
+ matplotlib==3.10.1
41
+ mdurl==0.1.2
42
+ mpmath==1.3.0
43
+ multidict==6.4.3
44
+ networkx==3.4.2
45
+ numpy==2.2.4
46
+ openai==1.74.0
47
+ openpyxl==3.1.5
48
+ orjson==3.10.16
49
+ packaging==24.2
50
+ pandas==2.2.3
51
+ pillow==11.2.1
52
+ propcache==0.3.1
53
+ pydantic==2.11.3
54
+ pydantic-core==2.33.1
55
+ pydub==0.25.1
56
+ pygments==2.19.1
57
+ pyparsing==3.2.3
58
+ python-dateutil==2.9.0.post0
59
+ python-dotenv==1.1.0
60
+ python-multipart==0.0.20
61
+ pytz==2025.2
62
+ pyyaml==6.0.2
63
+ referencing==0.36.2
64
+ regex==2024.11.6
65
+ requests==2.32.3
66
+ rich==14.0.0
67
+ rpds-py==0.24.0
68
+ ruff==0.11.5
69
+ safehttpx==0.1.6
70
+ safetensors==0.5.3
71
+ scikit-learn==1.6.1
72
+ scipy==1.15.2
73
+ semantic-version==2.10.0
74
+ setuptools==78.1.0
75
+ shellingham==1.5.4
76
+ six==1.17.0
77
+ sniffio==1.3.1
78
+ starlette==0.46.2
79
+ sympy==1.13.1
80
+ threadpoolctl==3.6.0
81
+ tiktoken==0.9.0
82
+ tokenizers==0.21.1
83
+ tomlkit==0.13.2
84
+ torch==2.6.0
85
+ tqdm==4.67.1
86
+ transformers==4.51.3
87
+ typer==0.15.2
88
+ typing-extensions==4.13.2
89
+ typing-inspection==0.4.0
90
+ tzdata==2025.2
91
+ urllib3==2.4.0
92
+ uvicorn==0.34.1
93
+ websockets==15.0.1
94
+ yarl==1.19.0
95
+ zipp==3.21.0
server.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List, Optional
5
+ import json
6
+ from classifiers.llm import LLMClassifier
7
+ from litellm import completion
8
+ import asyncio
9
+
10
+ app = FastAPI()
11
+
12
+ # Configure CORS
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # In production, replace with specific origins
16
+ allow_credentials=True,
17
+ allow_methods=["*"],
18
+ allow_headers=["*"],
19
+ )
20
+
21
+ # Initialize the LLM classifier
22
+ classifier = LLMClassifier(client=completion, model="gpt-3.5-turbo")
23
+
24
+ class TextInput(BaseModel):
25
+ text: str
26
+ categories: Optional[List[str]] = None
27
+
28
+ class ClassificationResponse(BaseModel):
29
+ category: str
30
+ confidence: float
31
+ explanation: str
32
+
33
+ class CategorySuggestionResponse(BaseModel):
34
+ categories: List[str]
35
+
36
+ @app.post("/classify", response_model=ClassificationResponse)
37
+ async def classify_text(text_input: TextInput):
38
+ try:
39
+ # Use async classification
40
+ results = await classifier.classify_async(
41
+ [text_input.text],
42
+ text_input.categories
43
+ )
44
+ result = results[0] # Get first result since we're classifying one text
45
+
46
+ return ClassificationResponse(
47
+ category=result["category"],
48
+ confidence=result["confidence"],
49
+ explanation=result["explanation"]
50
+ )
51
+ except Exception as e:
52
+ raise HTTPException(status_code=500, detail=str(e))
53
+
54
+ @app.post("/suggest-categories", response_model=CategorySuggestionResponse)
55
+ async def suggest_categories(texts: List[str]):
56
+ try:
57
+ categories = await classifier._suggest_categories_async(texts)
58
+ return CategorySuggestionResponse(categories=categories)
59
+ except Exception as e:
60
+ raise HTTPException(status_code=500, detail=str(e))
61
+
62
+ if __name__ == "__main__":
63
+ import uvicorn
64
+ uvicorn.run(app, host="0.0.0.0", port=8000)
test_server.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ import json
3
+
4
+ BASE_URL = "http://localhost:8000"
5
+
6
+ def test_classify_text():
7
+ # Test with default categories
8
+ response = requests.post(
9
+ f"{BASE_URL}/classify",
10
+ json={"text": "This is a sample text about technology and innovation."}
11
+ )
12
+ print("Classification with default categories:")
13
+ print(json.dumps(response.json(), indent=2))
14
+
15
+ # Test with custom categories
16
+ response = requests.post(
17
+ f"{BASE_URL}/classify",
18
+ json={
19
+ "text": "This is a sample text about technology and innovation.",
20
+ "categories": ["Technology", "Business", "Science", "Sports"]
21
+ }
22
+ )
23
+ print("\nClassification with custom categories:")
24
+ print(json.dumps(response.json(), indent=2))
25
+
26
+ def test_suggest_categories():
27
+ texts = [
28
+ "This is a text about artificial intelligence and machine learning.",
29
+ "A new breakthrough in quantum computing has been announced.",
30
+ "The latest smartphone features innovative camera technology."
31
+ ]
32
+
33
+ response = requests.post(
34
+ f"{BASE_URL}/suggest-categories",
35
+ json=texts
36
+ )
37
+ print("\nSuggested categories:")
38
+ print(json.dumps(response.json(), indent=2))
39
+
40
+ if __name__ == "__main__":
41
+ print("Testing FastAPI server endpoints...")
42
+ test_classify_text()
43
+ test_suggest_categories()