import torch import random import bisect import json import re from config import * from transformers import GPT2Model, GPT2LMHeadModel, PreTrainedModel, BitsAndBytesConfig from samplings import top_p_sampling, top_k_sampling, temperature_sampling from tokenizers import Tokenizer quantization_config = BitsAndBytesConfig( load_in_8bit=True, llm_int8_skip_modules=["patch_embedding"] # 跳过可能不兼容的模块 ) class Patchilizer: def __init__(self, stream=PATCH_STREAM): self.stream = stream self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"] self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')' self.bos_token_id = 1 self.eos_token_id = 2 self.special_token_id = 0 def split_bars(self, body_lines): """ Split a body of music into individual bars. """ new_bars = [] try: for line in body_lines: line_bars = re.split(self.regexPattern, line) line_bars = list(filter(None, line_bars)) new_line_bars = [] if len(line_bars) == 1: new_line_bars = line_bars else: if line_bars[0] in self.delimiters: new_line_bars = [line_bars[i] + line_bars[i + 1] for i in range(0, len(line_bars), 2)] else: new_line_bars = [line_bars[0]] + [line_bars[i] + line_bars[i + 1] for i in range(1, len(line_bars), 2)] if 'V' not in new_line_bars[-1]: new_line_bars[-2] += new_line_bars[-1] # 吸收最后一个 小节线+\n 的组合 new_line_bars = new_line_bars[:-1] new_bars += new_line_bars except: pass return new_bars def split_patches(self, abc_text, patch_size=PATCH_SIZE, generate_last=False): if not generate_last and len(abc_text) % patch_size != 0: abc_text += chr(self.eos_token_id) patches = [abc_text[i : i + patch_size] for i in range(0, len(abc_text), patch_size)] return patches def patch2chars(self, patch): """ Convert a patch into a bar. """ bytes = '' for idx in patch: if idx == self.eos_token_id: break if idx < self.eos_token_id: pass bytes += chr(idx) return bytes def patchilize_metadata(self, metadata_lines): metadata_patches = [] for line in metadata_lines: metadata_patches += self.split_patches(line) return metadata_patches def patchilize_tunebody(self, tunebody_lines, encode_mode='train'): tunebody_patches = [] bars = self.split_bars(tunebody_lines) if encode_mode == 'train': for bar in bars: tunebody_patches += self.split_patches(bar) elif encode_mode == 'generate': for bar in bars[:-1]: tunebody_patches += self.split_patches(bar) tunebody_patches += self.split_patches(bars[-1], generate_last=True) return tunebody_patches def encode(self, abc_text, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True, cut=True): lines = abc_text.split('\n') lines = list(filter(None, lines)) lines = [line + '\n' for line in lines] tunebody_index = -1 for i, line in enumerate(lines): if line.startswith('[r:'): tunebody_index = i break metadata_lines = lines[: tunebody_index] tunebody_lines = lines[tunebody_index:] metadata_patches = self.patchilize_metadata(metadata_lines) tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='train') if add_special_patches: bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id) eos_patch = chr(self.bos_token_id) + chr(self.eos_token_id) * (patch_size - 1) metadata_patches = [bos_patch] + metadata_patches tunebody_patches = tunebody_patches + [eos_patch] if self.stream: if len(metadata_patches) + len(tunebody_patches) > patch_length: available_cut_indexes = [0] + [index + 1 for index, patch in enumerate(tunebody_patches) if '\n' in patch] line_index_for_cut_index = list(range(len(available_cut_indexes))) # 每个cut_index对应tunebody的哪一行 end_index = len(metadata_patches) + len(tunebody_patches) - patch_length biggest_index = bisect.bisect_left(available_cut_indexes, end_index) # biggest index 在 end_index 右面一位 available_cut_indexes = available_cut_indexes[:biggest_index + 1] if len(available_cut_indexes) == 1: choices = ['head'] elif len(available_cut_indexes) == 2: choices = ['head', 'tail'] else: choices = ['head', 'tail', 'middle'] choice = random.choice(choices) if choice == 'head': patches = metadata_patches + tunebody_patches[0:] else: if choice == 'tail': cut_index = len(available_cut_indexes) - 1 else: cut_index = random.choice(range(1, len(available_cut_indexes) - 1)) line_index = line_index_for_cut_index[cut_index] stream_tunebody_lines = tunebody_lines[line_index:] stream_tunebody_patches = self.patchilize_tunebody(stream_tunebody_lines, encode_mode='train') if add_special_patches: stream_tunebody_patches = stream_tunebody_patches + [eos_patch] patches = metadata_patches + stream_tunebody_patches else: patches = metadata_patches + tunebody_patches else: patches = metadata_patches + tunebody_patches patches = patches[: patch_length] # encode to ids id_patches = [] for patch in patches: id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch)) id_patches.append(id_patch) return id_patches def encode_generate(self, abc_code, patch_length=PATCH_LENGTH, patch_size=PATCH_SIZE, add_special_patches=True): lines = abc_code.split('\n') lines = list(filter(None, lines)) tunebody_index = None for i, line in enumerate(lines): if line.startswith('[V:') or line.startswith('[r:'): tunebody_index = i break metadata_lines = lines[ : tunebody_index] tunebody_lines = lines[tunebody_index : ] # 备份未省略前的tunebody_lines metadata_lines = [line + '\n' for line in metadata_lines] if self.stream: if not abc_code.endswith('\n'): # 如果生成结果最后一行未完结 tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines) - 1)] + [tunebody_lines[-1]] else: tunebody_lines = [tunebody_lines[i] + '\n' for i in range(len(tunebody_lines))] else: tunebody_lines = [line + '\n' for line in tunebody_lines] metadata_patches = self.patchilize_metadata(metadata_lines) tunebody_patches = self.patchilize_tunebody(tunebody_lines, encode_mode='generate') if add_special_patches: bos_patch = chr(self.bos_token_id) * (patch_size - 1) + chr(self.eos_token_id) metadata_patches = [bos_patch] + metadata_patches patches = metadata_patches + tunebody_patches patches = patches[ : patch_length] # encode to ids id_patches = [] for patch in patches: if len(patch) < PATCH_SIZE and patch[-1] != chr(self.eos_token_id): id_patch = [ord(c) for c in patch] else: id_patch = [ord(c) for c in patch] + [self.special_token_id] * (patch_size - len(patch)) id_patches.append(id_patch) return id_patches def decode(self, patches): """ Decode patches into music. """ return ''.join(self.patch2chars(patch) for patch in patches) class PatchLevelDecoder(PreTrainedModel): """ A Patch-level Decoder model for generating patch features in an auto-regressive manner. It inherits PreTrainedModel from transformers. """ def __init__(self, config): super().__init__(config) self.patch_embedding = torch.nn.Linear(PATCH_SIZE * 128, config.n_embd).to(torch.float16) torch.nn.init.normal_(self.patch_embedding.weight, std=0.02) self.base = GPT2Model(config) def forward(self, patches: torch.Tensor, masks=None) -> torch.Tensor: """ The forward pass of the patch-level decoder model. :param patches: the patches to be encoded :param masks: the masks for the patches :return: the encoded patches """ patches = torch.nn.functional.one_hot(patches, num_classes=128).to(self.dtype) patches = patches.reshape(len(patches), -1, PATCH_SIZE * (128)) patches = self.patch_embedding(patches.to(self.device)) if masks==None: return self.base(inputs_embeds=patches) else: return self.base(inputs_embeds=patches, attention_mask=masks) class CharLevelDecoder(PreTrainedModel): """ A Char-level Decoder model for generating the chars within each patch in an auto-regressive manner based on the encoded patch features. It inherits PreTrainedModel from transformers. """ def __init__(self, config): super().__init__(config) self.special_token_id = 0 self.bos_token_id = 1 self.base = GPT2LMHeadModel(config) def forward(self, encoded_patches: torch.Tensor, target_patches: torch.Tensor): """ The forward pass of the char-level decoder model. :param encoded_patches: the encoded patches :param target_patches: the target patches :return: the output of the model """ target_patches = torch.cat((torch.ones_like(target_patches[:, 0:1]) * self.bos_token_id, target_patches), dim=1) # [patch_len, patch_size + 1] target_masks = target_patches == self.special_token_id # [patch_len, patch_size + 1] labels = target_patches.clone().masked_fill_(target_masks, -100) target_masks = torch.ones_like(labels) target_masks = target_masks.masked_fill_(labels == -100, 0) input_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight) input_embeds = torch.cat((encoded_patches.unsqueeze(1), input_embeds[:, 1:, :]), dim=1) logits = self.base(inputs_embeds=input_embeds, attention_mask=target_masks).logits # [patch_len, patch_size + 1, vocab_size] logits = logits[:, :-1, :] token_logps = torch.gather(logits.log_softmax(-1), dim=-1, index=target_patches[:, 1:].unsqueeze(-1)).squeeze(-1) # [patch_len, patch_size] token_logps = token_logps[target_masks[:, 1:] == 1] all_logps = token_logps.sum() return all_logps def generate(self, encoded_patch: torch.Tensor, # [hidden_size] tokens: torch.Tensor): # [1] """ The generate function for generating a patch based on the encoded patch and already generated tokens. :param encoded_patch: the encoded patch :param tokens: already generated tokens in the patch :return: the probability distribution of next token """ encoded_patch = encoded_patch.reshape(1, 1, -1) # [1, 1, hidden_size] tokens = tokens.reshape(1, -1) # Get input embeddings tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight) # Concatenate the encoded patch with the input embeddings tokens = torch.cat((encoded_patch, tokens[:,1:,:]), dim=1) # Get output from model outputs = self.base(inputs_embeds=tokens) # Get probabilities of next token probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1) return probs class NotaGenLMHeadModel(PreTrainedModel): """ NotaGen is a language model with a hierarchical structure. It includes a patch-level decoder and a char-level decoder. The patch-level decoder is used to generate patch features in an auto-regressive manner. The char-level decoder is used to generate the chars within each patch in an auto-regressive manner. It inherits PreTrainedModel from transformers. """ def __init__(self, encoder_config, decoder_config): super().__init__(encoder_config) self.special_token_id = 0 self.bos_token_id = 1 self.eos_token_id = 2 self.patch_level_decoder = PatchLevelDecoder(encoder_config) self.char_level_decoder = CharLevelDecoder(decoder_config) def forward(self, patches: torch.Tensor, masks: torch.Tensor): """ The forward pass of the bGPT model. :param patches: the patches to be encoded :param masks: the masks for the patches :return: the decoded patches """ patches = patches.reshape(len(patches), -1, PATCH_SIZE) encoded_patches = self.patch_level_decoder(patches, masks)["last_hidden_state"] left_shift_masks = masks * (masks.flip(1).cumsum(1).flip(1) > 1) masks[:, 0] = 0 encoded_patches = encoded_patches[left_shift_masks == 1] patches = patches[masks == 1] return self.char_level_decoder(encoded_patches, patches) def generate(self, patches: torch.Tensor, top_k=0, top_p=1, temperature=1.0): """ The generate function for generating patches based on patches. :param patches: the patches to be encoded :param top_k: the top k for sampling :param top_p: the top p for sampling :param temperature: the temperature for sampling :return: the generated patches """ if patches.shape[-1] % PATCH_SIZE != 0: tokens = patches[:,:,-(patches.shape[-1]%PATCH_SIZE):].squeeze(0, 1) tokens = torch.cat((torch.tensor([self.bos_token_id], device=self.device), tokens), dim=-1) patches = patches[:,:,:-(patches.shape[-1]%PATCH_SIZE)] else: tokens = torch.tensor([self.bos_token_id], device=self.device) patches = patches.reshape(len(patches), -1, PATCH_SIZE) # [bs, seq, patch_size] encoded_patches = self.patch_level_decoder(patches)["last_hidden_state"] # [bs, seq, hidden_size] generated_patch = [] while True: prob = self.char_level_decoder.generate(encoded_patches[0][-1], tokens).cpu().detach().numpy() # [128] prob = top_k_sampling(prob, top_k=top_k, return_probs=True) # [128] prob = top_p_sampling(prob, top_p=top_p, return_probs=True) # [128] token = temperature_sampling(prob, temperature=temperature) # int char = chr(token) generated_patch.append(token) if len(tokens) >= PATCH_SIZE:# or token == self.eos_token_id: break else: tokens = torch.cat((tokens, torch.tensor([token], device=self.device)), dim=0) return generated_patch