abancp commited on
Commit
dd1b76c
·
verified ·
1 Parent(s): aa1287d

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +3 -1
inference_fine_tune.py CHANGED
@@ -47,6 +47,7 @@ def generate_response(prompt:str):
47
  temperature = 0.7
48
  top_k = 50
49
  i = 0
 
50
  while decoder_input.shape[1] < 2000:
51
  # Apply causal mask based on current decoder_input length
52
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
@@ -59,11 +60,12 @@ def generate_response(prompt:str):
59
  next_token = torch.multinomial(probs, num_samples=1)
60
  next_token = top_k_indices.gather(-1, next_token)
61
  word += tokenizer.decode([next_token.item()])
 
62
  i+=1
63
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
64
  if decoder_input.shape[1] > config['seq_len']:
65
  decoder_input = decoder_input[:,-config['seq_len']:]
66
  if next_token.item() == eos_token_id or i >= 1024:
67
  break
68
- print("Output : ",word)
69
  return word
 
47
  temperature = 0.7
48
  top_k = 50
49
  i = 0
50
+ print("Output : ",end="")
51
  while decoder_input.shape[1] < 2000:
52
  # Apply causal mask based on current decoder_input length
53
  # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
 
60
  next_token = torch.multinomial(probs, num_samples=1)
61
  next_token = top_k_indices.gather(-1, next_token)
62
  word += tokenizer.decode([next_token.item()])
63
+ print(word,end="")
64
  i+=1
65
  decoder_input = torch.cat([decoder_input, next_token], dim=1)
66
  if decoder_input.shape[1] > config['seq_len']:
67
  decoder_input = decoder_input[:,-config['seq_len']:]
68
  if next_token.item() == eos_token_id or i >= 1024:
69
  break
70
+ print()
71
  return word