Spaces:
Sleeping
Sleeping
Commit
·
6432c3a
1
Parent(s):
a0c700c
fix arguments
Browse files
app.py
CHANGED
@@ -60,7 +60,7 @@ def load_model():
|
|
60 |
generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, attn_implementation="eager", use_flash_attention_2=False) # True for flash-attn2 else False
|
61 |
return (generator, generator_mini)
|
62 |
|
63 |
-
|
64 |
|
65 |
@spaces.GPU
|
66 |
def get_model():
|
@@ -159,7 +159,6 @@ def search_qdrant_with_context(query_text, collection_name, top_k=3):
|
|
159 |
def respond(
|
160 |
query,
|
161 |
history: list[tuple[str, str]],
|
162 |
-
system_message,
|
163 |
max_tokens,
|
164 |
temperature,
|
165 |
top_p,
|
@@ -184,7 +183,7 @@ def respond(
|
|
184 |
colleciton_name = "products"
|
185 |
|
186 |
context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
|
187 |
-
answer = generate_response(query, context, max_tokens, temperature, top_p, generator[0])
|
188 |
full_conv = f"Nutzer:{query};Assistent:{answer}"
|
189 |
if len(last_messages) > 5:
|
190 |
last_messages.pop(0)
|
@@ -207,6 +206,7 @@ demo = gr.ChatInterface(
|
|
207 |
label="Top-p (nucleus sampling)",
|
208 |
),
|
209 |
],
|
|
|
210 |
)
|
211 |
|
212 |
|
|
|
60 |
generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, attn_implementation="eager", use_flash_attention_2=False) # True for flash-attn2 else False
|
61 |
return (generator, generator_mini)
|
62 |
|
63 |
+
_model_cache = None
|
64 |
|
65 |
@spaces.GPU
|
66 |
def get_model():
|
|
|
159 |
def respond(
|
160 |
query,
|
161 |
history: list[tuple[str, str]],
|
|
|
162 |
max_tokens,
|
163 |
temperature,
|
164 |
top_p,
|
|
|
183 |
colleciton_name = "products"
|
184 |
|
185 |
context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
|
186 |
+
answer = generate_response(query, context, last_messages, max_tokens, temperature, top_p, generator[0])
|
187 |
full_conv = f"Nutzer:{query};Assistent:{answer}"
|
188 |
if len(last_messages) > 5:
|
189 |
last_messages.pop(0)
|
|
|
206 |
label="Top-p (nucleus sampling)",
|
207 |
),
|
208 |
],
|
209 |
+
chatbot=gr.Chatbot(type="messages"),
|
210 |
)
|
211 |
|
212 |
|