Spaces:
Runtime error
Runtime error
File size: 2,542 Bytes
fc75f91 238ab50 1fb2550 238ab50 fc75f91 1fb2550 238ab50 8d7ed76 238ab50 25b8856 238ab50 952f87e 77a996f 238ab50 22a141b 238ab50 22a141b 238ab50 6f02ff7 fd58d47 83a6f3d 6dcbe54 3a3549f 6e79f18 238ab50 4117229 238ab50 25b8856 238ab50 fc75f91 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
import torch
from torchtext.data.utils import get_tokenizer
import numpy as np
import subprocess
from huggingface_hub import hf_hub_download
from transformer import Transformer
model_url = "https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl"
subprocess.run(["pip", "install", model_url])
MAX_LEN = 350
tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
vocab = torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
filename="vocab.pth"))
vocab_token_dict = vocab.get_stoi()
indices_to_tokens = vocab.get_itos()
pad_token = vocab_token_dict['<pad>']
unknown_token = vocab_token_dict['<unk>']
sos_token = vocab_token_dict['<sos>']
eos_token = vocab_token_dict['<eos>']
text_pipeline = lambda x: vocab(tokenizer(x))
d_model = 512
heads = 8
N = 6
src_vocab = len(vocab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(len(vocab), len(vocab), d_model, N, heads).to(device)
model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
filename="alpaca_train_400_epoch.pt"), map_location=device))
model.eval()
def respond(input):
model.eval()
src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
e_outputs = model.encoder(src, src_mask)
outputs = torch.zeros(MAX_LEN).type_as(src.data).to(device)
outputs[0] = torch.tensor([vocab.get_stoi()['<sos>']])
for i in range(1, MAX_LEN):
trg_mask = np.triu(np.ones([1, i, i]), k=1).astype('uint8')
trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)
out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
out = torch.nn.functional.softmax(out, dim=-1).detach()
print(out.shape)
print(out[:, -1].data[0])
print(out[:, -1].data[0].shape)
print(np.sum(out[:, -1].data[0]))
ix = np.random.choice(np.arange(len(out[:, -1].data[0])), 1, p=out[:, -1].data[0])
# val, ix = out[:, -1].data.topk(1)
# outputs[i] = ix[0][0]
outputs[i] = ix
# if ix[0][0] == vocab_token_dict['<eos>']:
if ix == vocab_token_dict['<eos>']:
break
return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])
iface = gr.Interface(fn=respond, inputs="text", outputs="text")
iface.launch() |