Ankerkraut commited on
Commit
a4dec41
·
1 Parent(s): 4b2d9b2

enable cuda, cpu takes forever

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -40,16 +40,16 @@ client.add(collection_name="recipes",
40
  model_name = "LeoLM/leo-hessianai-13b-chat"
41
 
42
  last_messages = []
43
- #@spaces.GPU
44
  def load_model():
45
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
46
  model_name,
47
- device_map="cpu",
48
  torch_dtype=torch.float16,
49
  use_cache=True,
50
  offload_folder="../offload"
51
  )
52
-
53
  ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
54
  torch_dtype=torch.float16,
55
  truncation=True,
@@ -60,7 +60,7 @@ def load_model():
60
 
61
  _model_cache = None
62
 
63
- #@spaces.GPU
64
  def get_model():
65
  global _model_cache
66
  if _model_cache is None:
@@ -69,7 +69,7 @@ def get_model():
69
  _model_cache = load_model()
70
  return _model_cache
71
 
72
- #@spaces.GPU
73
  def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
74
  system_message_support = f"""<|im_start|>system
75
  Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
@@ -154,7 +154,7 @@ def search_qdrant_with_context(query_text, collection_name, top_k=3):
154
  print("Retrieved Text ", retrieved_texts)
155
 
156
  return retrieved_texts
157
- #@spaces.GPU
158
  def respond(
159
  query,
160
  history: list[tuple[str, str]],
@@ -186,6 +186,7 @@ def respond(
186
  if len(last_messages) > 5:
187
  last_messages.pop(0)
188
  last_messages.append(full_conv)
 
189
  return answer
190
 
191
  """
 
40
  model_name = "LeoLM/leo-hessianai-13b-chat"
41
 
42
  last_messages = []
43
+ @spaces.GPU
44
  def load_model():
45
  ankerbot_model = AutoModelForCausalLM.from_pretrained(
46
  model_name,
47
+ device_map="cuda:0",
48
  torch_dtype=torch.float16,
49
  use_cache=True,
50
  offload_folder="../offload"
51
  )
52
+ ankerbot_model.gradient_checkpointing_enable()
53
  ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
54
  torch_dtype=torch.float16,
55
  truncation=True,
 
60
 
61
  _model_cache = None
62
 
63
+ @spaces.GPU
64
  def get_model():
65
  global _model_cache
66
  if _model_cache is None:
 
69
  _model_cache = load_model()
70
  return _model_cache
71
 
72
+ @spaces.GPU
73
  def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
74
  system_message_support = f"""<|im_start|>system
75
  Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
 
154
  print("Retrieved Text ", retrieved_texts)
155
 
156
  return retrieved_texts
157
+ @spaces.GPU
158
  def respond(
159
  query,
160
  history: list[tuple[str, str]],
 
186
  if len(last_messages) > 5:
187
  last_messages.pop(0)
188
  last_messages.append(full_conv)
189
+ print(last_messages)
190
  return answer
191
 
192
  """