Spaces:
Runtime error
Runtime error
import torch | |
import gradio as gr | |
from transformers import AutoTokenizer, pipeline, logging | |
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig | |
model_name_or_path = "TheBloke/WizardCoder-Guanaco-15B-V1.1-GPTQ" | |
model_basename = "gptq_model-4bit-128g" | |
use_triton = False | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
model = AutoGPTQForCausalLM.from_quantized(model_name_or_path, | |
model_basename=model_basename, | |
use_safetensors=True, | |
trust_remote_code=False, | |
device=device, | |
use_triton=use_triton, | |
quantize_config=None, | |
cache_dir="models/" | |
) | |
""" | |
To download from a specific branch, use the revision parameter, as in this example: | |
model = AutoGPTQForCausalLM.from_quantized(model_name_or_path, | |
revision="gptq-4bit-32g-actorder_True", | |
model_basename=model_basename, | |
use_safetensors=True, | |
trust_remote_code=False, | |
device="cuda:0", | |
quantize_config=None) | |
""" | |
def code_gen(text): | |
# input_ids = tokenizer(text, return_tensors='pt').input_ids.to(device) | |
# output = model.generate( | |
# inputs=input_ids, temperature=0.7, max_new_tokens=124) | |
# print(tokenizer.decode(output[0])) | |
# Inference can also be done using transformers' pipeline | |
# Prevent printing spurious transformers error when using pipeline with AutoGPTQ | |
logging.set_verbosity(logging.CRITICAL) | |
print("*** Pipeline:") | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=124, | |
temperature=0.7, | |
top_p=0.95, | |
repetition_penalty=1.15 | |
) | |
response = pipe(text) | |
print(response) | |
return response[0]['generated_text'] | |
iface = gr.Interface(fn=code_gen, | |
inputs=gr.inputs.Textbox( | |
label="Input Source Code"), | |
outputs="text", | |
title="Code Generation") | |
iface.launch() | |