File size: 1,000 Bytes
a099612
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers import AutoModelForCausalLM, AutoTokenizer

from sandbox.light_rag.utils import get_device


class HFLLM:
    def __init__(self, model_name: str):
        self.device = get_device()
        self.model_name = model_name
        print("Loading HF model...")
        # Load the tokenizer and model from Hugging Face
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)

    def generate(self, prompt: str) -> list:
        # tokenize the text

        model_inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        generated_ids = self.model.generate(**model_inputs, max_new_tokens=1024)
        generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
        # print(f"gen txt: {generated_texts}")

        response = [{"answer": generated_texts[0].split("<|end_of_role|>")[-1].split("<|end_of_text|>")[0]}]
        return response