RTE Build
Deployment
a099612
from transformers import AutoModelForCausalLM, AutoTokenizer
from sandbox.light_rag.utils import get_device
class HFLLM:
def __init__(self, model_name: str):
self.device = get_device()
self.model_name = model_name
print("Loading HF model...")
# Load the tokenizer and model from Hugging Face
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
def generate(self, prompt: str) -> list:
# tokenize the text
model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
# print(f"gen txt: {generated_texts}")
response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}]
return response