import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from transformers import pipeline class ChatBotManager(): def __init__(self): self.available_model_types = ['SmolLM2-1.7B (non-conversational)', \ 'CUSsupport-chat-t5-small (conversational)'] self.model = None self.selected_model_type = None def get_available_model_types(self): return self.available_model_types def __create_chatbot__(self, model_type, device='cpu'): if (model_type==self.available_model_types[0]): return SmolLM2_1_7B_ChatBot(device) elif (model_type==self.available_model_types[1]): return CUSsupport_chat_t5_small(device) elif (model_type==self.available_model_types[2]): return Distilgpt2_tiny_conversational(device) def obtain_answer(self, message, history, model_type): if (self.selected_model_type!=model_type): self.selected_model_type = model_type self.model = self.__create_chatbot__(self.selected_model_type) return self.model.get_answer(message, history) # Abstract class. class BaseChatBot(): def __init__(self, device='cpu'): self.device = device def __get_input_ids__(self, input_text, device='cpu'): input_ids = self.tokenizer(input_text, return_tensors="pt") input_ids = input_ids.to(device) return input_ids def get_answer(self, message, history): output_message = None if (len(message.strip())>0): input_ids = self.__get_input_ids__(message, self.device) new_ids_aux = self.model.generate(**input_ids, repetition_penalty=1.5, max_new_tokens=1500) new_text_aux = self.tokenizer.decode(new_ids_aux[0]) output_message = {'role': 'assistant', 'content': new_text_aux} return output_message class SmolLM2_1_7B_ChatBot(BaseChatBot): def __init__(self, device='cpu'): super().__init__(device) model_type = "HuggingFaceTB/SmolLM2-1.7B" self.tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir='./cache') self.model = AutoModelForCausalLM.from_pretrained(model_type, cache_dir='./cache') self.model.to(device) class CUSsupport_chat_t5_small(BaseChatBot): def __init__(self, device='cpu'): super().__init__(device) self.t2t_pipeline = pipeline("text2text-generation", \ model="mrSoul7766/CUSsupport-chat-t5-small", device=torch.device(self.device)) def get_answer(self, message, history): output_message = None if (len(message.strip())>0): new_text_aux = self.t2t_pipeline(f"answer: {message}", max_length=512)[0]['generated_text'] output_message = {'role': 'assistant', 'content': new_text_aux} return output_message