Update inference_fine_tune.py
Browse files- inference_fine_tune.py +3 -3
inference_fine_tune.py
CHANGED
@@ -46,8 +46,8 @@ def generate_response(prompt:str):
|
|
46 |
decoder_input = decoder_input.unsqueeze(0)
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
-
|
50 |
-
while
|
51 |
# Apply causal mask based on current decoder_input length
|
52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
53 |
# Get model output
|
@@ -59,7 +59,7 @@ def generate_response(prompt:str):
|
|
59 |
next_token = torch.multinomial(probs, num_samples=1)
|
60 |
next_token = top_k_indices.gather(-1, next_token)
|
61 |
word += tokenizer.decode([next_token.item()])
|
62 |
-
|
63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
64 |
if decoder_input.shape[1] > config['seq_len']:
|
65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
|
|
46 |
decoder_input = decoder_input.unsqueeze(0)
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
+
i = 0
|
50 |
+
while i < 1024:
|
51 |
# Apply causal mask based on current decoder_input length
|
52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
53 |
# Get model output
|
|
|
59 |
next_token = torch.multinomial(probs, num_samples=1)
|
60 |
next_token = top_k_indices.gather(-1, next_token)
|
61 |
word += tokenizer.decode([next_token.item()])
|
62 |
+
i+=1
|
63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
64 |
if decoder_input.shape[1] > config['seq_len']:
|
65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|