Spaces:
Sleeping
Sleeping
Commit
·
a4dec41
1
Parent(s):
4b2d9b2
enable cuda, cpu takes forever
Browse files
app.py
CHANGED
@@ -40,16 +40,16 @@ client.add(collection_name="recipes",
|
|
40 |
model_name = "LeoLM/leo-hessianai-13b-chat"
|
41 |
|
42 |
last_messages = []
|
43 |
-
|
44 |
def load_model():
|
45 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_name,
|
47 |
-
device_map="
|
48 |
torch_dtype=torch.float16,
|
49 |
use_cache=True,
|
50 |
offload_folder="../offload"
|
51 |
)
|
52 |
-
|
53 |
ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
|
54 |
torch_dtype=torch.float16,
|
55 |
truncation=True,
|
@@ -60,7 +60,7 @@ def load_model():
|
|
60 |
|
61 |
_model_cache = None
|
62 |
|
63 |
-
|
64 |
def get_model():
|
65 |
global _model_cache
|
66 |
if _model_cache is None:
|
@@ -69,7 +69,7 @@ def get_model():
|
|
69 |
_model_cache = load_model()
|
70 |
return _model_cache
|
71 |
|
72 |
-
|
73 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
|
74 |
system_message_support = f"""<|im_start|>system
|
75 |
Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
|
@@ -154,7 +154,7 @@ def search_qdrant_with_context(query_text, collection_name, top_k=3):
|
|
154 |
print("Retrieved Text ", retrieved_texts)
|
155 |
|
156 |
return retrieved_texts
|
157 |
-
|
158 |
def respond(
|
159 |
query,
|
160 |
history: list[tuple[str, str]],
|
@@ -186,6 +186,7 @@ def respond(
|
|
186 |
if len(last_messages) > 5:
|
187 |
last_messages.pop(0)
|
188 |
last_messages.append(full_conv)
|
|
|
189 |
return answer
|
190 |
|
191 |
"""
|
|
|
40 |
model_name = "LeoLM/leo-hessianai-13b-chat"
|
41 |
|
42 |
last_messages = []
|
43 |
+
@spaces.GPU
|
44 |
def load_model():
|
45 |
ankerbot_model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_name,
|
47 |
+
device_map="cuda:0",
|
48 |
torch_dtype=torch.float16,
|
49 |
use_cache=True,
|
50 |
offload_folder="../offload"
|
51 |
)
|
52 |
+
ankerbot_model.gradient_checkpointing_enable()
|
53 |
ankerbot_tokenizer = AutoTokenizer.from_pretrained(model_name,
|
54 |
torch_dtype=torch.float16,
|
55 |
truncation=True,
|
|
|
60 |
|
61 |
_model_cache = None
|
62 |
|
63 |
+
@spaces.GPU
|
64 |
def get_model():
|
65 |
global _model_cache
|
66 |
if _model_cache is None:
|
|
|
69 |
_model_cache = load_model()
|
70 |
return _model_cache
|
71 |
|
72 |
+
@spaces.GPU
|
73 |
def generate_response(query, context, prompts, max_tokens, temperature, top_p, generator):
|
74 |
system_message_support = f"""<|im_start|>system
|
75 |
Rolle: Du bist der KI-Assistent für Kundenservice, der im Namen des Unternehmens und Gewürzmanufaktur Ankerkraut handelt und Antworten aus der Ich-Perspektive, basierend auf den bereitgestellten Informationen gibt.
|
|
|
154 |
print("Retrieved Text ", retrieved_texts)
|
155 |
|
156 |
return retrieved_texts
|
157 |
+
@spaces.GPU
|
158 |
def respond(
|
159 |
query,
|
160 |
history: list[tuple[str, str]],
|
|
|
186 |
if len(last_messages) > 5:
|
187 |
last_messages.pop(0)
|
188 |
last_messages.append(full_conv)
|
189 |
+
print(last_messages)
|
190 |
return answer
|
191 |
|
192 |
"""
|