Ankerkraut commited on
Commit
6432c3a
·
1 Parent(s): a0c700c

fix arguments

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -60,7 +60,7 @@ def load_model():
60
  generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, attn_implementation="eager", use_flash_attention_2=False) # True for flash-attn2 else False
61
  return (generator, generator_mini)
62
 
63
- model_cache = None
64
 
65
  @spaces.GPU
66
  def get_model():
@@ -159,7 +159,6 @@ def search_qdrant_with_context(query_text, collection_name, top_k=3):
159
  def respond(
160
  query,
161
  history: list[tuple[str, str]],
162
- system_message,
163
  max_tokens,
164
  temperature,
165
  top_p,
@@ -184,7 +183,7 @@ def respond(
184
  colleciton_name = "products"
185
 
186
  context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
187
- answer = generate_response(query, context, max_tokens, temperature, top_p, generator[0])
188
  full_conv = f"Nutzer:{query};Assistent:{answer}"
189
  if len(last_messages) > 5:
190
  last_messages.pop(0)
@@ -207,6 +206,7 @@ demo = gr.ChatInterface(
207
  label="Top-p (nucleus sampling)",
208
  ),
209
  ],
 
210
  )
211
 
212
 
 
60
  generator_mini = pipeline(task="text-generation", model=ankerbot_model, tokenizer=ankerbot_tokenizer, torch_dtype=torch.float16, attn_implementation="eager", use_flash_attention_2=False) # True for flash-attn2 else False
61
  return (generator, generator_mini)
62
 
63
+ _model_cache = None
64
 
65
  @spaces.GPU
66
  def get_model():
 
159
  def respond(
160
  query,
161
  history: list[tuple[str, str]],
 
162
  max_tokens,
163
  temperature,
164
  top_p,
 
183
  colleciton_name = "products"
184
 
185
  context = search_qdrant_with_context(query + " " + refined_context[0]["generated_text"].split("assistant\n").pop(), collection_name)
186
+ answer = generate_response(query, context, last_messages, max_tokens, temperature, top_p, generator[0])
187
  full_conv = f"Nutzer:{query};Assistent:{answer}"
188
  if len(last_messages) > 5:
189
  last_messages.pop(0)
 
206
  label="Top-p (nucleus sampling)",
207
  ),
208
  ],
209
+ chatbot=gr.Chatbot(type="messages"),
210
  )
211
 
212