Spaces:
Sleeping
Sleeping
import os | |
if os.environ.get("SPACES_ZERO_GPU") is not None: | |
import spaces | |
else: | |
class spaces: | |
def GPU(func): | |
def wrapper(*args, **kwargs): | |
return func(*args, **kwargs) | |
return wrapper | |
import gradio as gr | |
from pathlib import Path | |
import gc | |
import shutil | |
import torch | |
from utils import set_token, upload_repo, is_repo_exists, is_repo_name | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from transformers import BitsAndBytesConfig | |
def fake_gpu(): | |
pass | |
MODEL_CLASS = { | |
"AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer], | |
} | |
DTYPE_DICT = { | |
"fp16": torch.float16, | |
"bf16": torch.bfloat16, | |
"fp32": torch.float32, | |
"fp8": torch.float8_e4m3fn | |
} | |
def get_model_class(): | |
return list(MODEL_CLASS.keys()) | |
def get_model(mclass: str): | |
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0] | |
def get_tokenizer(mclass: str): | |
return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1] | |
def get_dtype(dtype: str): | |
return DTYPE_DICT.get(dtype, torch.bfloat16) | |
def save_readme_md(dir, repo_id): | |
orig_name = repo_id | |
orig_url = f"https://huggingface.co/{repo_id}/" | |
md = f"""--- | |
license: other | |
language: | |
- en | |
library_name: transformers | |
base_model: {repo_id} | |
tags: | |
- transformers | |
--- | |
Quants of [{orig_name}]({orig_url}). | |
""" | |
path = str(Path(dir, "README.md")) | |
with open(path, mode='w', encoding="utf-8") as f: | |
f.write(md) | |
def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)): | |
progress(0, desc="Start quantizing...") | |
out_dir = repo_id.split("/")[-1] | |
type_kwargs = {} | |
if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype) | |
nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype), | |
bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype)) | |
quant_kwargs = {} | |
if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config | |
progress(0.1, desc="Loading...") | |
tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False) | |
model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs) | |
progress(0.5, desc="Saving...") | |
tokenizer.save_pretrained(out_dir) | |
model.save_pretrained(out_dir, safe_serialization=True) | |
if Path(out_dir).exists(): save_readme_md(out_dir, repo_id) | |
del tokenizer | |
del model | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress(1, desc="Quantized.") | |
return out_dir | |
def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False, | |
dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)): | |
if not hf_token: hf_token = os.environ.get("HF_TOKEN") # default huggingface token | |
if not hf_token: raise gr.Error("HF write token is required for this process.") | |
set_token(hf_token) | |
if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO") # default repo id | |
if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}") | |
if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}") | |
if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}") | |
progress(0, desc="Start quantizing...") | |
new_path = quantize_repo(repo_id, dtype, qtype, mclass) | |
if not new_path: return "" | |
if not urls: urls = [] | |
progress(0.5, desc="Start uploading...") | |
repo_url = upload_repo(newrepo_id, new_path, is_private) | |
progress(1, desc="Processing...") | |
shutil.rmtree(new_path) | |
urls.append(repo_url) | |
md = "### Your new repo:\n" | |
for u in urls: | |
md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>" | |
torch.cuda.empty_cache() | |
gc.collect() | |
return gr.update(value=urls, choices=urls), gr.update(value=md) | |