File size: 2,743 Bytes
82f9e44
 
 
 
 
 
 
 
 
7e1aa1c
 
 
 
 
 
 
 
 
17a3eb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82f9e44
e6bd9b6
17a3eb0
e6bd9b6
 
82f9e44
 
 
 
 
 
 
 
 
 
c78a3b6
dd1b76c
aa1287d
82f9e44
 
 
 
 
 
 
 
 
 
e6bd9b6
dd1b76c
c78a3b6
82f9e44
 
 
aa1287d
82f9e44
dd1b76c
17a3eb0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch

from tokenizers import Tokenizer


from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model

def get_tokenizer(config)->Tokenizer:
    tokenizers_path = Path(config['tokenizer_file'])
    if Path.exists(tokenizers_path):
        print("Loading tokenizer from ", tokenizers_path)
        tokenizer = Tokenizer.from_file(str(tokenizers_path))
        return tokenizer
    else:
        raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)


config = get_config("./openweb.config.json")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
user_token_id = tokenizer.token_to_id("<user>") 
ai_token_id = tokenizer.token_to_id("<ai>")

model  = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()        
state = torch.load(model_path,map_location=torch.device('cpu'))
model.load_state_dict(state['model_state_dict'])

def generate_response(prompt:str):
    print("Prompt : ",prompt)
    
    word = ""
    input_tokens = tokenizer.encode(prompt).ids
    input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
    if len(input_tokens) > config['seq_len']:
        print(f"exceeding max length of input : {config['seq_len']}")
        exit()
    input_tokens = torch.tensor(input_tokens)
    decoder_input = input_tokens.to(device)
    if decoder_input.dim() == 1:
       decoder_input = decoder_input.unsqueeze(0)
    temperature = 0.7
    top_k = 50
    i = 0
    print("Output  : ",end="")
    while decoder_input.shape[1] < 2000:
        # Apply causal mask based on current decoder_input length
        # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
        # Get model output
        out = model.decode(decoder_input)
        logits = model.project(out[:, -1])  # Get logits for last token
        logits = logits / temperature
        top_k_logits, top_k_indices = torch.topk(logits, top_k)
        probs = torch.softmax(top_k_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        next_token = top_k_indices.gather(-1, next_token)
        word += tokenizer.decode([next_token.item()])
        print(word,end="")
        i+=1
        decoder_input = torch.cat([decoder_input, next_token], dim=1)
        if decoder_input.shape[1] > config['seq_len']:
            decoder_input = decoder_input[:,-config['seq_len']:]
        if next_token.item() == eos_token_id  or i >= 1024:
            break
    print()
    return word