File size: 2,884 Bytes
4279f4f 53abc6e 4279f4f d7c6080 4279f4f 8a0d941 d7c6080 4279f4f d7c6080 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
import torch
import json
def generate_rag_prompt_message(context, question):
prompt = f'Olet tekoälyavustaja joka vastaa annetun kontekstin perusteella asiantuntevasti ja ystävällisesti käyttäjän kysymyksiin\n\nKonteksti: {context}\n\nKysymys: {question}\n\nVastaa yllä olevaan kysymykseen annetun kontekstin perusteella.'
prompt = [{'role': 'user', 'content': prompt}]
return prompt
class EndpointHandler():
def __init__(self, path=""):
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
self.model = AutoModelForCausalLM.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1", device_map='cuda:0', torch_dtype = torch.bfloat16).eval()
self.tokenizer = AutoTokenizer.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1")
self.generation_config = GenerationConfig(
pad_token_id = self.tokenizer.eos_token_id,
eos_token_id = self.tokenizer.convert_tokens_to_ids("</s>"),
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
print(data)
try:
inputs = data.pop("inputs",None)
context = inputs["context"]
question = inputs["question"]
messages = generate_rag_prompt_message(context, question)
inputs = self.tokenizer(
[
self.tokenizer.apply_chat_template(messages, tokenize=False)
]*1, return_tensors = "pt").to("cuda")
with torch.no_grad():
generated_ids = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
generation_config=self.generation_config, **{
"temperature": 0.1,
"penalty_alpha": 0.6,
"min_p": 0.5,
"do_sample": True,
"repetition_penalty": 1.28,
"min_length": 10,
"max_new_tokens": 250
})
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0]
try:
generated_answer = generated_text.split('[/INST]')[1].strip()
return json.dumps({"answer": generated_answer})
except Exception as e:
return json.dumps({"answer": str(e)})
except Exception as e:
print(e) |