Cannot generate with BS > 1

#25
by chenjiel - opened

Hi, Llama4 model developer, do you know how can we enable the HF transformers to load Llama4 scout and generate with BS > 1 inputs?

For example, I'm trying to test with the following code and will hit a shape error:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)

input_texts = ["What's the age of the earth?", "What's the age of the sun?"]
input_ids = tokenizer.batch_encode_plus(
    input_texts,
    return_tensors="pt",
    padding=True,
    truncation=True)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

outputs = model.generate(**input_ids.to(model.device), do_sample=False, max_new_tokens=200)
outputs = tokenizer.batch_decode(outputs)
print(outputs)

Error:

  File "/workspace/.local/lib/python3.12/site-packages/transformers/models/llama4/modeling_llama4.py", line 359, in forward
    attn_scales = attn_scales.view((*input_shape, 1, 1))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[2, 7, 1, 1]' is invalid for input of size 7

Stack:

  File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/generation/utils.py", line 2460, in generate
    result = self._sample(
             ^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/generation/utils.py", line 3426, in _sample
    outputs = self(**model_inputs, return_dict=True)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/models/llama4/modeling_llama4.py", line 1015, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/models/llama4/modeling_llama4.py", line 700, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/models/llama4/modeling_llama4.py", line 435, in forward
    attention_states, self_attn_weights = self.self_attn(
                                          ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1740, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1751, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/accelerate/hooks.py", line 176, in new_forward
    output = module._old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/.local/lib/python3.12/site-packages/transformers/models/llama4/modeling_llama4.py", line 359, in forward
    attn_scales = attn_scales.view((*input_shape, 1, 1))
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[2, 7, 1, 1]' is invalid for input of size 7
Meta Llama org

Fixed on main patching today!

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment