10M-LLM / inference.py
abancp's picture
ready for deploy
e6bd9b6
import torch
from tokenizers import Tokenizer
from pathlib import Path
from config import get_config, get_weights_file_path
from train import get_model
def generate_text(
model, text, tokenizer, max_len, device,
temperature=0.7, top_k=50
):
eos_idx = tokenizer.token_to_id('</s>')
pad_idx = tokenizer.token_to_id('<pad>')
# Start with the input text as initial decoder input
decoder_input = text.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
# Print the initial prompt
while decoder_input.shape[1] < 2000 :
# Apply causal mask based on current decoder_input length
# decoder_mask = (decoder_input != pad_idx).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(mask).to(device)
# Get model output
out = model.decode(decoder_input)
logits = model.project(out[:, -1]) # Get logits for last token
# Sampling: temperature + top-k
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
# Decode and print token
word = tokenizer.decode([next_token.item()])
print(word, end="", flush=True)
# Append next token
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if decoder_input.shape[1] > max_len:
decoder_input = decoder_input[:,-max_len:]
if next_token.item() == eos_idx:
break
print()
return decoder_input.squeeze(0)
def get_tokenizer(config)->Tokenizer:
tokenizers_path = Path(config['tokenizer_file'])
if Path.exists(tokenizers_path):
print("Loading tokenizer from ", tokenizers_path)
tokenizer = Tokenizer.from_file(str(tokenizers_path))
return tokenizer
else:
raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path)
def run_model(config):
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device : {device}")
tokenizer = get_tokenizer(config)
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()
if Path.exists(Path(model_path)):
print("Loading Model from : ", model_path)
state = torch.load(model_path)
model.load_state_dict(state['model_state_dict'])
print("You : ",end="")
input_text = input()
pad_token_id = tokenizer.token_to_id("<pad>")
while input_text != "exit":
input_tokens = tokenizer.encode(input_text).ids[:-1]
if len(input_tokens) > config['seq_len']:
print(f"exceeding max length of input : {config['seq_len']}")
continue
# if len(input_tokens) < config['seq_len']:
# input_tokens += [pad_token_id] * (config['seq_len'] - len(input_tokens))
input_tokens = torch.tensor(input_tokens)
output_tokens = generate_text(model, input_tokens, tokenizer, config['seq_len'], device )
print("MODEL : ",output_tokens)
output_text = tokenizer.decode(output_tokens.detach().cpu().numpy())
# print("Model : "+output_text)
print("You : ",end="")
input_text = input()
else:
raise FileNotFoundError("Model File not found : "+ model_path)
def generate_response(prompt:str):
config = get_config("./openweb.config.json")
print(config)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = get_tokenizer(config)
pad_token_id = tokenizer.token_to_id("<pad>")
eos_token_id = tokenizer.token_to_id("</s>")
model = get_model(config, tokenizer.get_vocab_size()).to(device)
model_path = get_weights_file_path(config,config['preload'])
model.eval()
state = torch.load(model_path)
model.load_state_dict(state['model_state_dict'])
input_tokens = tokenizer.encode(prompt).ids[:-1]
if len(input_tokens) > config['seq_len']:
print(f"exceeding max length of input : {config['seq_len']}")
exit()
input_tokens = torch.tensor(input_tokens)
input_mask = (input_tokens != pad_token_id).unsqueeze(0).int() & causal_mask(input_tokens.size(0))
decoder_input = input_tokens.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
temperature = 0.7
top_k = 50
while decoder_input.shape[1] < 2000 :
# Apply causal mask based on current decoder_input length
# decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
print(decoder_input)
# Get model output
out = model.decode(decoder_input)
logits = model.project(out[:, -1]) # Get logits for last token
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
word = tokenizer.decode([next_token.item()])
# yield word
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if decoder_input.shape[1] > config['seq_len']:
decoder_input = decoder_input[:,-config['seq_len']:]
if next_token.item() == eos_token_id:
break
return decoder_input
if __name__ == "__main__":
config = get_config("openweb.config.json")
run_model(config)