Update app.py
Browse files
app.py
CHANGED
@@ -14,10 +14,14 @@ def load_model():
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
-
|
|
|
|
|
|
|
18 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
19 |
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
|
20 |
global_model.eval()
|
|
|
21 |
print("Model loaded successfully!")
|
22 |
return True
|
23 |
except Exception as e:
|
@@ -25,14 +29,14 @@ def load_model():
|
|
25 |
return False
|
26 |
|
27 |
# Load model at startup
|
28 |
-
load_model()
|
29 |
|
30 |
@app.route('/', methods=['GET'])
|
31 |
def home():
|
32 |
"""Home endpoint to check if API is running"""
|
33 |
response = {
|
34 |
'status': 'API is running',
|
35 |
-
'model_status': 'loaded' if
|
36 |
'usage': {
|
37 |
'endpoint': '/classify',
|
38 |
'method': 'POST',
|
@@ -44,14 +48,14 @@ def home():
|
|
44 |
@app.route('/health', methods=['GET'])
|
45 |
def health_check():
|
46 |
"""Health check endpoint"""
|
47 |
-
if
|
48 |
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
|
49 |
return jsonify({'status': 'healthy'})
|
50 |
|
51 |
@app.route('/classify', methods=['POST'])
|
52 |
def classify_email():
|
53 |
"""Classify email subject"""
|
54 |
-
if
|
55 |
return jsonify({'error': 'Model not loaded'}), 503
|
56 |
|
57 |
try:
|
@@ -101,6 +105,6 @@ def classify_email():
|
|
101 |
return jsonify({'error': str(e)}), 500
|
102 |
|
103 |
if __name__ == '__main__':
|
104 |
-
# Use port 7860 for Hugging Face Spaces
|
105 |
port = int(os.environ.get('PORT', 7860))
|
106 |
-
app.run(host='0.0.0.0', port=port)
|
|
|
14 |
global global_tokenizer, global_model
|
15 |
try:
|
16 |
print("Loading model and tokenizer...")
|
17 |
+
# Replace this path with your model's directory if using a custom model
|
18 |
+
MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Can be a custom path if using your own model
|
19 |
+
|
20 |
+
# Load tokenizer and model from Hugging Face Hub or a local path
|
21 |
global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
|
22 |
global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
|
23 |
global_model.eval()
|
24 |
+
|
25 |
print("Model loaded successfully!")
|
26 |
return True
|
27 |
except Exception as e:
|
|
|
29 |
return False
|
30 |
|
31 |
# Load model at startup
|
32 |
+
model_loaded = load_model()
|
33 |
|
34 |
@app.route('/', methods=['GET'])
|
35 |
def home():
|
36 |
"""Home endpoint to check if API is running"""
|
37 |
response = {
|
38 |
'status': 'API is running',
|
39 |
+
'model_status': 'loaded' if model_loaded else 'not loaded',
|
40 |
'usage': {
|
41 |
'endpoint': '/classify',
|
42 |
'method': 'POST',
|
|
|
48 |
@app.route('/health', methods=['GET'])
|
49 |
def health_check():
|
50 |
"""Health check endpoint"""
|
51 |
+
if not model_loaded:
|
52 |
return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
|
53 |
return jsonify({'status': 'healthy'})
|
54 |
|
55 |
@app.route('/classify', methods=['POST'])
|
56 |
def classify_email():
|
57 |
"""Classify email subject"""
|
58 |
+
if not model_loaded:
|
59 |
return jsonify({'error': 'Model not loaded'}), 503
|
60 |
|
61 |
try:
|
|
|
105 |
return jsonify({'error': str(e)}), 500
|
106 |
|
107 |
if __name__ == '__main__':
|
108 |
+
# Use port 7860 for Hugging Face Spaces or any other port for local testing
|
109 |
port = int(os.environ.get('PORT', 7860))
|
110 |
+
app.run(host='0.0.0.0', port=port)
|