aideveloper24 commited on
Commit
cb536ff
·
verified ·
1 Parent(s): 4340cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -36
app.py CHANGED
@@ -1,42 +1,20 @@
1
  from flask import Flask, request, jsonify
 
2
  import torch
3
- from transformers import BertTokenizer, BertForSequenceClassification
4
  import os
5
 
6
  app = Flask(__name__)
7
 
8
- # Global variables to store model and tokenizer
9
- global_tokenizer = None
10
- global_model = None
11
-
12
- def load_model():
13
- """Load the model and tokenizer"""
14
- global global_tokenizer, global_model
15
- try:
16
- print("Loading model and tokenizer...")
17
- # Use a different model (bert-base-uncased)
18
- MODEL_NAME = "bert-base-uncased" # Switch to this model
19
-
20
- # Load tokenizer and model from Hugging Face Hub or a local path
21
- global_tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)
22
- global_model = BertForSequenceClassification.from_pretrained(MODEL_NAME)
23
- global_model.eval()
24
-
25
- print("Model loaded successfully!")
26
- return True
27
- except Exception as e:
28
- print(f"Error loading model: {str(e)}")
29
- return False
30
-
31
- # Load model at startup
32
- model_loaded = load_model()
33
 
34
  @app.route('/', methods=['GET'])
35
  def home():
36
  """Home endpoint to check if API is running"""
37
  response = {
38
  'status': 'API is running',
39
- 'model_status': 'loaded' if model_loaded else 'not loaded',
40
  'usage': {
41
  'endpoint': '/classify',
42
  'method': 'POST',
@@ -48,16 +26,11 @@ def home():
48
  @app.route('/health', methods=['GET'])
49
  def health_check():
50
  """Health check endpoint"""
51
- if not model_loaded:
52
- return jsonify({'status': 'unhealthy', 'error': 'Model not loaded'}), 503
53
  return jsonify({'status': 'healthy'})
54
 
55
  @app.route('/classify', methods=['POST'])
56
  def classify_email():
57
  """Classify email subject"""
58
- if not model_loaded:
59
- return jsonify({'error': 'Model not loaded'}), 503
60
-
61
  try:
62
  # Get request data
63
  data = request.get_json()
@@ -71,11 +44,11 @@ def classify_email():
71
  subject = data['subject']
72
 
73
  # Tokenize
74
- inputs = global_tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
75
 
76
  # Predict
77
  with torch.no_grad():
78
- outputs = global_model(**inputs)
79
  logits = outputs.logits
80
 
81
  # Get probabilities
@@ -85,8 +58,8 @@ def classify_email():
85
 
86
  # Define custom categories (Modify this as needed)
87
  CUSTOM_LABELS = {
88
- 0: "Business/Professional",
89
- 1: "Personal/Casual"
90
  }
91
 
92
  result = {
 
1
  from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
  import torch
 
4
  import os
5
 
6
  app = Flask(__name__)
7
 
8
+ # Load the model and tokenizer directly
9
+ tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
10
+ model = AutoModelForSequenceClassification.from_pretrained("distilbert/distilbert-base-uncased-finetuned-sst-2-english")
11
+ model.eval() # Set the model to evaluation mode
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @app.route('/', methods=['GET'])
14
  def home():
15
  """Home endpoint to check if API is running"""
16
  response = {
17
  'status': 'API is running',
 
18
  'usage': {
19
  'endpoint': '/classify',
20
  'method': 'POST',
 
26
  @app.route('/health', methods=['GET'])
27
  def health_check():
28
  """Health check endpoint"""
 
 
29
  return jsonify({'status': 'healthy'})
30
 
31
  @app.route('/classify', methods=['POST'])
32
  def classify_email():
33
  """Classify email subject"""
 
 
 
34
  try:
35
  # Get request data
36
  data = request.get_json()
 
44
  subject = data['subject']
45
 
46
  # Tokenize
47
+ inputs = tokenizer(subject, return_tensors="pt", truncation=True, max_length=512)
48
 
49
  # Predict
50
  with torch.no_grad():
51
+ outputs = model(**inputs)
52
  logits = outputs.logits
53
 
54
  # Get probabilities
 
58
 
59
  # Define custom categories (Modify this as needed)
60
  CUSTOM_LABELS = {
61
+ 0: "Negative",
62
+ 1: "Positive"
63
  }
64
 
65
  result = {