Omnibus commited on
Commit
29a3f51
·
verified ·
1 Parent(s): 28a7aef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -2
app.py CHANGED
@@ -1,4 +1,30 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def quantize():
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from huggingface_hub import snapshot_download
3
+ from accelerate.utils import BnbQuantizationConfig
4
+ from accelerate.utils import load_and_quantize_model
5
+ from accelerate import Accelerator
6
 
7
+ model_path="marcsun13/gpt2-xl-linear-sharded"
8
+
9
+ def quantize(model_path=model_path):
10
+ print("1")
11
+ weights_location = snapshot_download(repo_id=f"{model_path}")
12
+ print("2")
13
+ bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold = 6)
14
+ #bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4")
15
+ print("3")
16
+ quantized_model = load_and_quantize_model(empty_model, weights_location=weights_location, bnb_quantization_config=bnb_quantization_config, device_map = "auto")
17
+ print("4")
18
+ accelerate = Accelerator()
19
+ print("5")
20
+ new_weights_location = "./model"
21
+ print("6")
22
+ accelerate.save_model(quantized_model, new_weights_location)
23
+ print("7")
24
+ quantized_model_from_saved = load_and_quantize_model(empty_model, weights_location=new_weights_location, bnb_quantization_config=bnb_quantization_config, device_map = "auto")
25
+ print("Done")
26
+
27
+ with gr.Blocks() as app:
28
+ btn=gr.Button()
29
+ btn.click(quantize,None,None)
30
+ app.launch()