import re, json import os, random import torch, logging from copy import deepcopy as cp from torch.utils.data import Dataset from tokenizers import ByteLevelBPETokenizer from transformers import T5Tokenizer, RobertaTokenizer import nltk logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, ) logger = logging.getLogger(__name__) class MyTokenizer(object): """ Wrapper for ByteLevelBPETokenizer """ def __init__(self, vocab=None, merges=None, **kwargs): self.tokenizer = ByteLevelBPETokenizer(vocab, merges, **kwargs) self.update_id2token() @staticmethod def from_pretrained(path): vocabp = os.path.join(path, "vocab.json") mergesp = os.path.join(path, "merges.txt") mytoken = MyTokenizer(vocabp, mergesp) return mytoken def update_id2token(self): vocab = self.tokenizer.get_vocab() self.id2token = {vocab[token]: token for token in vocab} def add_special_tokens(self, dic): for values in dic.values(): self.tokenizer.add_special_tokens(values) self.update_id2token() def convert_ids_to_tokens(self, ids): vocab = self.id2token return [vocab[i] for i in ids] def decode(self, ids, **kwargs): ##### to be update tokens = self.convert_ids_to_tokens(ids) return " ".join(tokens) def encode(self, text, **kwargs): text = text.encode("ascii", errors="ignore").decode("ascii") return self.tokenizer.encode(text).ids def get_vocab(self): return self.tokenizer.get_vocab() def __len__(self): return len(self.tokenizer.get_vocab()) class RefineFeatures(object): def __init__(self, example_id, source_ids, target_ids): self.example_id = example_id self.source_ids = source_ids self.target_ids = target_ids class RefineDataset(Dataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.tokenizer = tokenizer self.args = args logger.info("Reading examples from {}".format(file_path)) examples = [json.loads(line) for line in open(file_path)] for i in range(len(examples)): if "id" not in examples[i]: examples[i]["id"] = i if samplenum > 0: examples = examples[:samplenum] logger.info(f"Tokenize examples: {file_path}") self.feats = pool.map(self.tokenize, \ [(example, tokenizer, args) for example in examples]) def tokenize(self, item): example, tokenizer, args = item oldlines = example["old"].split("\n") newlines = example["new"].split("\n") oldlines = [line[1:].strip() for line in oldlines] newlines = [line[1:].strip() for line in newlines] oldlines = "\n".join(oldlines) newlines = "\n".join(newlines) oldlines = "" + oldlines.replace("\n", "") newlines = "" + newlines.replace("\n", "") comment = example["comment"] srcids = self.encode_remove(tokenizer, oldlines, args) srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args) tgtids = self.encode_remove(tokenizer, newlines, args) srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) return RefineFeatures(example["id"], srcids, tgtids) @staticmethod def process_pred_gold(pred, gold): gold = gold.split("\n") gold = [line[1:].strip() for line in gold] gold = " ".join(gold) pred = " ".join(pred.split()) pred = pred.replace(" ", "") return pred, gold def pad_assert(self, source_ids, target_ids, args, tokenizer): source_ids = source_ids[:args.max_source_length - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = args.max_source_length - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len target_ids = target_ids[:args.max_target_length - 2] target_ids = [tokenizer.bos_id] + target_ids + [tokenizer.eos_id] pad_len = args.max_target_length - len(target_ids) target_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == args.max_source_length, "Not equal length." assert len(target_ids) == args.max_target_length, "Not equal length." return source_ids, target_ids def encode_remove(self, tokenizer, text, args): text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True) if type(tokenizer) == T5Tokenizer: return text[:-1] elif type(tokenizer) == RobertaTokenizer: return text[1:-1] elif type(tokenizer) == MyTokenizer: return text else: raise NotImplementedError def __len__(self): return len(self.feats) def __getitem__(self, i): return self.feats[i] class SimpleRefineDataset(RefineDataset): def tokenize(self, item): example, tokenizer, args = item oldlines = example["old"].split("\n") newlines = example["new"].split("\n") oldlines = [line[1:].strip() for line in oldlines] newlines = [line[1:].strip() for line in newlines] oldlines = " ".join(oldlines) newlines = " ".join(newlines) comment = example["comment"] srcids = self.encode_remove(tokenizer, oldlines, args) srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args) tgtids = self.encode_remove(tokenizer, newlines, args) srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) return RefineFeatures(example["id"], srcids, tgtids) @staticmethod def process_pred_gold(pred, gold): gold = gold.split("\n") gold = [line[1:].strip() for line in gold] gold = " ".join(gold) pred = " ".join(pred.split()) return pred, gold class Seq2SeqDataset(RefineDataset): def tokenize(self, item): example, tokenizer, args = item inputs, outputs = example["old"], example["new"] inputs = " ".join(inputs.split()) outputs = " ".join(outputs.split()) srcids = self.encode_remove(tokenizer, inputs, args) tgtids = self.encode_remove(tokenizer, outputs, args) srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) return RefineFeatures(example["id"], srcids, tgtids) @staticmethod def process_pred_gold(pred, gold): gold = " ".join(gold.split()) pred = " ".join(pred.split()) return pred, gold class TextDataset(Dataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.cnt = 0 self.tokenizer = tokenizer self.args = args if isinstance(tokenizer, MyTokenizer): tokenizer_type = "mytok" elif isinstance(tokenizer, T5Tokenizer): tokenizer_type = "" elif isinstance(tokenizer, RobertaTokenizer): tokenizer_type = "rb" else: tokenizer_type = "unk" savep = file_path.replace(".jsonl", tokenizer_type + ".exps") # savep = "/home/v-zhuoli1/lzzz/processed/chunk_25.exps" if os.path.exists(savep): logger.info("Loading examples from {}".format(savep)) examples = torch.load(savep) else: logger.info("Reading examples from {}".format(file_path)) examples = read_review_examples(file_path, samplenum, tokenizer) logger.info(f"Tokenize examples: {file_path}") examples = pool.map(self.tokenize, \ [(example, tokenizer, args) for example in examples]) torch.save(examples, savep) logger.info("Convert examples to features...") self.set_start_end_ids(examples) self.featss = pool.map(self.convert_examples_to_features, \ [(example, tokenizer, args) for example in examples]) self.feats = [feat for feats in self.featss for feat in feats] # expand the lists def __len__(self): return len(self.feats) def __getitem__(self, i): return self.feats[i] def reset_len(self, data_len): assert len(self.feats) >= data_len self.feats = self.feats[:data_len] def set_start_end_ids(self, examples): for example in examples: labels = example.labels start_id = 0 end_id = len(labels) - 1 for i, label in enumerate(labels): if label != -100: # find the first label start_id = i break for i in range(len(labels) - 1, -1, -1): label = labels[i] if label != -100: end_id = i break example.start_id = start_id example.end_id = end_id def tokenize(self, item): example, tokenizer, args = item example.input = self.encode_remove(tokenizer, example.input, args) e0id = tokenizer.special_dict[""] inputs = " ".join(str(id) for id in example.input) lines = inputs.split(" " + str(e0id) + " ") lines = [ [int(v) for v in line.split(" ") if len(v) > 0] for line in lines ] lens = [len(line) for line in lines] # if 0 in lens: # logger.info("Warning: empty line in an example.") lens = list(map(len, lines)) curlen = len(lens) + sum(lens) left, right = 0, len(lines) while curlen > args.max_source_length - 2: if left % 2 == 0: curlen -= 1 + len(lines[left]) left += 1 else: right -= 1 curlen -= 1 + len(lines[right]) lines = lines[left:right] labels = example.labels[left:right] assert len(lines) + sum(map(len, lines)) <= args.max_source_length - 2, "Too long inputs in TextDataset.tokenize." if len(lines) != len(labels): logger.info("Not equal length in TextDataset.tokenize.") lines = lines[:len(labels)] labels = labels[:len(lines)] example.lines = lines example.labels = labels example.msg = self.encode_remove(tokenizer, example.msg, args) return example def convert_examples_to_features(self, item): example, _, _ = item if len(example.msg) > 0: exs = [] for _ in range(3): # up sampling if random.random() < 0.5: exs.append(self.genmsg_example(item)) else: exs.append(self.daemsg_example(item)) return exs if random.random() < 0.5: return [self.encoder_example(item)] return [self.decoder_example(item)] def encoder_example(self, item): example, tokenizer, args = item lines = example.lines labels = example.labels target_ids = [tokenizer.pad_id] * args.max_target_length source_ids, input_labels = [], [] for i, (line, label) in enumerate(zip(lines, labels)): if i == example.start_id: source_ids.append(tokenizer.start_id) input_labels.append(-100) if label != -100: # only insert special tokens at diffs, not context source_ids.append(tokenizer.mask_id) input_labels.append(label) source_ids.extend(line) input_labels.extend([-100] * len(line)) if i == example.end_id: source_ids.append(tokenizer.end_id) input_labels.append(-100) assert len(input_labels) == len(source_ids), "Not equal length." assert len(input_labels) <= args.max_source_length, f"Too long inputs: {len(input_labels)}." source_ids = source_ids[:args.max_source_length - 2] input_labels = input_labels[:args.max_source_length - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] input_labels = [-100] + input_labels + [-100] pad_len = args.max_source_length - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len input_labels += [-100] * pad_len new_input_labels = [] map_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} for label in input_labels: if label == -100: new_input_labels.append(-100) else: new_input_labels.append(map_dict[label]) input_labels = new_input_labels assert len(source_ids) == args.max_source_length, "Not equal length." assert len(input_labels) == args.max_source_length, "Not equal length." return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="label") def decoder_example(self, item): example, tokenizer, args = item lines = example.lines labels = example.labels input_labels = [-100] * args.max_source_length source_ids, target_ids = [], [] SPECIAL_ID = 0 mask_idxs = random.choices(range(len(lines)), k=int(len(lines) * args.mask_rate)) id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} for i, (line, label) in enumerate(zip(lines, labels)): if i == example.start_id: source_ids.append(tokenizer.start_id) if label in id_dict: source_ids.append(id_dict[label]) if i in mask_idxs: source_ids.append(tokenizer.special_dict[f""]) target_ids.append(tokenizer.special_dict[f""]) target_ids.extend(line) if SPECIAL_ID < 99: # only 0-99 ids in vocab SPECIAL_ID += 1 else: source_ids.extend(line) if i == example.end_id: source_ids.append(tokenizer.end_id) source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="line") def genmsg_example(self, item): example, tokenizer, args = item lines = example.lines labels = example.labels input_labels = [-100] * args.max_source_length source_ids, target_ids = [], [] id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} for i, (line, label) in enumerate(zip(lines, labels)): if i == example.start_id: source_ids.append(tokenizer.start_id) if label != -100: source_ids.append(id_dict[label]) source_ids.extend(line) if i == example.end_id: source_ids.append(tokenizer.end_id) target_ids.append(tokenizer.msg_id) target_ids.extend(example.msg) assert len(source_ids) <= args.max_source_length, f"Too long inputs: {len(source_ids)}." source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="genmsg") def daemsg_example(self, item): example, tokenizer, args = item input_labels = [-100] * args.max_source_length source_ids, target_ids = [], [] msg_ids = cp(example.msg) masks = [random.random() < 0.20 for _ in range(len(msg_ids))] if sum(masks) == 0: idx = random.choice(range(len(msg_ids))) masks[idx] = True source_ids, target_ids = [], [] i = 0 SPECIAL_ID = 0 while i < len(masks): j = i while j < len(masks) and not masks[j]: source_ids.append(msg_ids[j]) j += 1 if j == len(masks): break source_ids.append(tokenizer.special_dict[f""]) target_ids.append(tokenizer.special_dict[f""]) while j < len(masks) and masks[j]: target_ids.append(msg_ids[j]) j += 1 if SPECIAL_ID < 99: # only 0-99 ids in vocab SPECIAL_ID += 1 i = j source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="daemsg") def pad_assert(self, source_ids, target_ids, args, tokenizer): source_ids = source_ids[:args.max_source_length - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = args.max_source_length - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len target_ids = target_ids[:args.max_target_length - 1] target_ids = target_ids + [tokenizer.eos_id] pad_len = args.max_target_length - len(target_ids) target_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == args.max_source_length, "Not equal length." assert len(target_ids) == args.max_target_length, "Not equal length." return source_ids, target_ids def encode_remove(self, tokenizer, text, args): text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True) if type(tokenizer) == T5Tokenizer: return text[:-1] elif type(tokenizer) == RobertaTokenizer: return text[1:-1] elif type(tokenizer) == MyTokenizer: return text else: raise NotImplementedError class CommentGenDataset(TextDataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.tokenizer = tokenizer if isinstance(tokenizer, MyTokenizer): tokenizer_type = "mytok" elif isinstance(tokenizer, T5Tokenizer): tokenizer_type = "" elif isinstance(tokenizer, RobertaTokenizer): tokenizer_type = "rb" else: tokenizer_type = "unk" savep = file_path.replace(".jsonl", tokenizer_type + ".exps") if os.path.exists(savep): logger.info("Loading examples from {}".format(savep)) examples = torch.load(savep) else: logger.info("Reading examples from {}".format(file_path)) examples = read_review_examples(file_path, samplenum, tokenizer) # for i in range(len(examples)): # examples[i].msg = " ".join(nltk.word_tokenize(examples[i].msg)) logger.info(f"Tokenize examples: {file_path}") examples = pool.map(self.tokenize, \ [(example, tokenizer, args) for example in examples]) torch.save(examples, savep) logger.info("Convert examples to features...") self.set_start_end_ids(examples) self.feats = pool.map(self.convert_examples_to_features, \ [(example, tokenizer, args) for example in examples]) self.feats = [feat for feat in self.feats if feat is not None] def convert_examples_to_features(self, item): example, tokenizer, args = item if len(example.msg) == 0: return None return self.genmsg_example(item) class CommentClsDataset(TextDataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.tokenizer = tokenizer if isinstance(tokenizer, MyTokenizer): tokenizer_type = "mytok" elif isinstance(tokenizer, T5Tokenizer): tokenizer_type = "" elif isinstance(tokenizer, RobertaTokenizer): tokenizer_type = "rb" else: tokenizer_type = "unk" savep = file_path.replace(".jsonl", tokenizer_type + ".exps") if os.path.exists(savep): logger.info("Loading examples from {}".format(savep)) examples = torch.load(savep) else: logger.info("Reading examples from {}".format(file_path)) examples = read_review_examples(file_path, samplenum, tokenizer) logger.info(f"Tokenize examples: {file_path}") examples = pool.map(self.tokenize, \ [(example, tokenizer, args) for example in examples]) torch.save(examples, savep) logger.info("Convert examples to features...") self.set_start_end_ids(examples) self.feats = pool.map(self.convert_examples_to_features, \ [(example, tokenizer, args) for example in examples]) def convert_examples_to_features(self, item): example, tokenizer, args = item tmpfeature = self.genmsg_example(item) return ClsFeatures(tmpfeature.example_id, tmpfeature.source_ids, example.y) class SimpleClsDataset(TextDataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.tokenizer = tokenizer if isinstance(tokenizer, MyTokenizer): tokenizer_type = "mytok" elif isinstance(tokenizer, T5Tokenizer): tokenizer_type = "" elif isinstance(tokenizer, RobertaTokenizer): tokenizer_type = "rb" else: tokenizer_type = "unk" savep = file_path.replace(".jsonl", tokenizer_type + ".simpexps") if os.path.exists(savep): logger.info("Loading examples from {}".format(savep)) self.feats = torch.load(savep) else: logger.info("Reading examples from {}".format(file_path)) examples = read_review_examples(file_path, samplenum, tokenizer) logger.info(f"Tokenize examples: {file_path}") self.feats = pool.map(self.convert_examples_to_features, \ [(example, tokenizer, args) for example in examples]) torch.save(self.feats, savep) def convert_examples_to_features(self, item): example, tokenizer, args = item example.input_lines = example.input.split("") labels_l = len(example.labels) example.input_lines = example.input_lines[:labels_l] for i in range(len(example.input_lines)): if example.labels[i] == 1: example.input_lines[i] = "+ " + example.input_lines[i] elif example.labels[i] == 0: example.input_lines[i] = "- " + example.input_lines[i] example.input = " ".join(example.input_lines) input_ids = self.encode_remove(tokenizer, example.input, args) exceed_l = len(input_ids) - args.max_source_length + 2 if exceed_l > 0: halfexl = (exceed_l + 1) // 2 input_ids = input_ids[halfexl:-halfexl] source_ids = input_ids[:args.max_source_length - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = args.max_source_length - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len example_id = example.idx y = example.y return ClsFeatures(example_id, source_ids, y) class SimpleGenDataset(TextDataset): def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): self.tokenizer = tokenizer if isinstance(tokenizer, MyTokenizer): tokenizer_type = "mytok" elif isinstance(tokenizer, T5Tokenizer): tokenizer_type = "" elif isinstance(tokenizer, RobertaTokenizer): tokenizer_type = "rb" else: tokenizer_type = "unk" savep = file_path.replace(".jsonl", tokenizer_type + ".simpgenexps") if os.path.exists(savep): logger.info("Loading examples from {}".format(savep)) self.feats = torch.load(savep) else: logger.info("Reading examples from {}".format(file_path)) data = read_jsonl(file_path) # data = [dic for dic in data if len(dic["patch"].split("\n")) <= 20] for i in range(len(data)): data[i]["idx"] = i logger.info(f"Tokenize examples: {file_path}") # self.feats = pool.map(self.convert_examples_to_features, \ # [(dic, tokenizer, args) for dic in data]) self.feats = [self.convert_examples_to_features((dic, tokenizer, args)) for dic in data] torch.save(self.feats, savep) def convert_examples_to_features(self, item): dic, tokenizer, args = item diff, msg = dic["patch"], dic["msg"] difflines = diff.split("\n")[1:] # remove start @@ difflines = [line for line in difflines if len(line.strip()) > 0] map_dic = {"-": 0, "+": 1, " ": 2} def f(s): if s in map_dic: return map_dic[s] else: return 2 labels = [f(line[0]) for line in difflines] difflines = [line[1:].strip() for line in difflines] inputstr = "" for label, line in zip(labels, difflines): if label == 1: inputstr += "" + line elif label == 0: inputstr += "" + line else: inputstr += "" + line source_ids = self.encode_remove(tokenizer, inputstr, args) target_ids = [] target_ids.append(tokenizer.msg_id) msg = self.encode_remove(tokenizer, dic["msg"], args) target_ids.extend(msg) source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) input_labels = [-100] * len(source_ids) return ReviewFeatures(dic["idx"], source_ids, input_labels, target_ids, type="genmsg") class InputFeatures(object): """A single training/test features for a example.""" def __init__(self, example_id, source_ids, target_ids, url=None): self.example_id = example_id self.source_ids = source_ids self.target_ids = target_ids self.url = url class ReviewFeatures(object): def __init__(self, example_id, source_ids, source_labels, target_ids, type): self.example_id = example_id self.source_ids = source_ids self.source_labels = source_labels self.target_ids = target_ids assert type in ("label", "line", "genmsg", "daemsg") self.type = type class ClsFeatures(object): def __init__(self, example_id, source_ids, y): self.example_id = example_id self.source_ids = source_ids self.y = y class ReviewExample(object): """A single training/test example.""" def __init__( self, idx, oldf, diff, msg, cmtid, max_len, y ): self.idx = idx # idx is useless yet self.oldf = oldf self.diff = diff self.msg = msg self.cmtid = cmtid self.max_len = max_len self.y = y self.prevlines = [] self.afterlines = [] self.lines = [] self.labels = [] self.avail = False self.input = "" self.align_and_clean() self.postprocess() def postprocess(self): if not self.avail: return # Warning: lines is not self.lines # lines for rough length estimation lines = [source_str.split() for source_str in self.lines] inputl = len(lines) # line tag inputl += sum(map(len, lines)) left, right = 0, len(lines) while inputl > self.max_len: if left % 2 == 0: inputl -= len(lines[left]) + 1 left += 1 else: right -= 1 inputl -= len(lines[right]) + 1 lines = lines[left:right] self.lines = self.lines[left:right] self.labels = self.labels[left:right] prevlines = self.prevlines afterlines = self.afterlines prev_after_len = max(len(prevlines), len(afterlines)) i = 0 while inputl < self.max_len and i < prev_after_len: if i < len(prevlines): newl = inputl + len(prevlines[-1-i].split()) + 1 if newl > self.max_len: break self.lines.insert(0, prevlines[-1-i]) self.labels.insert(0, -100) inputl = newl # tag if i < len(afterlines): newl = inputl + len(afterlines[i].split()) + 1 if newl > self.max_len: break self.lines.append(afterlines[i]) self.labels.append(-100) inputl = newl # tag i += 1 assert inputl <= self.max_len, "Too long inputs." assert len(self.lines) == len(self.labels), "Not equal length." self.input = "".join(self.lines) self.prevlines, self.lines, self.afterlines = [], [], [] def remove_space_clean(self, line): """ Remove start and end empty chars. """ rep = " \t\r" totallen = len(line) i = 0 while i < totallen and line[i] in rep: i += 1 j = totallen - 1 while j >= 0 and line[j] in rep: j -= 1 line = line[i : j + 1] return line def align_and_clean(self): oldflines = self.oldf.split("\n") difflines = self.diff.split("\n") first_line = difflines[0] difflines = difflines[1:] difflines = [line for line in difflines if line != r"\ No newline at end of file"] regex = r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@" matchres = re.match(regex, first_line) if matchres: startline, rangelen, startpos, endpos = matchres.groups() self.avail = True else: self.avail = False return startline, rangelen = int(startline) - 1, int(rangelen) endline = startline + rangelen self.prevlines = oldflines[:startline] self.afterlines = oldflines[endline:] for line in difflines: if line.startswith("-"): self.lines.append(line[1:]) self.labels.append(0) elif line.startswith("+"): self.lines.append(line[1:]) self.labels.append(1) else: self.lines.append(line) self.labels.append(2) self.prevlines = [self.remove_space_clean(line) for line in self.prevlines] self.afterlines = [self.remove_space_clean(line) for line in self.afterlines] self.lines = [self.remove_space_clean(line) for line in self.lines] self.msg = self.remove_space_clean(self.msg) self.prevlines = [line for line in self.prevlines if len(line) > 0] self.afterlines = [line for line in self.afterlines if len(line) > 0] # print("\n".join(self.prevlines)) # print("\n\n\n\n") # print("\n".join(self.lines)) # print("\n\n\n\n") # print("\n".join(self.afterlines)) # print("\n\n\n\n") assert len(self.lines) == len(self.labels), "Not equal length in align." topack = list( zip( *[ (line, label) for line, label in zip(self.lines, self.labels) if len(line) > 0 ] ) ) if topack == []: self.avail = False return else: self.lines, self.labels = topack # tuple->list, convenient for later operation self.lines = list(self.lines) self.labels = list(self.labels) def read_review_examples(filename, data_num=-1, tokenizer=None): """Read examples from filename.""" examples = [] idx = 0 with open(filename) as f: for line in f: try: js = json.loads(line.strip()) except: print("Error during reading json data.") continue maxl = 200 if "y" not in js: js["y"] = 0 if "msg" in js and len(js["msg"]) > 0: js["y"] = 1 example = ReviewExample( idx=idx, oldf=js["oldf"], diff=js["patch"], msg=js["msg"] if "msg" in js else "", cmtid=js["cmtid"] if "cmtid" in js else "", max_len=maxl, y=js["y"] ) if example.avail: examples.append(example) idx += 1 if idx == data_num: break else: # print(f"Passing {idx} because of invalid diff.") idx += 1 if idx == data_num: break return examples def read_jsonl(path): data = [] with open(path) as f: for line in f: try: js = json.loads(line.strip()) except: print("Error during reading json data.") continue data.append(js) return data