Spaces:
Running
Running
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from transformers import pipeline | |
class ChatBotManager(): | |
def __init__(self): | |
self.available_model_types = ['SmolLM2-1.7B (non-conversational)', \ | |
'CUSsupport-chat-t5-small (conversational)'] | |
self.model = None | |
self.selected_model_type = None | |
def get_available_model_types(self): | |
return self.available_model_types | |
def __create_chatbot__(self, model_type, device='cpu'): | |
if (model_type==self.available_model_types[0]): | |
return SmolLM2_1_7B_ChatBot(device) | |
elif (model_type==self.available_model_types[1]): | |
return CUSsupport_chat_t5_small(device) | |
elif (model_type==self.available_model_types[2]): | |
return Distilgpt2_tiny_conversational(device) | |
def obtain_answer(self, message, history, model_type): | |
if (self.selected_model_type!=model_type): | |
self.selected_model_type = model_type | |
self.model = self.__create_chatbot__(self.selected_model_type) | |
return self.model.get_answer(message, history) | |
# Abstract class. | |
class BaseChatBot(): | |
def __init__(self, device='cpu'): | |
self.device = device | |
def __get_input_ids__(self, input_text, device='cpu'): | |
input_ids = self.tokenizer(input_text, return_tensors="pt") | |
input_ids = input_ids.to(device) | |
return input_ids | |
def get_answer(self, message, history): | |
output_message = None | |
if (len(message.strip())>0): | |
input_ids = self.__get_input_ids__(message, self.device) | |
new_ids_aux = self.model.generate(**input_ids, repetition_penalty=1.5, max_new_tokens=1500) | |
new_text_aux = self.tokenizer.decode(new_ids_aux[0]) | |
output_message = {'role': 'assistant', 'content': new_text_aux} | |
return output_message | |
class SmolLM2_1_7B_ChatBot(BaseChatBot): | |
def __init__(self, device='cpu'): | |
super().__init__(device) | |
model_type = "HuggingFaceTB/SmolLM2-1.7B" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir='./cache') | |
self.model = AutoModelForCausalLM.from_pretrained(model_type, cache_dir='./cache') | |
self.model.to(device) | |
class CUSsupport_chat_t5_small(BaseChatBot): | |
def __init__(self, device='cpu'): | |
super().__init__(device) | |
self.t2t_pipeline = pipeline("text2text-generation", \ | |
model="mrSoul7766/CUSsupport-chat-t5-small", | |
device=torch.device(self.device)) | |
def get_answer(self, message, history): | |
output_message = None | |
if (len(message.strip())>0): | |
new_text_aux = self.t2t_pipeline(f"answer: {message}", max_length=512)[0]['generated_text'] | |
output_message = {'role': 'assistant', 'content': new_text_aux} | |
return output_message | |