from transformers import AutoModelForCausalLM, AutoTokenizer from sandbox.light_rag.utils import get_device class HFLLM: def __init__(self, model_name: str): self.device = get_device() self.model_name = model_name print("Loading HF model...") # Load the tokenizer and model from Hugging Face self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device) def generate(self, prompt: str) -> list: # tokenize the text model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024) generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) # print(f"gen txt: {generated_texts}") response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}] return response