Llama3-2-3B-IT-Byte ๐ข
Llama3.2-3B transferred to byte-level tokenization via cross-tokenizer distillation.
๐งThis model is intended as a proof-of-concept that we can quickly & effectively transfer pretrained (subword-based) models to the byte-level. It is not optimized for production use (in particular, it is not optimized for speed)!๐ง
Benchmarks
Llama3-2-3B-IT-Byte performs competitively although it has been trained only on 1.3B bytes (328M subword tokens total).
MMLU | BoolQ | PiQA | IFEval | ARC-C | Avg. | |
---|---|---|---|---|---|---|
EvaByte-6.5B-SFT | 49.5 | 79.5* | 74.1* | 60.2 | 64.6* | 65.6 |
Llama3.2-3B-Instruct (original) | 62.4 | 78.8 | 76.9 | 76.6 | 43.9 | 67.7 |
Gemma2-2B-IT (original) | 56.9 | 83.8 | 79.6 | 62.5 | 50.4 | 66.6 |
Llama3-2-3B-IT-Byte (this model) | 57.0 | 76.6 | 73.6 | 58.8 | 39.8 | 61.2 |
Gemma2-2B-IT-Byte | 51.0 | 80.5 | 71.5 | 51.9 | 38.2 | 58.6 |
*Numbers from EvaByte-6.5B (Base) since they are not reported for the SFT model.
Usage
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("benjamin/Llama3-2-3B-IT-Byte")
print("Vocab Size:", len(tokenizer)) # 256 bytes + some special tokens
device = "cuda"
model = AutoModelForCausalLM.from_pretrained(
"benjamin/Llama3-2-3B-IT-Byte", trust_remote_code=True
)
model = model.to(device)
tokens = tokenizer.apply_chat_template(
[{"role": "user", "content": "Hello, how are you doing?"}], return_tensors="pt"
)
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>")
out = model.generate(tokens.to(model.device), eos_token_id=eot_id)
print(tokenizer.decode(out[0]))
Training
This model has been trained using tokenkit
with the following command:
python3 scripts/cross_tokenizer_distill.py \
--config=configs/cross_tokenizer_distill.yaml \
--overrides \
losses=[sft,alm_unconstrained,alm_latents] \
multitask_aggregation_fn=approx_gradmag_preserve_mag \
alm_mode=merge_by_space_prob+append_space \
tokenizer_pair_bias_threshold=0.1 \
max_student_length=2048 \
steps=20000 \
eval_interval=20000 \
save_interval=20000 \
optimizer.learning_rate=3.e-5 \
optimizer.weight_decay=0.0 \
optimizer.max_grad_norm=null \
optimizer.grad_acc_steps=1 \
train_model_mode=full \
expand_input_ids=true \
output_embeddings_mode=untie \
eval.tasks=[arc_easy,arc_challenge,piqa,boolq,arithmetic,mmlu,ifeval,agieval_en,agieval_cn] \
data.batch_size=32 \
student.pretrained_model_name_or_path=benjamin/Llama-3.2-3B-Instruct-flax \
student.tokenizer_name=meta-llama/Llama-3.2-3B-Instruct:source=Llama3 \
target_tokenizer_name=meta-llama/Llama-3.2-3B-Instruct:source=Llama3:target=Llama3:conversion=byte \
n_model_parallel=4 \
n_data_parallel=4 \
data.num_workers=16 \
num_workers=16 \
name=llama3_to_byte_20k
Training took ~26 hours on a TPU v4-32.
Future Work
The current version of this model is trained for 20k steps with 32*2048 bytes per batch (= 1.3B bytes โ 328M subword tokens total). It was unexpected that it performs as well as it does with this very short training procedure. We plan to train a new version for more steps (you can also do so yourself using tokenkit
).
To preserve efficiency, we would have to add (a combination of) BLT-style hierarchical processing, attention approximations, and self-speculative decoding.
Acknowledgments
Training was enabled by Cloud TPUs from Googleโs TPU Research Cloud (TRC).
- Downloads last month
- 46
Model tree for benjamin/Llama3-2-3B-IT-Byte
Base model
meta-llama/Llama-3.2-3B-Instruct