Spaces:
Running
Running
import re, json | |
import os, random | |
import torch, logging | |
from copy import deepcopy as cp | |
from torch.utils.data import Dataset | |
from tokenizers import ByteLevelBPETokenizer | |
from transformers import T5Tokenizer, RobertaTokenizer | |
import nltk | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
datefmt="%m/%d/%Y %H:%M:%S", | |
level=logging.INFO, | |
) | |
logger = logging.getLogger(__name__) | |
class MyTokenizer(object): | |
""" | |
Wrapper for ByteLevelBPETokenizer | |
""" | |
def __init__(self, vocab=None, merges=None, **kwargs): | |
self.tokenizer = ByteLevelBPETokenizer(vocab, merges, **kwargs) | |
self.update_id2token() | |
def from_pretrained(path): | |
vocabp = os.path.join(path, "vocab.json") | |
mergesp = os.path.join(path, "merges.txt") | |
mytoken = MyTokenizer(vocabp, mergesp) | |
return mytoken | |
def update_id2token(self): | |
vocab = self.tokenizer.get_vocab() | |
self.id2token = {vocab[token]: token for token in vocab} | |
def add_special_tokens(self, dic): | |
for values in dic.values(): | |
self.tokenizer.add_special_tokens(values) | |
self.update_id2token() | |
def convert_ids_to_tokens(self, ids): | |
vocab = self.id2token | |
return [vocab[i] for i in ids] | |
def decode(self, ids, **kwargs): ##### to be update | |
tokens = self.convert_ids_to_tokens(ids) | |
return " ".join(tokens) | |
def encode(self, text, **kwargs): | |
text = text.encode("ascii", errors="ignore").decode("ascii") | |
return self.tokenizer.encode(text).ids | |
def get_vocab(self): | |
return self.tokenizer.get_vocab() | |
def __len__(self): | |
return len(self.tokenizer.get_vocab()) | |
class RefineFeatures(object): | |
def __init__(self, example_id, source_ids, target_ids): | |
self.example_id = example_id | |
self.source_ids = source_ids | |
self.target_ids = target_ids | |
class RefineDataset(Dataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.tokenizer = tokenizer | |
self.args = args | |
logger.info("Reading examples from {}".format(file_path)) | |
examples = [json.loads(line) for line in open(file_path)] | |
for i in range(len(examples)): | |
if "id" not in examples[i]: | |
examples[i]["id"] = i | |
if samplenum > 0: | |
examples = examples[:samplenum] | |
logger.info(f"Tokenize examples: {file_path}") | |
self.feats = pool.map(self.tokenize, \ | |
[(example, tokenizer, args) for example in examples]) | |
def tokenize(self, item): | |
example, tokenizer, args = item | |
oldlines = example["old"].split("\n") | |
newlines = example["new"].split("\n") | |
oldlines = [line[1:].strip() for line in oldlines] | |
newlines = [line[1:].strip() for line in newlines] | |
oldlines = "\n".join(oldlines) | |
newlines = "\n".join(newlines) | |
oldlines = "<add>" + oldlines.replace("\n", "<add>") | |
newlines = "<add>" + newlines.replace("\n", "<add>") | |
comment = example["comment"] | |
srcids = self.encode_remove(tokenizer, oldlines, args) | |
srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args) | |
tgtids = self.encode_remove(tokenizer, newlines, args) | |
srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) | |
return RefineFeatures(example["id"], srcids, tgtids) | |
def process_pred_gold(pred, gold): | |
gold = gold.split("\n") | |
gold = [line[1:].strip() for line in gold] | |
gold = " ".join(gold) | |
pred = " ".join(pred.split()) | |
pred = pred.replace("<add> ", "") | |
return pred, gold | |
def pad_assert(self, source_ids, target_ids, args, tokenizer): | |
source_ids = source_ids[:args.max_source_length - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = args.max_source_length - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
target_ids = target_ids[:args.max_target_length - 2] | |
target_ids = [tokenizer.bos_id] + target_ids + [tokenizer.eos_id] | |
pad_len = args.max_target_length - len(target_ids) | |
target_ids += [tokenizer.pad_id] * pad_len | |
assert len(source_ids) == args.max_source_length, "Not equal length." | |
assert len(target_ids) == args.max_target_length, "Not equal length." | |
return source_ids, target_ids | |
def encode_remove(self, tokenizer, text, args): | |
text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True) | |
if type(tokenizer) == T5Tokenizer: | |
return text[:-1] | |
elif type(tokenizer) == RobertaTokenizer: | |
return text[1:-1] | |
elif type(tokenizer) == MyTokenizer: | |
return text | |
else: | |
raise NotImplementedError | |
def __len__(self): | |
return len(self.feats) | |
def __getitem__(self, i): | |
return self.feats[i] | |
class SimpleRefineDataset(RefineDataset): | |
def tokenize(self, item): | |
example, tokenizer, args = item | |
oldlines = example["old"].split("\n") | |
newlines = example["new"].split("\n") | |
oldlines = [line[1:].strip() for line in oldlines] | |
newlines = [line[1:].strip() for line in newlines] | |
oldlines = " ".join(oldlines) | |
newlines = " ".join(newlines) | |
comment = example["comment"] | |
srcids = self.encode_remove(tokenizer, oldlines, args) | |
srcids += [tokenizer.msg_id] + self.encode_remove(tokenizer, comment, args) | |
tgtids = self.encode_remove(tokenizer, newlines, args) | |
srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) | |
return RefineFeatures(example["id"], srcids, tgtids) | |
def process_pred_gold(pred, gold): | |
gold = gold.split("\n") | |
gold = [line[1:].strip() for line in gold] | |
gold = " ".join(gold) | |
pred = " ".join(pred.split()) | |
return pred, gold | |
class Seq2SeqDataset(RefineDataset): | |
def tokenize(self, item): | |
example, tokenizer, args = item | |
inputs, outputs = example["old"], example["new"] | |
inputs = " ".join(inputs.split()) | |
outputs = " ".join(outputs.split()) | |
srcids = self.encode_remove(tokenizer, inputs, args) | |
tgtids = self.encode_remove(tokenizer, outputs, args) | |
srcids, tgtids = self.pad_assert(srcids, tgtids, args, tokenizer) | |
return RefineFeatures(example["id"], srcids, tgtids) | |
def process_pred_gold(pred, gold): | |
gold = " ".join(gold.split()) | |
pred = " ".join(pred.split()) | |
return pred, gold | |
class TextDataset(Dataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.cnt = 0 | |
self.tokenizer = tokenizer | |
self.args = args | |
if isinstance(tokenizer, MyTokenizer): | |
tokenizer_type = "mytok" | |
elif isinstance(tokenizer, T5Tokenizer): | |
tokenizer_type = "" | |
elif isinstance(tokenizer, RobertaTokenizer): | |
tokenizer_type = "rb" | |
else: | |
tokenizer_type = "unk" | |
savep = file_path.replace(".jsonl", tokenizer_type + ".exps") | |
# savep = "/home/v-zhuoli1/lzzz/processed/chunk_25.exps" | |
if os.path.exists(savep): | |
logger.info("Loading examples from {}".format(savep)) | |
examples = torch.load(savep) | |
else: | |
logger.info("Reading examples from {}".format(file_path)) | |
examples = read_review_examples(file_path, samplenum, tokenizer) | |
logger.info(f"Tokenize examples: {file_path}") | |
examples = pool.map(self.tokenize, \ | |
[(example, tokenizer, args) for example in examples]) | |
torch.save(examples, savep) | |
logger.info("Convert examples to features...") | |
self.set_start_end_ids(examples) | |
self.featss = pool.map(self.convert_examples_to_features, \ | |
[(example, tokenizer, args) for example in examples]) | |
self.feats = [feat for feats in self.featss for feat in feats] # expand the lists | |
def __len__(self): | |
return len(self.feats) | |
def __getitem__(self, i): | |
return self.feats[i] | |
def reset_len(self, data_len): | |
assert len(self.feats) >= data_len | |
self.feats = self.feats[:data_len] | |
def set_start_end_ids(self, examples): | |
for example in examples: | |
labels = example.labels | |
start_id = 0 | |
end_id = len(labels) - 1 | |
for i, label in enumerate(labels): | |
if label != -100: # find the first label | |
start_id = i | |
break | |
for i in range(len(labels) - 1, -1, -1): | |
label = labels[i] | |
if label != -100: | |
end_id = i | |
break | |
example.start_id = start_id | |
example.end_id = end_id | |
def tokenize(self, item): | |
example, tokenizer, args = item | |
example.input = self.encode_remove(tokenizer, example.input, args) | |
e0id = tokenizer.special_dict["<e0>"] | |
inputs = " ".join(str(id) for id in example.input) | |
lines = inputs.split(" " + str(e0id) + " ") | |
lines = [ | |
[int(v) for v in line.split(" ") if len(v) > 0] for line in lines | |
] | |
lens = [len(line) for line in lines] | |
# if 0 in lens: | |
# logger.info("Warning: empty line in an example.") | |
lens = list(map(len, lines)) | |
curlen = len(lens) + sum(lens) | |
left, right = 0, len(lines) | |
while curlen > args.max_source_length - 2: | |
if left % 2 == 0: | |
curlen -= 1 + len(lines[left]) | |
left += 1 | |
else: | |
right -= 1 | |
curlen -= 1 + len(lines[right]) | |
lines = lines[left:right] | |
labels = example.labels[left:right] | |
assert len(lines) + sum(map(len, lines)) <= args.max_source_length - 2, "Too long inputs in TextDataset.tokenize." | |
if len(lines) != len(labels): | |
logger.info("Not equal length in TextDataset.tokenize.") | |
lines = lines[:len(labels)] | |
labels = labels[:len(lines)] | |
example.lines = lines | |
example.labels = labels | |
example.msg = self.encode_remove(tokenizer, example.msg, args) | |
return example | |
def convert_examples_to_features(self, item): | |
example, _, _ = item | |
if len(example.msg) > 0: | |
exs = [] | |
for _ in range(3): # up sampling | |
if random.random() < 0.5: | |
exs.append(self.genmsg_example(item)) | |
else: | |
exs.append(self.daemsg_example(item)) | |
return exs | |
if random.random() < 0.5: | |
return [self.encoder_example(item)] | |
return [self.decoder_example(item)] | |
def encoder_example(self, item): | |
example, tokenizer, args = item | |
lines = example.lines | |
labels = example.labels | |
target_ids = [tokenizer.pad_id] * args.max_target_length | |
source_ids, input_labels = [], [] | |
for i, (line, label) in enumerate(zip(lines, labels)): | |
if i == example.start_id: | |
source_ids.append(tokenizer.start_id) | |
input_labels.append(-100) | |
if label != -100: # only insert special tokens at diffs, not context | |
source_ids.append(tokenizer.mask_id) | |
input_labels.append(label) | |
source_ids.extend(line) | |
input_labels.extend([-100] * len(line)) | |
if i == example.end_id: | |
source_ids.append(tokenizer.end_id) | |
input_labels.append(-100) | |
assert len(input_labels) == len(source_ids), "Not equal length." | |
assert len(input_labels) <= args.max_source_length, f"Too long inputs: {len(input_labels)}." | |
source_ids = source_ids[:args.max_source_length - 2] | |
input_labels = input_labels[:args.max_source_length - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
input_labels = [-100] + input_labels + [-100] | |
pad_len = args.max_source_length - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
input_labels += [-100] * pad_len | |
new_input_labels = [] | |
map_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} | |
for label in input_labels: | |
if label == -100: | |
new_input_labels.append(-100) | |
else: | |
new_input_labels.append(map_dict[label]) | |
input_labels = new_input_labels | |
assert len(source_ids) == args.max_source_length, "Not equal length." | |
assert len(input_labels) == args.max_source_length, "Not equal length." | |
return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="label") | |
def decoder_example(self, item): | |
example, tokenizer, args = item | |
lines = example.lines | |
labels = example.labels | |
input_labels = [-100] * args.max_source_length | |
source_ids, target_ids = [], [] | |
SPECIAL_ID = 0 | |
mask_idxs = random.choices(range(len(lines)), k=int(len(lines) * args.mask_rate)) | |
id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} | |
for i, (line, label) in enumerate(zip(lines, labels)): | |
if i == example.start_id: | |
source_ids.append(tokenizer.start_id) | |
if label in id_dict: | |
source_ids.append(id_dict[label]) | |
if i in mask_idxs: | |
source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"]) | |
target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"]) | |
target_ids.extend(line) | |
if SPECIAL_ID < 99: # only 0-99 ids in vocab | |
SPECIAL_ID += 1 | |
else: | |
source_ids.extend(line) | |
if i == example.end_id: | |
source_ids.append(tokenizer.end_id) | |
source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) | |
return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="line") | |
def genmsg_example(self, item): | |
example, tokenizer, args = item | |
lines = example.lines | |
labels = example.labels | |
input_labels = [-100] * args.max_source_length | |
source_ids, target_ids = [], [] | |
id_dict = {0: tokenizer.del_id, 1: tokenizer.add_id, 2: tokenizer.keep_id} | |
for i, (line, label) in enumerate(zip(lines, labels)): | |
if i == example.start_id: | |
source_ids.append(tokenizer.start_id) | |
if label != -100: | |
source_ids.append(id_dict[label]) | |
source_ids.extend(line) | |
if i == example.end_id: | |
source_ids.append(tokenizer.end_id) | |
target_ids.append(tokenizer.msg_id) | |
target_ids.extend(example.msg) | |
assert len(source_ids) <= args.max_source_length, f"Too long inputs: {len(source_ids)}." | |
source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) | |
return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="genmsg") | |
def daemsg_example(self, item): | |
example, tokenizer, args = item | |
input_labels = [-100] * args.max_source_length | |
source_ids, target_ids = [], [] | |
msg_ids = cp(example.msg) | |
masks = [random.random() < 0.20 for _ in range(len(msg_ids))] | |
if sum(masks) == 0: | |
idx = random.choice(range(len(msg_ids))) | |
masks[idx] = True | |
source_ids, target_ids = [], [] | |
i = 0 | |
SPECIAL_ID = 0 | |
while i < len(masks): | |
j = i | |
while j < len(masks) and not masks[j]: | |
source_ids.append(msg_ids[j]) | |
j += 1 | |
if j == len(masks): | |
break | |
source_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"]) | |
target_ids.append(tokenizer.special_dict[f"<e{SPECIAL_ID}>"]) | |
while j < len(masks) and masks[j]: | |
target_ids.append(msg_ids[j]) | |
j += 1 | |
if SPECIAL_ID < 99: # only 0-99 ids in vocab | |
SPECIAL_ID += 1 | |
i = j | |
source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) | |
return ReviewFeatures(example.idx, source_ids, input_labels, target_ids, type="daemsg") | |
def pad_assert(self, source_ids, target_ids, args, tokenizer): | |
source_ids = source_ids[:args.max_source_length - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = args.max_source_length - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
target_ids = target_ids[:args.max_target_length - 1] | |
target_ids = target_ids + [tokenizer.eos_id] | |
pad_len = args.max_target_length - len(target_ids) | |
target_ids += [tokenizer.pad_id] * pad_len | |
assert len(source_ids) == args.max_source_length, "Not equal length." | |
assert len(target_ids) == args.max_target_length, "Not equal length." | |
return source_ids, target_ids | |
def encode_remove(self, tokenizer, text, args): | |
text = tokenizer.encode(text, max_length=args.max_source_length, truncation=True) | |
if type(tokenizer) == T5Tokenizer: | |
return text[:-1] | |
elif type(tokenizer) == RobertaTokenizer: | |
return text[1:-1] | |
elif type(tokenizer) == MyTokenizer: | |
return text | |
else: | |
raise NotImplementedError | |
class CommentGenDataset(TextDataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.tokenizer = tokenizer | |
if isinstance(tokenizer, MyTokenizer): | |
tokenizer_type = "mytok" | |
elif isinstance(tokenizer, T5Tokenizer): | |
tokenizer_type = "" | |
elif isinstance(tokenizer, RobertaTokenizer): | |
tokenizer_type = "rb" | |
else: | |
tokenizer_type = "unk" | |
savep = file_path.replace(".jsonl", tokenizer_type + ".exps") | |
if os.path.exists(savep): | |
logger.info("Loading examples from {}".format(savep)) | |
examples = torch.load(savep) | |
else: | |
logger.info("Reading examples from {}".format(file_path)) | |
examples = read_review_examples(file_path, samplenum, tokenizer) | |
# for i in range(len(examples)): | |
# examples[i].msg = " ".join(nltk.word_tokenize(examples[i].msg)) | |
logger.info(f"Tokenize examples: {file_path}") | |
examples = pool.map(self.tokenize, \ | |
[(example, tokenizer, args) for example in examples]) | |
torch.save(examples, savep) | |
logger.info("Convert examples to features...") | |
self.set_start_end_ids(examples) | |
self.feats = pool.map(self.convert_examples_to_features, \ | |
[(example, tokenizer, args) for example in examples]) | |
self.feats = [feat for feat in self.feats if feat is not None] | |
def convert_examples_to_features(self, item): | |
example, tokenizer, args = item | |
if len(example.msg) == 0: | |
return None | |
return self.genmsg_example(item) | |
class CommentClsDataset(TextDataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.tokenizer = tokenizer | |
if isinstance(tokenizer, MyTokenizer): | |
tokenizer_type = "mytok" | |
elif isinstance(tokenizer, T5Tokenizer): | |
tokenizer_type = "" | |
elif isinstance(tokenizer, RobertaTokenizer): | |
tokenizer_type = "rb" | |
else: | |
tokenizer_type = "unk" | |
savep = file_path.replace(".jsonl", tokenizer_type + ".exps") | |
if os.path.exists(savep): | |
logger.info("Loading examples from {}".format(savep)) | |
examples = torch.load(savep) | |
else: | |
logger.info("Reading examples from {}".format(file_path)) | |
examples = read_review_examples(file_path, samplenum, tokenizer) | |
logger.info(f"Tokenize examples: {file_path}") | |
examples = pool.map(self.tokenize, \ | |
[(example, tokenizer, args) for example in examples]) | |
torch.save(examples, savep) | |
logger.info("Convert examples to features...") | |
self.set_start_end_ids(examples) | |
self.feats = pool.map(self.convert_examples_to_features, \ | |
[(example, tokenizer, args) for example in examples]) | |
def convert_examples_to_features(self, item): | |
example, tokenizer, args = item | |
tmpfeature = self.genmsg_example(item) | |
return ClsFeatures(tmpfeature.example_id, tmpfeature.source_ids, example.y) | |
class SimpleClsDataset(TextDataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.tokenizer = tokenizer | |
if isinstance(tokenizer, MyTokenizer): | |
tokenizer_type = "mytok" | |
elif isinstance(tokenizer, T5Tokenizer): | |
tokenizer_type = "" | |
elif isinstance(tokenizer, RobertaTokenizer): | |
tokenizer_type = "rb" | |
else: | |
tokenizer_type = "unk" | |
savep = file_path.replace(".jsonl", tokenizer_type + ".simpexps") | |
if os.path.exists(savep): | |
logger.info("Loading examples from {}".format(savep)) | |
self.feats = torch.load(savep) | |
else: | |
logger.info("Reading examples from {}".format(file_path)) | |
examples = read_review_examples(file_path, samplenum, tokenizer) | |
logger.info(f"Tokenize examples: {file_path}") | |
self.feats = pool.map(self.convert_examples_to_features, \ | |
[(example, tokenizer, args) for example in examples]) | |
torch.save(self.feats, savep) | |
def convert_examples_to_features(self, item): | |
example, tokenizer, args = item | |
example.input_lines = example.input.split("<e0>") | |
labels_l = len(example.labels) | |
example.input_lines = example.input_lines[:labels_l] | |
for i in range(len(example.input_lines)): | |
if example.labels[i] == 1: | |
example.input_lines[i] = "+ " + example.input_lines[i] | |
elif example.labels[i] == 0: | |
example.input_lines[i] = "- " + example.input_lines[i] | |
example.input = " ".join(example.input_lines) | |
input_ids = self.encode_remove(tokenizer, example.input, args) | |
exceed_l = len(input_ids) - args.max_source_length + 2 | |
if exceed_l > 0: | |
halfexl = (exceed_l + 1) // 2 | |
input_ids = input_ids[halfexl:-halfexl] | |
source_ids = input_ids[:args.max_source_length - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = args.max_source_length - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
example_id = example.idx | |
y = example.y | |
return ClsFeatures(example_id, source_ids, y) | |
class SimpleGenDataset(TextDataset): | |
def __init__(self, tokenizer, pool, args, file_path, samplenum=-1): | |
self.tokenizer = tokenizer | |
if isinstance(tokenizer, MyTokenizer): | |
tokenizer_type = "mytok" | |
elif isinstance(tokenizer, T5Tokenizer): | |
tokenizer_type = "" | |
elif isinstance(tokenizer, RobertaTokenizer): | |
tokenizer_type = "rb" | |
else: | |
tokenizer_type = "unk" | |
savep = file_path.replace(".jsonl", tokenizer_type + ".simpgenexps") | |
if os.path.exists(savep): | |
logger.info("Loading examples from {}".format(savep)) | |
self.feats = torch.load(savep) | |
else: | |
logger.info("Reading examples from {}".format(file_path)) | |
data = read_jsonl(file_path) | |
# data = [dic for dic in data if len(dic["patch"].split("\n")) <= 20] | |
for i in range(len(data)): | |
data[i]["idx"] = i | |
logger.info(f"Tokenize examples: {file_path}") | |
# self.feats = pool.map(self.convert_examples_to_features, \ | |
# [(dic, tokenizer, args) for dic in data]) | |
self.feats = [self.convert_examples_to_features((dic, tokenizer, args)) for dic in data] | |
torch.save(self.feats, savep) | |
def convert_examples_to_features(self, item): | |
dic, tokenizer, args = item | |
diff, msg = dic["patch"], dic["msg"] | |
difflines = diff.split("\n")[1:] # remove start @@ | |
difflines = [line for line in difflines if len(line.strip()) > 0] | |
map_dic = {"-": 0, "+": 1, " ": 2} | |
def f(s): | |
if s in map_dic: | |
return map_dic[s] | |
else: | |
return 2 | |
labels = [f(line[0]) for line in difflines] | |
difflines = [line[1:].strip() for line in difflines] | |
inputstr = "" | |
for label, line in zip(labels, difflines): | |
if label == 1: | |
inputstr += "<add>" + line | |
elif label == 0: | |
inputstr += "<del>" + line | |
else: | |
inputstr += "<keep>" + line | |
source_ids = self.encode_remove(tokenizer, inputstr, args) | |
target_ids = [] | |
target_ids.append(tokenizer.msg_id) | |
msg = self.encode_remove(tokenizer, dic["msg"], args) | |
target_ids.extend(msg) | |
source_ids, target_ids = self.pad_assert(source_ids, target_ids, args, tokenizer) | |
input_labels = [-100] * len(source_ids) | |
return ReviewFeatures(dic["idx"], source_ids, input_labels, target_ids, type="genmsg") | |
class InputFeatures(object): | |
"""A single training/test features for a example.""" | |
def __init__(self, example_id, source_ids, target_ids, url=None): | |
self.example_id = example_id | |
self.source_ids = source_ids | |
self.target_ids = target_ids | |
self.url = url | |
class ReviewFeatures(object): | |
def __init__(self, example_id, source_ids, source_labels, target_ids, type): | |
self.example_id = example_id | |
self.source_ids = source_ids | |
self.source_labels = source_labels | |
self.target_ids = target_ids | |
assert type in ("label", "line", "genmsg", "daemsg") | |
self.type = type | |
class ClsFeatures(object): | |
def __init__(self, example_id, source_ids, y): | |
self.example_id = example_id | |
self.source_ids = source_ids | |
self.y = y | |
class ReviewExample(object): | |
"""A single training/test example.""" | |
def __init__( | |
self, idx, oldf, diff, msg, cmtid, max_len, y | |
): | |
self.idx = idx # idx is useless yet | |
self.oldf = oldf | |
self.diff = diff | |
self.msg = msg | |
self.cmtid = cmtid | |
self.max_len = max_len | |
self.y = y | |
self.prevlines = [] | |
self.afterlines = [] | |
self.lines = [] | |
self.labels = [] | |
self.avail = False | |
self.input = "" | |
self.align_and_clean() | |
self.postprocess() | |
def postprocess(self): | |
if not self.avail: | |
return | |
# Warning: lines is not self.lines | |
# lines for rough length estimation | |
lines = [source_str.split() for source_str in self.lines] | |
inputl = len(lines) # line tag | |
inputl += sum(map(len, lines)) | |
left, right = 0, len(lines) | |
while inputl > self.max_len: | |
if left % 2 == 0: | |
inputl -= len(lines[left]) + 1 | |
left += 1 | |
else: | |
right -= 1 | |
inputl -= len(lines[right]) + 1 | |
lines = lines[left:right] | |
self.lines = self.lines[left:right] | |
self.labels = self.labels[left:right] | |
prevlines = self.prevlines | |
afterlines = self.afterlines | |
prev_after_len = max(len(prevlines), len(afterlines)) | |
i = 0 | |
while inputl < self.max_len and i < prev_after_len: | |
if i < len(prevlines): | |
newl = inputl + len(prevlines[-1-i].split()) + 1 | |
if newl > self.max_len: | |
break | |
self.lines.insert(0, prevlines[-1-i]) | |
self.labels.insert(0, -100) | |
inputl = newl # tag | |
if i < len(afterlines): | |
newl = inputl + len(afterlines[i].split()) + 1 | |
if newl > self.max_len: | |
break | |
self.lines.append(afterlines[i]) | |
self.labels.append(-100) | |
inputl = newl # tag | |
i += 1 | |
assert inputl <= self.max_len, "Too long inputs." | |
assert len(self.lines) == len(self.labels), "Not equal length." | |
self.input = "<e0>".join(self.lines) | |
self.prevlines, self.lines, self.afterlines = [], [], [] | |
def remove_space_clean(self, line): | |
""" | |
Remove start and end empty chars. | |
""" | |
rep = " \t\r" | |
totallen = len(line) | |
i = 0 | |
while i < totallen and line[i] in rep: | |
i += 1 | |
j = totallen - 1 | |
while j >= 0 and line[j] in rep: | |
j -= 1 | |
line = line[i : j + 1] | |
return line | |
def align_and_clean(self): | |
oldflines = self.oldf.split("\n") | |
difflines = self.diff.split("\n") | |
first_line = difflines[0] | |
difflines = difflines[1:] | |
difflines = [line for line in difflines if line != r"\ No newline at end of file"] | |
regex = r"@@ -(\d+),(\d+) \+(\d+),(\d+) @@" | |
matchres = re.match(regex, first_line) | |
if matchres: | |
startline, rangelen, startpos, endpos = matchres.groups() | |
self.avail = True | |
else: | |
self.avail = False | |
return | |
startline, rangelen = int(startline) - 1, int(rangelen) | |
endline = startline + rangelen | |
self.prevlines = oldflines[:startline] | |
self.afterlines = oldflines[endline:] | |
for line in difflines: | |
if line.startswith("-"): | |
self.lines.append(line[1:]) | |
self.labels.append(0) | |
elif line.startswith("+"): | |
self.lines.append(line[1:]) | |
self.labels.append(1) | |
else: | |
self.lines.append(line) | |
self.labels.append(2) | |
self.prevlines = [self.remove_space_clean(line) for line in self.prevlines] | |
self.afterlines = [self.remove_space_clean(line) for line in self.afterlines] | |
self.lines = [self.remove_space_clean(line) for line in self.lines] | |
self.msg = self.remove_space_clean(self.msg) | |
self.prevlines = [line for line in self.prevlines if len(line) > 0] | |
self.afterlines = [line for line in self.afterlines if len(line) > 0] | |
# print("\n".join(self.prevlines)) | |
# print("\n\n\n\n") | |
# print("\n".join(self.lines)) | |
# print("\n\n\n\n") | |
# print("\n".join(self.afterlines)) | |
# print("\n\n\n\n") | |
assert len(self.lines) == len(self.labels), "Not equal length in align." | |
topack = list( | |
zip( | |
*[ | |
(line, label) | |
for line, label in zip(self.lines, self.labels) | |
if len(line) > 0 | |
] | |
) | |
) | |
if topack == []: | |
self.avail = False | |
return | |
else: | |
self.lines, self.labels = topack | |
# tuple->list, convenient for later operation | |
self.lines = list(self.lines) | |
self.labels = list(self.labels) | |
def read_review_examples(filename, data_num=-1, tokenizer=None): | |
"""Read examples from filename.""" | |
examples = [] | |
idx = 0 | |
with open(filename) as f: | |
for line in f: | |
try: | |
js = json.loads(line.strip()) | |
except: | |
print("Error during reading json data.") | |
continue | |
maxl = 200 | |
if "y" not in js: | |
js["y"] = 0 | |
if "msg" in js and len(js["msg"]) > 0: | |
js["y"] = 1 | |
example = ReviewExample( | |
idx=idx, | |
oldf=js["oldf"], | |
diff=js["patch"], | |
msg=js["msg"] if "msg" in js else "", | |
cmtid=js["cmtid"] if "cmtid" in js else "", | |
max_len=maxl, | |
y=js["y"] | |
) | |
if example.avail: | |
examples.append(example) | |
idx += 1 | |
if idx == data_num: | |
break | |
else: | |
# print(f"Passing {idx} because of invalid diff.") | |
idx += 1 | |
if idx == data_num: | |
break | |
return examples | |
def read_jsonl(path): | |
data = [] | |
with open(path) as f: | |
for line in f: | |
try: | |
js = json.loads(line.strip()) | |
except: | |
print("Error during reading json data.") | |
continue | |
data.append(js) | |
return data |