Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from sandbox.light_rag.utils import get_device | |
class HFLLM: | |
def __init__(self, model_name: str): | |
self.device = get_device() | |
self.model_name = model_name | |
print("Loading HF model...") | |
# Load the tokenizer and model from Hugging Face | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) | |
def generate(self, prompt: str) -> list: | |
# tokenize the text | |
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) | |
generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024) | |
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) | |
# print(f"gen txt: {generated_texts}") | |
response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}] | |
return response | |