Update inference_fine_tune.py
Browse files- inference_fine_tune.py +2 -2
inference_fine_tune.py
CHANGED
@@ -47,7 +47,7 @@ def generate_response(prompt:str):
|
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
i = 0
|
50 |
-
while
|
51 |
# Apply causal mask based on current decoder_input length
|
52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
53 |
# Get model output
|
@@ -63,7 +63,7 @@ def generate_response(prompt:str):
|
|
63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
64 |
if decoder_input.shape[1] > config['seq_len']:
|
65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
66 |
-
if next_token.item() == eos_token_id:
|
67 |
break
|
68 |
print("Output : ",word)
|
69 |
return word
|
|
|
47 |
temperature = 0.7
|
48 |
top_k = 50
|
49 |
i = 0
|
50 |
+
while decoder_input.shape[1] < 2000:
|
51 |
# Apply causal mask based on current decoder_input length
|
52 |
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
|
53 |
# Get model output
|
|
|
63 |
decoder_input = torch.cat([decoder_input, next_token], dim=1)
|
64 |
if decoder_input.shape[1] > config['seq_len']:
|
65 |
decoder_input = decoder_input[:,-config['seq_len']:]
|
66 |
+
if next_token.item() == eos_token_id or i >= 1024:
|
67 |
break
|
68 |
print("Output : ",word)
|
69 |
return word
|