dar-tau commited on
Commit
8555522
·
verified ·
1 Parent(s): 5bd57b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -54,7 +54,7 @@ def get_past_key_values(system_prompt):
54
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
55
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
56
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
57
- return model(tokenized.to(model.device)).past_key_values
58
 
59
 
60
  @spaces.GPU
@@ -64,7 +64,7 @@ def generate(text, past_key_values):
64
  {'role': 'user', 'content': text}
65
  ]
66
  response = pipe(messages,
67
- past_key_values=past_key_values,
68
  **generate_kwargs)[0]['generated_text']
69
  return response[-1]['content']
70
 
 
54
  test_messages = [*start_messages, {'role': 'user', 'content': 'Hello World!'}]
55
  tokenized_test = tokenizer.apply_chat_template(test_messages, return_tensors='pt')
56
  assert (tokenized_test[:, :tokenized.shape[1]] == tokenized).all().cpu().item()
57
+ return model(tokenized.to(model.device)).past_key_values.cpu().detach()
58
 
59
 
60
  @spaces.GPU
 
64
  {'role': 'user', 'content': text}
65
  ]
66
  response = pipe(messages,
67
+ past_key_values=past_key_values.to(model.device),
68
  **generate_kwargs)[0]['generated_text']
69
  return response[-1]['content']
70