Omnibus's picture
Update app.py
ff28935 verified
import gradio as gr
from huggingface_hub import snapshot_download
from accelerate.utils import BnbQuantizationConfig
from accelerate.utils import load_and_quantize_model
from accelerate import Accelerator
from accelerate import init_empty_weights
#from mingpt.model import GPT
model_path="marcsun13/gpt2-xl-linear-sharded"
def quantize(model_path=model_path):
print("1")
weights_location = snapshot_download(repo_id=f"{model_path}")
print("2")
bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold = 6)
#bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
print("3")
#model_config = GPT.get_default_config()
#model_config.model_type = 'gpt2-xl'
#model_config.vocab_size = 50257
#model_config.block_size = 1024
print(weights_location)
print(weights_location.config)
with init_empty_weights():
empty_model = weights_location(model_config)
quantized_model = load_and_quantize_model(empty_model, weights_location=weights_location, bnb_quantization_config=bnb_quantization_config, device_map = "auto")
print("4")
accelerate = Accelerator()
print("5")
new_weights_location = "./model"
print("6")
accelerate.save_model(quantized_model, new_weights_location)
print("7")
quantized_model_from_saved = load_and_quantize_model(empty_model, weights_location=new_weights_location, bnb_quantization_config=bnb_quantization_config, device_map = "auto")
print("Done")
with gr.Blocks() as app:
btn=gr.Button()
btn.click(quantize,None,None)
app.launch()