abancp commited on
Commit
c78a3b6
·
verified ·
1 Parent(s): ad5e460

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +3 -3
inference_fine_tune.py CHANGED
@@ -46,8 +46,8 @@ def generate_response(prompt:str):
46
  decoder_input = decoder_input.unsqueeze(0)
47
  temperature = 0.7
48
  top_k = 50
49
-
50
- while decoder_input.shape[1] < 2000 :
51
  # Apply causal mask based on current decoder_input length
52
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
53
  # Get model output
@@ -59,7 +59,7 @@ def generate_response(prompt:str):
59
  next_token = torch.multinomial(probs, num_samples=1)
60
  next_token = top_k_indices.gather(-1, next_token)
61
  word += tokenizer.decode([next_token.item()])
62
-
63
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
64
  if decoder_input.shape[1] > config['seq_len']:
65
  decoder_input = decoder_input[:,-config['seq_len']:]
 
46
  decoder_input = decoder_input.unsqueeze(0)
47
  temperature = 0.7
48
  top_k = 50
49
+ i = 0
50
+ while i < 1024:
51
  # Apply causal mask based on current decoder_input length
52
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
53
  # Get model output
 
59
  next_token = torch.multinomial(probs, num_samples=1)
60
  next_token = top_k_indices.gather(-1, next_token)
61
  word += tokenizer.decode([next_token.item()])
62
+ i+=1
63
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
64
  if decoder_input.shape[1] > config['seq_len']:
65
  decoder_input = decoder_input[:,-config['seq_len']:]