|
from typing import Dict, List, Any |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig |
|
import torch |
|
import json |
|
|
|
|
|
def generate_rag_prompt_message(context, question): |
|
prompt = f'Olet tekoälyavustaja joka vastaa annetun kontekstin perusteella asiantuntevasti ja ystävällisesti käyttäjän kysymyksiin\n\nKonteksti: {context}\n\nKysymys: {question}\n\nVastaa yllä olevaan kysymykseen annetun kontekstin perusteella.' |
|
prompt = [{'role': 'user', 'content': prompt}] |
|
return prompt |
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1", device_map='cuda:0', torch_dtype = torch.bfloat16).eval() |
|
self.tokenizer = AutoTokenizer.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1") |
|
self.generation_config = GenerationConfig( |
|
pad_token_id = self.tokenizer.eos_token_id, |
|
eos_token_id = self.tokenizer.convert_tokens_to_ids("</s>"), |
|
) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
""" |
|
data args: |
|
inputs (:obj: `str` | `PIL.Image` | `np.array`) |
|
kwargs |
|
Return: |
|
A :obj:`list` | `dict`: will be serialized and returned |
|
""" |
|
print(data) |
|
try: |
|
inputs = data.pop("inputs",None) |
|
context = inputs["context"] |
|
question = inputs["question"] |
|
|
|
messages = generate_rag_prompt_message(context, question) |
|
|
|
inputs = self.tokenizer( |
|
[ |
|
self.tokenizer.apply_chat_template(messages, tokenize=False) |
|
]*1, return_tensors = "pt").to("cuda") |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = self.model.generate( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
generation_config=self.generation_config, **{ |
|
"temperature": 0.1, |
|
"penalty_alpha": 0.6, |
|
"min_p": 0.5, |
|
"do_sample": True, |
|
"repetition_penalty": 1.28, |
|
"min_length": 10, |
|
"max_new_tokens": 250 |
|
}) |
|
|
|
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0] |
|
try: |
|
generated_answer = generated_text.split('[/INST]')[1].strip() |
|
return json.dumps({"answer": generated_answer}) |
|
except Exception as e: |
|
return json.dumps({"answer": str(e)}) |
|
except Exception as e: |
|
print(e) |