sandz7 commited on
Commit
3f52b9b
Β·
1 Parent(s): f2448b3

added context manager decor on llama_generation

Browse files
Files changed (1) hide show
  1. app.py +6 -8
app.py CHANGED
@@ -4,6 +4,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
4
  from huggingface_hub import login
5
  import os
6
  import threading
 
7
  import spaces
8
  from openai import OpenAI
9
  # import multiprocessing as mp
@@ -78,11 +79,9 @@ def gpt_generation(input: str,
78
 
79
  return stream
80
 
81
- # Global lock variable
82
- lock = threading.Lock()
83
-
84
  # Place just input pass and return generation output
85
  @spaces.GPU(duration=120)
 
86
  def llama_generation(input_text: str,
87
  history: list,
88
  temperature: float,
@@ -115,16 +114,15 @@ def llama_generation(input_text: str,
115
  generate_kwargs["do_sample"] = False
116
 
117
  # Use a lock object to synchronize access to the llama_model
118
- global lock
119
 
120
- # # Place the generation in a thread so we can access it.
121
- # # place the function as target and place the kwargs next as the kwargs
122
- def generation_llama(lock=lock):
123
  with lock:
124
- # Generate response using Llama3
125
  response = llama_model.generate(**generate_kwargs)
126
  return response
127
 
 
128
  # start the thread and wait for it to finish
129
  thread = threading.Thread(target=generation_llama)
130
  thread.start()
 
4
  from huggingface_hub import login
5
  import os
6
  import threading
7
+ import contextlib
8
  import spaces
9
  from openai import OpenAI
10
  # import multiprocessing as mp
 
79
 
80
  return stream
81
 
 
 
 
82
  # Place just input pass and return generation output
83
  @spaces.GPU(duration=120)
84
+ @contextlib.contextmanager
85
  def llama_generation(input_text: str,
86
  history: list,
87
  temperature: float,
 
114
  generate_kwargs["do_sample"] = False
115
 
116
  # Use a lock object to synchronize access to the llama_model
117
+ lock = threading.Lock()
118
 
119
+ def generate_llama():
 
 
120
  with lock:
121
+ # Generate the response using the llama_model
122
  response = llama_model.generate(**generate_kwargs)
123
  return response
124
 
125
+
126
  # start the thread and wait for it to finish
127
  thread = threading.Thread(target=generation_llama)
128
  thread.start()