import torch from tokenizers import Tokenizer from pathlib import Path from config import get_config, get_weights_file_path from train import get_model def generate_text( model, text, tokenizer, max_len, device, temperature=0.7, top_k=50 ): eos_idx = tokenizer.token_to_id('') pad_idx = tokenizer.token_to_id('') # Start with the input text as initial decoder input decoder_input = text.to(device) if decoder_input.dim() == 1: decoder_input = decoder_input.unsqueeze(0) # Print the initial prompt while decoder_input.shape[1] < 2000 : # Apply causal mask based on current decoder_input length # decoder_mask = (decoder_input != pad_idx).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(mask).to(device) # Get model output out = model.decode(decoder_input) logits = model.project(out[:, -1]) # Get logits for last token # Sampling: temperature + top-k logits = logits / temperature top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = torch.softmax(top_k_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token) # Decode and print token word = tokenizer.decode([next_token.item()]) print(word, end="", flush=True) # Append next token decoder_input = torch.cat([decoder_input, next_token], dim=1) if decoder_input.shape[1] > max_len: decoder_input = decoder_input[:,-max_len:] if next_token.item() == eos_idx: break print() return decoder_input.squeeze(0) def get_tokenizer(config)->Tokenizer: tokenizers_path = Path(config['tokenizer_file']) if Path.exists(tokenizers_path): print("Loading tokenizer from ", tokenizers_path) tokenizer = Tokenizer.from_file(str(tokenizers_path)) return tokenizer else: raise FileNotFoundError("Cant find tokenizer file : ",tokenizers_path) def run_model(config): device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device : {device}") tokenizer = get_tokenizer(config) model = get_model(config, tokenizer.get_vocab_size()).to(device) model_path = get_weights_file_path(config,config['preload']) model.eval() if Path.exists(Path(model_path)): print("Loading Model from : ", model_path) state = torch.load(model_path) model.load_state_dict(state['model_state_dict']) print("You : ",end="") input_text = input() pad_token_id = tokenizer.token_to_id("") while input_text != "exit": input_tokens = tokenizer.encode(input_text).ids[:-1] if len(input_tokens) > config['seq_len']: print(f"exceeding max length of input : {config['seq_len']}") continue # if len(input_tokens) < config['seq_len']: # input_tokens += [pad_token_id] * (config['seq_len'] - len(input_tokens)) input_tokens = torch.tensor(input_tokens) output_tokens = generate_text(model, input_tokens, tokenizer, config['seq_len'], device ) print("MODEL : ",output_tokens) output_text = tokenizer.decode(output_tokens.detach().cpu().numpy()) # print("Model : "+output_text) print("You : ",end="") input_text = input() else: raise FileNotFoundError("Model File not found : "+ model_path) def generate_response(prompt:str): config = get_config("./openweb.config.json") print(config) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = get_tokenizer(config) pad_token_id = tokenizer.token_to_id("") eos_token_id = tokenizer.token_to_id("") model = get_model(config, tokenizer.get_vocab_size()).to(device) model_path = get_weights_file_path(config,config['preload']) model.eval() state = torch.load(model_path) model.load_state_dict(state['model_state_dict']) input_tokens = tokenizer.encode(prompt).ids[:-1] if len(input_tokens) > config['seq_len']: print(f"exceeding max length of input : {config['seq_len']}") exit() input_tokens = torch.tensor(input_tokens) input_mask = (input_tokens != pad_token_id).unsqueeze(0).int() & causal_mask(input_tokens.size(0)) decoder_input = input_tokens.to(device) if decoder_input.dim() == 1: decoder_input = decoder_input.unsqueeze(0) temperature = 0.7 top_k = 50 while decoder_input.shape[1] < 2000 : # Apply causal mask based on current decoder_input length # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device) print(decoder_input) # Get model output out = model.decode(decoder_input) logits = model.project(out[:, -1]) # Get logits for last token logits = logits / temperature top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = torch.softmax(top_k_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token) word = tokenizer.decode([next_token.item()]) # yield word decoder_input = torch.cat([decoder_input, next_token], dim=1) if decoder_input.shape[1] > config['seq_len']: decoder_input = decoder_input[:,-config['seq_len']:] if next_token.item() == eos_token_id: break return decoder_input if __name__ == "__main__": config = get_config("openweb.config.json") run_model(config)