aideveloper24 commited on
Commit
1a0232d
·
verified ·
1 Parent(s): 7b3bba6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -14,10 +14,14 @@ def load_model():
14
  global global_tokenizer, global_model
15
  try:
16
  print("Loading model and tokenizer...")
17
- MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english"
 
 
 
18
  global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
19
  global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
20
  global_model.eval()
 
21
  print("Model loaded successfully!")
22
  return True
23
  except Exception as e:
@@ -25,14 +29,14 @@ def load_model():
25
  return False
26
 
27
  # Load model at startup
28
- load_model()
29
 
30
  @app.route('/', methods=['GET'])
31
  def home():
32
  """Home endpoint to check if API is running"""
33
  response = {
34
  'status': 'API is running',
35
- 'model_status': 'loaded' if global_model is not None else 'not loaded',
36
  'usage': {
37
  'endpoint': '/classify',
38
  'method': 'POST',
@@ -44,14 +48,14 @@ def home():
44
  @app.route('/health', methods=['GET'])
45
  def health_check():
46
  """Health check endpoint"""
47
- if global_model is None or global_tokenizer is None:
48
  return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
49
  return jsonify({'status': 'healthy'})
50
 
51
  @app.route('/classify', methods=['POST'])
52
  def classify_email():
53
  """Classify email subject"""
54
- if global_model is None or global_tokenizer is None:
55
  return jsonify({'error': 'Model not loaded'}), 503
56
 
57
  try:
@@ -101,6 +105,6 @@ def classify_email():
101
  return jsonify({'error': str(e)}), 500
102
 
103
  if __name__ == '__main__':
104
- # Use port 7860 for Hugging Face Spaces
105
  port = int(os.environ.get('PORT', 7860))
106
- app.run(host='0.0.0.0', port=port)
 
14
  global global_tokenizer, global_model
15
  try:
16
  print("Loading model and tokenizer...")
17
+ # Replace this path with your model's directory if using a custom model
18
+ MODEL_NAME = "distilbert-base-uncased-finetuned-sst-2-english" # Can be a custom path if using your own model
19
+
20
+ # Load tokenizer and model from Hugging Face Hub or a local path
21
  global_tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
22
  global_model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
23
  global_model.eval()
24
+
25
  print("Model loaded successfully!")
26
  return True
27
  except Exception as e:
 
29
  return False
30
 
31
  # Load model at startup
32
+ model_loaded = load_model()
33
 
34
  @app.route('/', methods=['GET'])
35
  def home():
36
  """Home endpoint to check if API is running"""
37
  response = {
38
  'status': 'API is running',
39
+ 'model_status': 'loaded' if model_loaded else 'not loaded',
40
  'usage': {
41
  'endpoint': '/classify',
42
  'method': 'POST',
 
48
  @app.route('/health', methods=['GET'])
49
  def health_check():
50
  """Health check endpoint"""
51
+ if not model_loaded:
52
  return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
53
  return jsonify({'status': 'healthy'})
54
 
55
  @app.route('/classify', methods=['POST'])
56
  def classify_email():
57
  """Classify email subject"""
58
+ if not model_loaded:
59
  return jsonify({'error': 'Model not loaded'}), 503
60
 
61
  try:
 
105
  return jsonify({'error': str(e)}), 500
106
 
107
  if __name__ == '__main__':
108
+ # Use port 7860 for Hugging Face Spaces or any other port for local testing
109
  port = int(os.environ.get('PORT', 7860))
110
+ app.run(host='0.0.0.0', port=port)