File size: 5,795 Bytes
82f9e44 e6bd9b6 82f9e44 e6bd9b6 82f9e44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model
def generate_text(
model, text, tokenizer, max_len, device,
temperature=0.7, top_k=50
):
eos_idx = tokenizer.token_to_id('</s>')
pad_idx = tokenizer.token_to_id('<pad>')
# Start with the input text as initial decoder input
decoder_input = text.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
# Print the initial prompt
while decoder_input.shape[1] < 2000 :
# Apply causal mask based on current decoder_input length
# decoder_mask = (decoder_input != pad_idx).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(mask).to(device)
# Get model output
out = model.decode(decoder_input)
logits = model.project(out[:, -1]) # Get logits for last token
# Sampling: temperature + top-k
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
# Decode and print token
word = tokenizer.decode([next_token.item()])
print(word, end="", flush=True)
# Append next token
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if decoder_input.shape[1] > max_len:
decoder_input = decoder_input[:,-max_len:]
if next_token.item() == eos_idx:
break
print()
return decoder_input.squeeze(0)
def get_tokenizer(config)->Tokenizer:
tokenizers_path = Path(config['tokenizer_file'])
if Path.exists(tokenizers_path):
print("Loading tokenizer from ", tokenizers_path)
tokenizer = Tokenizer.from_file(str(tokenizers_path))
return tokenizer
else:
raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
def run_model(config):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device : {device}")
tokenizer = get_tokenizer(config)
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()
if Path.exists(Path(model_path)):
print("Loading Model from : ", model_path)
state = torch.load(model_path)
model.load_state_dict(state['model_state_dict'])
print("You : ",end="")
input_text = input()
pad_token_id = tokenizer.token_to_id("<pad>")
while input_text != "exit":
input_tokens = tokenizer.encode(input_text).ids[:-1]
if len(input_tokens) > config['seq_len']:
print(f"exceeding max length of input : {config['seq_len']}")
continue
# if len(input_tokens) < config['seq_len']:
# input_tokens += [pad_token_id] * (config['seq_len'] - len(input_tokens))
input_tokens = torch.tensor(input_tokens)
output_tokens = generate_text(model, input_tokens, tokenizer, config['seq_len'], device )
print("MODEL : ",output_tokens)
output_text = tokenizer.decode(output_tokens.detach().cpu().numpy())
# print("Model : "+output_text)
print("You : ",end="")
input_text = input()
else:
raise FileNotFoundError("Model File not found : "+ model_path)
def generate_response(prompt:str):
config = get_config("./openweb.config.json")
print(config)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()
state = torch.load(model_path)
model.load_state_dict(state['model_state_dict'])
input_tokens = tokenizer.encode(prompt).ids[:-1]
if len(input_tokens) > config['seq_len']:
print(f"exceeding max length of input : {config['seq_len']}")
exit()
input_tokens = torch.tensor(input_tokens)
input_mask = (input_tokens != pad_token_id).unsqueeze(0).int() & causal_mask(input_tokens.size(0))
decoder_input = input_tokens.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
temperature = 0.7
top_k = 50
while decoder_input.shape[1] < 2000 :
# Apply causal mask based on current decoder_input length
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
print(decoder_input)
# Get model output
out = model.decode(decoder_input)
logits = model.project(out[:, -1]) # Get logits for last token
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
word = tokenizer.decode([next_token.item()])
# yield word
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if decoder_input.shape[1] > config['seq_len']:
decoder_input = decoder_input[:,-config['seq_len']:]
if next_token.item() == eos_token_id:
break
return decoder_input
if __name__ == "__main__":
config = get_config("openweb.config.json")
run_model(config) |