Nitin00043 commited on
Commit
1ee9cdc
·
verified ·
1 Parent(s): ad55826

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -31
app.py CHANGED
@@ -4,48 +4,55 @@ import torch
4
  from concurrent.futures import ThreadPoolExecutor
5
  from threading import Lock
6
 
7
- # Global cache settings and lock for thread-safety
8
  CACHE_SIZE = 100
9
  prediction_cache = {}
10
  cache_lock = Lock()
11
 
12
- # Function to load models with 8-bit quantization
13
- def load_quantized_model(model_name):
14
- try:
15
- model = AutoModelForSequenceClassification.from_pretrained(model_name, load_in_8bit=True)
 
 
 
 
 
 
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- device = 0 if torch.cuda.is_available() else -1
18
- pipe = pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)
19
- print(f"Loaded model: {model_name}")
20
- return pipe
21
- except Exception as e:
22
- print(f"Error loading model '{model_name}': {e}")
23
- raise e
 
 
 
24
 
25
- # Load both models concurrently at startup
26
  with ThreadPoolExecutor() as executor:
27
- sentiment_future = executor.submit(load_quantized_model, "cardiffnlp/twitter-roberta-base-sentiment")
28
- emotion_future = executor.submit(load_quantized_model, "bhadresh-savani/bert-base-uncased-emotion")
29
 
30
  sentiment_pipeline = sentiment_future.result()
31
  emotion_pipeline = emotion_future.result()
32
 
33
  def analyze_text(text):
34
- # Check cache first (using lock for thread-safety)
35
  with cache_lock:
36
  if text in prediction_cache:
37
  return prediction_cache[text]
38
 
39
  try:
40
- # Execute both model inferences in parallel
41
  with ThreadPoolExecutor() as executor:
42
- sentiment_future = executor.submit(sentiment_pipeline, text)
43
- emotion_future = executor.submit(emotion_pipeline, text)
44
-
45
- sentiment_result = sentiment_future.result()[0]
46
- emotion_result = emotion_future.result()[0]
47
 
48
- # Prepare a clear, rounded output
49
  result = {
50
  "Sentiment": {sentiment_result['label']: round(sentiment_result['score'], 4)},
51
  "Emotion": {emotion_result['label']: round(emotion_result['score'], 4)}
@@ -53,7 +60,7 @@ def analyze_text(text):
53
  except Exception as e:
54
  result = {"error": str(e)}
55
 
56
- # Update cache with lock protection
57
  with cache_lock:
58
  if len(prediction_cache) >= CACHE_SIZE:
59
  prediction_cache.pop(next(iter(prediction_cache)))
@@ -61,15 +68,13 @@ def analyze_text(text):
61
 
62
  return result
63
 
64
- # Gradio interface: using gr.JSON to display structured output
65
-
66
-
67
- demo = gr.Interface(
68
  fn=analyze_text,
69
  inputs=gr.Textbox(placeholder="Enter your text here...", label="Input Text"),
70
  outputs=gr.JSON(label="Analysis Results"),
71
  title="🚀 Fast Sentiment & Emotion Analysis",
72
- description="An optimized application using 8-bit quantized models and parallel processing for fast inference.",
73
  examples=[
74
  ["I'm thrilled to start this new adventure!"],
75
  ["This situation is making me really frustrated."],
@@ -79,8 +84,9 @@ demo = gr.Interface(
79
  allow_flagging="never"
80
  )
81
 
82
- # Warm up the models with a sample input to reduce first-call latency
83
  _ = analyze_text("Warming up models...")
84
 
85
  if __name__ == "__main__":
86
- demo.launch()
 
 
4
  from concurrent.futures import ThreadPoolExecutor
5
  from threading import Lock
6
 
7
+ # Global cache and thread lock for thread-safe caching
8
  CACHE_SIZE = 100
9
  prediction_cache = {}
10
  cache_lock = Lock()
11
 
12
+ def load_model(model_name):
13
+ """
14
+ Loads the model with 8-bit quantization if a GPU is available.
15
+ On CPU, it loads the full model.
16
+ """
17
+ if torch.cuda.is_available():
18
+ # Use 8-bit quantization and auto device mapping for GPU inference.
19
+ model = AutoModelForSequenceClassification.from_pretrained(
20
+ model_name, load_in_8bit=True, device_map="auto"
21
+ )
22
  tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ device = 0 # GPU index
24
+ else:
25
+ # CPU fallback: do not use quantization.
26
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ device = -1
29
+
30
+ return pipeline("text-classification", model=model, tokenizer=tokenizer, device=device)
31
+
32
+ # Load both models concurrently atartup.
33
 
 
34
  with ThreadPoolExecutor() as executor:
35
+ sentiment_future = executor.submit(load_model, "cardiffnlp/twitter-roberta-base-sentiment")
36
+ emotion_future = executor.submit(load_model, "bhadresh-savani/bert-base-uncased-emotion")
37
 
38
  sentiment_pipeline = sentiment_future.result()
39
  emotion_pipeline = emotion_future.result()
40
 
41
  def analyze_text(text):
42
+ # Check cache first (thread-safe)
43
  with cache_lock:
44
  if text in prediction_cache:
45
  return prediction_cache[text]
46
 
47
  try:
48
+ # Run both model inferences in parallel.
49
  with ThreadPoolExecutor() as executor:
50
+ future_sentiment = executor.submit(sentiment_pipeline, text)
51
+ future_emotion = executor.submit(emotion_pipeline, text)
52
+ sentiment_result = future_sentiment.result()[0]
53
+ emotion_result = future_emotion.result()[0]
 
54
 
55
+ # Format the output with rounded scores.
56
  result = {
57
  "Sentiment": {sentiment_result['label']: round(sentiment_result['score'], 4)},
58
  "Emotion": {emotion_result['label']: round(emotion_result['score'], 4)}
 
60
  except Exception as e:
61
  result = {"error": str(e)}
62
 
63
+ # Update cache with protection.
64
  with cache_lock:
65
  if len(prediction_cache) >= CACHE_SIZE:
66
  prediction_cache.pop(next(iter(prediction_cache)))
 
68
 
69
  return result
70
 
71
+ # Define the Gradio interface.
72
+ demo = gr.Interface(
 
 
73
  fn=analyze_text,
74
  inputs=gr.Textbox(placeholder="Enter your text here...", label="Input Text"),
75
  outputs=gr.JSON(label="Analysis Results"),
76
  title="🚀 Fast Sentiment & Emotion Analysis",
77
+ description="An optimized application using quantized models (when available) and parallel processing for fast inference.",
78
  examples=[
79
  ["I'm thrilled to start this new adventure!"],
80
  ["This situation is making me really frustrated."],
 
84
  allow_flagging="never"
85
  )
86
 
87
+ # Warm up the models to reduce first-call latency.
88
  _ = analyze_text("Warming up models...")
89
 
90
  if __name__ == "__main__":
91
+ # In Spaces, binding to 0.0.0.0 is required.
92
+ demo.launch(server_name="0.0.0.0", server_port=7860)