NLP_Chatbot / chatbot.py
DanielIglesias97's picture
We have improved the interface of the application and removed the
8991658
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import pipeline
class ChatBotManager():
def __init__(self):
self.available_model_types = ['SmolLM2-1.7B (non-conversational)', \
'CUSsupport-chat-t5-small (conversational)']
self.model = None
self.selected_model_type = None
def get_available_model_types(self):
return self.available_model_types
def __create_chatbot__(self, model_type, device='cpu'):
if (model_type==self.available_model_types[0]):
return SmolLM2_1_7B_ChatBot(device)
elif (model_type==self.available_model_types[1]):
return CUSsupport_chat_t5_small(device)
elif (model_type==self.available_model_types[2]):
return Distilgpt2_tiny_conversational(device)
def obtain_answer(self, message, history, model_type):
if (self.selected_model_type!=model_type):
self.selected_model_type = model_type
self.model = self.__create_chatbot__(self.selected_model_type)
return self.model.get_answer(message, history)
# Abstract class.
class BaseChatBot():
def __init__(self, device='cpu'):
self.device = device
def __get_input_ids__(self, input_text, device='cpu'):
input_ids = self.tokenizer(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
return input_ids
def get_answer(self, message, history):
output_message = None
if (len(message.strip())>0):
input_ids = self.__get_input_ids__(message, self.device)
new_ids_aux = self.model.generate(**input_ids, repetition_penalty=1.5, max_new_tokens=1500)
new_text_aux = self.tokenizer.decode(new_ids_aux[0])
output_message = {'role': 'assistant', 'content': new_text_aux}
return output_message
class SmolLM2_1_7B_ChatBot(BaseChatBot):
def __init__(self, device='cpu'):
super().__init__(device)
model_type = "HuggingFaceTB/SmolLM2-1.7B"
self.tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir='./cache')
self.model = AutoModelForCausalLM.from_pretrained(model_type, cache_dir='./cache')
self.model.to(device)
class CUSsupport_chat_t5_small(BaseChatBot):
def __init__(self, device='cpu'):
super().__init__(device)
self.t2t_pipeline = pipeline("text2text-generation", \
model="mrSoul7766/CUSsupport-chat-t5-small",
device=torch.device(self.device))
def get_answer(self, message, history):
output_message = None
if (len(message.strip())>0):
new_text_aux = self.t2t_pipeline(f"answer: {message}", max_length=512)[0]['generated_text']
output_message = {'role': 'assistant', 'content': new_text_aux}
return output_message