Spaces:
Running
Running
File size: 2,966 Bytes
8991658 ce1805e 8991658 ce1805e 22abb9f 8991658 22abb9f 8991658 22abb9f 8991658 22abb9f 8991658 22abb9f 8991658 22abb9f ce1805e 22abb9f 8991658 22abb9f 8991658 ce1805e 22abb9f 8991658 ce1805e 8991658 4d2d079 8991658 22abb9f 8991658 22abb9f 8991658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers import pipeline
class ChatBotManager():
def __init__(self):
self.available_model_types = ['SmolLM2-1.7B (non-conversational)', \
'CUSsupport-chat-t5-small (conversational)']
self.model = None
self.selected_model_type = None
def get_available_model_types(self):
return self.available_model_types
def __create_chatbot__(self, model_type, device='cpu'):
if (model_type==self.available_model_types[0]):
return SmolLM2_1_7B_ChatBot(device)
elif (model_type==self.available_model_types[1]):
return CUSsupport_chat_t5_small(device)
elif (model_type==self.available_model_types[2]):
return Distilgpt2_tiny_conversational(device)
def obtain_answer(self, message, history, model_type):
if (self.selected_model_type!=model_type):
self.selected_model_type = model_type
self.model = self.__create_chatbot__(self.selected_model_type)
return self.model.get_answer(message, history)
# Abstract class.
class BaseChatBot():
def __init__(self, device='cpu'):
self.device = device
def __get_input_ids__(self, input_text, device='cpu'):
input_ids = self.tokenizer(input_text, return_tensors="pt")
input_ids = input_ids.to(device)
return input_ids
def get_answer(self, message, history):
output_message = None
if (len(message.strip())>0):
input_ids = self.__get_input_ids__(message, self.device)
new_ids_aux = self.model.generate(**input_ids, repetition_penalty=1.5, max_new_tokens=1500)
new_text_aux = self.tokenizer.decode(new_ids_aux[0])
output_message = {'role': 'assistant', 'content': new_text_aux}
return output_message
class SmolLM2_1_7B_ChatBot(BaseChatBot):
def __init__(self, device='cpu'):
super().__init__(device)
model_type = "HuggingFaceTB/SmolLM2-1.7B"
self.tokenizer = AutoTokenizer.from_pretrained(model_type, cache_dir='./cache')
self.model = AutoModelForCausalLM.from_pretrained(model_type, cache_dir='./cache')
self.model.to(device)
class CUSsupport_chat_t5_small(BaseChatBot):
def __init__(self, device='cpu'):
super().__init__(device)
self.t2t_pipeline = pipeline("text2text-generation", \
model="mrSoul7766/CUSsupport-chat-t5-small",
device=torch.device(self.device))
def get_answer(self, message, history):
output_message = None
if (len(message.strip())>0):
new_text_aux = self.t2t_pipeline(f"answer: {message}", max_length=512)[0]['generated_text']
output_message = {'role': 'assistant', 'content': new_text_aux}
return output_message
|