File size: 2,151 Bytes
b87a3ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# coding=utf-8
# Quantizes fine-tuned models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
# Usage: python quantize.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json
#                           --max_length 1024 --max_samples 1024
# dataset format: instruction (string), input (string), output (string), history (List[string])


import fire
from datasets import load_dataset
from transformers import AutoTokenizer
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig


def quantize(input_dir: str, output_dir: str, data_file: str, max_length: int, max_samples: int):
    tokenizer = AutoTokenizer.from_pretrained(input_dir, use_fast=False, padding_side="left")

    def format_example(examples):
        prefix=("A chat between a curious user and an artificial intelligence assistant. "
                "The assistant gives helpful, detailed, and polite answers to the user's questions.")
        texts = []
        for i in range(len(examples["instruction"])):
            prompt = prefix + "\n"
            if "history" in examples:
                for user_query, bot_resp in examples["history"][i]:
                    prompt += "Human: {}\nAssistant: {}\n".format(user_query, bot_resp)
            prompt += "Human: {}\nAssistant: {}".format(
                examples["instruction"][i] + "\n" + examples["input"][i], examples["output"][i]
            )
            texts.append(prompt)
        return tokenizer(texts, truncation=True, max_length=max_length)

    dataset = load_dataset("json", data_files=data_file)["train"]
    column_names = list(dataset.column_names)
    dataset = dataset.select(range(min(len(dataset), max_samples)))
    dataset = dataset.map(format_example, batched=True, remove_columns=column_names)
    dataset = dataset.shuffle()

    quantize_config = BaseQuantizeConfig(
        bits=4,
        group_size=128,
        desc_act=False
    )

    model = AutoGPTQForCausalLM.from_pretrained(input_dir, quantize_config, trust_remote_code=True)
    model.quantize(dataset)
    model.save_quantized(output_dir)


if __name__ == "__main__":
    fire.Fire(quantize)