Update inference_fine_tune.py
Browse files- inference_fine_tune.py +3 -1
inference_fine_tune.py
CHANGED
@@ -47,6 +47,7 @@ def generate_response(prompt:str):
|
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
i = 0
|
|
|
50 |
while decoder_input.shape[1] < 2000:
|
51 |
# Apply causal mask based on current decoder_input length
|
52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
@@ -59,11 +60,12 @@ def generate_response(prompt:str):
|
|
59 |
next_token = torch.multinomial(probs, num_samples=1)
|
60 |
next_token = top_k_indices.gather(-1, next_token)
|
61 |
word += tokenizer.decode([next_token.item()])
|
|
|
62 |
i+=1
|
63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
64 |
if decoder_input.shape[1] > config['seq_len']:
|
65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
66 |
if next_token.item() == eos_token_id or i >= 1024:
|
67 |
break
|
68 |
-
print(
|
69 |
return word
|
|
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
i = 0
|
50 |
+
print("Output : ",end="")
|
51 |
while decoder_input.shape[1] < 2000:
|
52 |
# Apply causal mask based on current decoder_input length
|
53 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
|
|
60 |
next_token = torch.multinomial(probs, num_samples=1)
|
61 |
next_token = top_k_indices.gather(-1, next_token)
|
62 |
word += tokenizer.decode([next_token.item()])
|
63 |
+
print(word,end="")
|
64 |
i+=1
|
65 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
66 |
if decoder_input.shape[1] > config['seq_len']:
|
67 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
68 |
if next_token.item() == eos_token_id or i >= 1024:
|
69 |
break
|
70 |
+
print()
|
71 |
return word
|