|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers import BlenderbotForConditionalGeneration |
|
import torch |
|
imprt gradio as gr |
|
|
|
|
|
model_name = "microsoft/DialoGPT-medium" |
|
chat_tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
def conversation(user_input, chat_history=[]): |
|
user_input_ids = chat_tokenizer(user_input+chat_tokenizer.eos_token, return_tensors="pt").input_ids |
|
|
|
|
|
chatbot_input_ids = torch.cat([torch.LongTensor(chat_hostory), user_input_ids], dim=-1) |
|
|
|
|
|
chat_history = model.generate(chatbot_input_ids, max_length=1000, pad_token_id=chat_tokenizer.eos_token_id).tolist() |
|
print(chat_history) |
|
|
|
response = chat_tokeniser.decode(chat_history[0]).split("<|endoftext|>") |
|
print("Starting to print response") |
|
print(response) |
|
|
|
|
|
html = "<div class='mychatbot'>" |
|
for x, msg in enumerate(response): |
|
if x%2 !=0: |
|
msg = "ChatBot:" + msg |
|
class = "bot" |
|
else: |
|
class = "user" |
|
|
|
print("Value of x:") |
|
print(x) |
|
print("Message:") |
|
print(msg) |
|
html += "<div class='msg {}'></div>".format(class, msg) |
|
html += "</div>" |
|
print(html) |
|
return html, chat_history |
|
|
|
|
|
css = """ |
|
.mychat {display:flex;flex-direction:column} |
|
.msg {padding:5px;margin-bottom:5px;border-radius:5px;width:75%} |
|
.msg.user {background-color:lightblue;color:white} |
|
.msg.bot {background-color:orange;color:white,align-self:self-end} |
|
.footer |
|
""" |
|
in_text = gr.inputs.Textbox(placeholder="Let's start a chat...") |
|
gr.Interface(fn=conversation, theme="default", inputs=[in_text, "state"], outputs=["html", "state"], css=css).launch() |