from typing import Dict, List, Any from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig import torch import json def generate_rag_prompt_message(context, question): prompt = f'Olet tekoälyavustaja joka vastaa annetun kontekstin perusteella asiantuntevasti ja ystävällisesti käyttäjän kysymyksiin\n\nKonteksti: {context}\n\nKysymys: {question}\n\nVastaa yllä olevaan kysymykseen annetun kontekstin perusteella.' prompt = [{'role': 'user', 'content': prompt}] return prompt class EndpointHandler(): def __init__(self, path=""): # Preload all the elements you are going to need at inference. # pseudo: # self.model= load_model(path) self.model = AutoModelForCausalLM.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1", device_map='cuda:0', torch_dtype = torch.bfloat16).eval() self.tokenizer = AutoTokenizer.from_pretrained(f"RASMUS/Ahma-3B-Instruct-RAG-v0.1") self.generation_config = GenerationConfig( pad_token_id = self.tokenizer.eos_token_id, eos_token_id = self.tokenizer.convert_tokens_to_ids(""), ) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: """ data args: inputs (:obj: `str` | `PIL.Image` | `np.array`) kwargs Return: A :obj:`list` | `dict`: will be serialized and returned """ print(data) try: inputs = data.pop("inputs",None) context = inputs["context"] question = inputs["question"] messages = generate_rag_prompt_message(context, question) inputs = self.tokenizer( [ self.tokenizer.apply_chat_template(messages, tokenize=False) ]*1, return_tensors = "pt").to("cuda") with torch.no_grad(): generated_ids = self.model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], generation_config=self.generation_config, **{ "temperature": 0.1, "penalty_alpha": 0.6, "min_p": 0.5, "do_sample": True, "repetition_penalty": 1.28, "min_length": 10, "max_new_tokens": 250 }) generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True)[0] try: generated_answer = generated_text.split('[/INST]')[1].strip() return json.dumps({"answer": generated_answer}) except Exception as e: return json.dumps({"answer": str(e)}) except Exception as e: print(e)