Update modeling_qwen2.py
Browse filesFix https://huggingface.co/bytedance-research/ChatTS-14B/discussions/4
- modeling_qwen2.py +9 -2
modeling_qwen2.py
CHANGED
@@ -1450,6 +1450,9 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1450 |
attention_mask=attention_mask
|
1451 |
)
|
1452 |
|
|
|
|
|
|
|
1453 |
def _update_model_kwargs_for_generation(
|
1454 |
self,
|
1455 |
outputs: ModelOutput,
|
@@ -1505,8 +1508,12 @@ class Qwen2TSForCausalLM(Qwen2PreTrainedModel):
|
|
1505 |
if past_key_values is not None:
|
1506 |
if isinstance(past_key_values, Cache):
|
1507 |
cache_length = past_key_values.get_seq_length()
|
1508 |
-
past_length
|
1509 |
-
max_cache_length =
|
|
|
|
|
|
|
|
|
1510 |
else:
|
1511 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1512 |
max_cache_length = None
|
|
|
1450 |
attention_mask=attention_mask
|
1451 |
)
|
1452 |
|
1453 |
+
def _extract_past_from_model_output(self, outputs: ModelOutput):
|
1454 |
+
return "past_key_values", outputs.past_key_values
|
1455 |
+
|
1456 |
def _update_model_kwargs_for_generation(
|
1457 |
self,
|
1458 |
outputs: ModelOutput,
|
|
|
1508 |
if past_key_values is not None:
|
1509 |
if isinstance(past_key_values, Cache):
|
1510 |
cache_length = past_key_values.get_seq_length()
|
1511 |
+
past_length = past_key_values.seen_tokens
|
1512 |
+
max_cache_length = (
|
1513 |
+
past_key_values.get_max_length()
|
1514 |
+
if hasattr(past_key_values, "get_max_length")
|
1515 |
+
else past_key_values.get_max_cache_shape()
|
1516 |
+
)
|
1517 |
else:
|
1518 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1519 |
max_cache_length = None
|