Gemma2-2B-IT-Byte ๐Ÿ”ข

Gemma2-2B transferred to byte-level tokenization via cross-tokenizer distillation.

๐ŸšงThis model is intended as a proof-of-concept that we can quickly & effectively transfer pretrained (subword-based) models to the byte-level. It is not optimized for production use (in particular, it is not optimized for speed)!๐Ÿšง

Benchmarks

Gemma-2B-IT-Byte performs competitively although it has been trained only on 1.3B bytes (328M subword tokens total).

MMLU BoolQ PiQA IFEval ARC-C Avg.
EvaByte-6.5B-SFT 49.5 79.5* 74.1* 60.2 64.6* 65.6
Llama3.2-3B-Instruct (original) 62.4 78.8 76.9 76.6 43.9 67.7
Gemma2-2B-IT (original) 56.9 83.8 79.6 62.5 50.4 66.6
Llama3-2-3B-IT-Byte 57.0 76.6 73.6 58.8 39.8 61.2
Gemma2-2B-IT-Byte (this model) 51.0 80.5 71.5 51.9 38.2 58.6

*Numbers from EvaByte-6.5B (Base) since they are not reported for the SFT model.

Usage

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("benjamin/Gemma2-2B-IT-Byte")
print("Vocab Size:", len(tokenizer))  # 256 bytes + some special tokens

device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
    "benjamin/Gemma2-2B-IT-Byte", trust_remote_code=True
)
model = model.to(device)

tokens = tokenizer.apply_chat_template(
    [{"role": "user", "content": "Hello, how are you doing?"}], return_tensors="pt"
)
eot_id = tokenizer.convert_tokens_to_ids("<end_of_turn>")
out = model.generate(tokens.to(model.device), eos_token_id=eot_id)
print(tokenizer.decode(out[0]))

Training

This model has been trained using tokenkit with the following command:

python3 scripts/cross_tokenizer_distill.py \
    --config=configs/cross_tokenizer_distill.yaml \
    --overrides \
    losses=[sft,alm_unconstrained,alm_latents] \
    multitask_aggregation_fn=approx_gradmag_preserve_mag \
    alm_mode=merge_by_space_prob+append_space \
    tokenizer_pair_bias_threshold=0.1 \
    max_student_length=2048 \
    steps=20000 \
    eval_interval=20000 \
    save_interval=20000 \
    optimizer.learning_rate=3.e-5 \
    optimizer.weight_decay=0.0 \
    optimizer.max_grad_norm=null \
    optimizer.grad_acc_steps=1 \
    train_model_mode=full \
    expand_input_ids=true \
    output_embeddings_mode=untie \
    eval.tasks=[arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn] \
    data.batch_size=32 \
    student.pretrained_model_name_or_path=benjamin/gemma-2-2b-it-flax \
    student.tokenizer_name=google/gemma-2-2b-it:source=Gemma2 \
    target_tokenizer_name=google/gemma-2-2b-it:source=Gemma2:target=Gemma2:conversion=byte \
    n_model_parallel=4 \
    n_data_parallel=4 \
    data.num_workers=16 \
    num_workers=16 \
    name=gemma2_to_byte_20k

Training took ~10 hours on a TPUv4-32.

Future Work

The current version of this model is trained for 20k steps with 32*2048 bytes per batch (= 1.3B bytes โ‰ˆ 328M subword tokens total). It was unexpected that it performs as well as it does with this very short training procedure. We plan to train a new version for more steps (you can also do so yourself using tokenkit).

To preserve efficiency, we would have to add (a combination of) BLT-style hierarchical processing, attention approximations, and self-speculative decoding.

Acknowledgments

Training was enabled by Cloud TPUs from Googleโ€™s TPU Research Cloud (TRC).

Downloads last month
38
Safetensors
Model size
2.62B params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for benjamin/Gemma2-2B-IT-Byte

Base model

google/gemma-2-2b
Finetuned
(601)
this model

Collection including benjamin/Gemma2-2B-IT-Byte