abancp commited on
Commit
aa1287d
·
verified ·
1 Parent(s): c78a3b6

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +2 -2
inference_fine_tune.py CHANGED
@@ -47,7 +47,7 @@ def generate_response(prompt:str):
47
  temperature = 0.7
48
  top_k = 50
49
  i = 0
50
- while i < 1024:
51
  # Apply causal mask based on current decoder_input length
52
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
53
  # Get model output
@@ -63,7 +63,7 @@ def generate_response(prompt:str):
63
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
64
  if decoder_input.shape[1] > config['seq_len']:
65
  decoder_input = decoder_input[:,-config['seq_len']:]
66
- if next_token.item() == eos_token_id:
67
  break
68
  print("Output : ",word)
69
  return word
 
47
  temperature = 0.7
48
  top_k = 50
49
  i = 0
50
+ while decoder_input.shape[1] < 2000:
51
  # Apply causal mask based on current decoder_input length
52
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
53
  # Get model output
 
63
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
64
  if decoder_input.shape[1] > config['seq_len']:
65
  decoder_input = decoder_input[:,-config['seq_len']:]
66
+ if next_token.item() == eos_token_id or i >= 1024:
67
  break
68
  print("Output : ",word)
69
  return word