kwabs22
Testing Suggested Code Fix
acfdbf2
raw
history blame contribute delete
2.33 kB
import spaces
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import subprocess
import os
def install_cuda_toolkit(): #Swiftly Provided by https://huggingface.co/John6666 to fix OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.
# CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda_11.8.0_520.61.05_linux.run"
CUDA_TOOLKIT_URL = "https://developer.download.nvidia.com/compute/cuda/12.2.0/local_installers/cuda_12.2.0_535.54.03_linux.run"
CUDA_TOOLKIT_FILE = "/tmp/%s" % os.path.basename(CUDA_TOOLKIT_URL)
subprocess.call(["wget", "-q", CUDA_TOOLKIT_URL, "-O", CUDA_TOOLKIT_FILE])
subprocess.call(["chmod", "+x", CUDA_TOOLKIT_FILE])
subprocess.call([CUDA_TOOLKIT_FILE, "--silent", "--toolkit"])
os.environ["CUDA_HOME"] = "/usr/local/cuda"
os.environ["PATH"] = "%s/bin:%s" % (os.environ["CUDA_HOME"], os.environ["PATH"])
os.environ["LD_LIBRARY_PATH"] = "%s/lib:%s" % (
os.environ["CUDA_HOME"],
"" if "LD_LIBRARY_PATH" not in os.environ else os.environ["LD_LIBRARY_PATH"],
)
# Fix: arch_list[-1] += '+PTX'; IndexError: list index out of range
os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6"
install_cuda_toolkit() #Swiftly Provided by https://huggingface.co/John6666 to fix OSError: CUDA_HOME environment variable is not set. Please set it to your CUDA install root.
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("ISTA-DASLab/Meta-Llama-3.1-70B-AQLM-PV-2Bit-1x16")
model = AutoModelForCausalLM.from_pretrained("ISTA-DASLab/Meta-Llama-3.1-70B-AQLM-PV-2Bit-1x16", torch_dtype='auto', device_map='auto').to(device)
@spaces.GPU
def generate_text(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(inputs.input_ids) #, max_length=100)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
interface = gr.Interface(
fn=generate_text,
inputs="text",
outputs="text",
title="Meta-Llama-3.1-70B-AQLM-PV-2Bit-1x16 Text Generation",
description="Enter a prompt and generate text using Meta-Llama-3.1-70B-AQLM-PV-2Bit-1x16. Responses are a little bit different Meta-Llama-3.1-70B",
)
interface.launch()