File size: 2,454 Bytes
fc75f91
238ab50
 
 
1fb2550
 
238ab50
 
fc75f91
1fb2550
 
 
238ab50
 
 
8d7ed76
238ab50
 
25b8856
238ab50
 
 
 
 
 
 
 
 
 
 
 
952f87e
77a996f
238ab50
 
22a141b
238ab50
22a141b
238ab50
 
 
 
 
 
 
 
 
 
c654b20
fd58d47
c9c480e
 
6e79f18
238ab50
4117229
9eed1c1
4117229
9eed1c1
238ab50
25b8856
238ab50
 
fc75f91
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import torch
from torchtext.data.utils import get_tokenizer
import numpy as np
import subprocess

from huggingface_hub import hf_hub_download
from transformer import Transformer

model_url = "https://huggingface.co/spacy/en_core_web_sm/resolve/main/en_core_web_sm-any-py3-none-any.whl"
subprocess.run(["pip", "install", model_url])

MAX_LEN = 350

tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
vocab = torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
                                   filename="vocab.pth"))
vocab_token_dict = vocab.get_stoi()
indices_to_tokens = vocab.get_itos()
pad_token = vocab_token_dict['<pad>']
unknown_token = vocab_token_dict['<unk>']
sos_token = vocab_token_dict['<sos>']
eos_token = vocab_token_dict['<eos>']
text_pipeline = lambda x: vocab(tokenizer(x))

d_model = 512
heads = 8
N = 6
src_vocab = len(vocab)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transformer(len(vocab), len(vocab), d_model, N, heads).to(device)
model.load_state_dict(torch.load(hf_hub_download(repo_id="nickgardner/chatbot",
                                      filename="alpaca_train_400_epoch.pt"), map_location=device))
model.eval()

def respond(input):
    model.eval()
    src = torch.tensor(text_pipeline(input), dtype=torch.int64).unsqueeze(0).to(device)
    src_mask = ((src != pad_token) & (src != unknown_token)).unsqueeze(-2).to(device)
    e_outputs = model.encoder(src, src_mask)

    outputs = torch.zeros(MAX_LEN).type_as(src.data).to(device)
    outputs[0] = torch.tensor([vocab.get_stoi()['<sos>']])
    for i in range(1, MAX_LEN):
        trg_mask = np.triu(np.ones([1, i, i]), k=1).astype('uint8')
        trg_mask = torch.autograd.Variable(torch.from_numpy(trg_mask) == 0).to(device)

        out = model.out(model.decoder(outputs[:i].unsqueeze(0), e_outputs, src_mask, trg_mask))
        out = torch.nn.functional.softmax(out, dim=-1)[:, -1].squeeze().detach().numpy()
        print(out.shape)
        print(np.sum(out))
        ix = np.random.choice(np.arange(len(out)), 1, p=out)
        # val, ix = out[:, -1].data.topk(1)

        # outputs[i] = ix[0][0]
        outputs[i] = ix[0]
        # if ix[0][0] == vocab_token_dict['<eos>']:
        if ix[0] == vocab_token_dict['<eos>']:
            break
    return ' '.join([indices_to_tokens[ix] for ix in outputs[1:i]])

iface = gr.Interface(fn=respond, inputs="text", outputs="text")
iface.launch()