codeX-1.0 / evaluator /CodeBLEU /dataflow_match.py
dorkai's picture
Upload model from GitHub.
b410583
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from evaluator.CodeBLEU.parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp
from evaluator.CodeBLEU.parser import (remove_comments_and_docstrings,
tree_to_token_index,
index_to_code_token,
tree_to_variable_index)
from tree_sitter import Language, Parser
import os
root_dir = os.path.dirname(__file__)
dfg_function = {
'python': DFG_python,
'java': DFG_java,
'ruby': DFG_ruby,
'go': DFG_go,
'php': DFG_php,
'javascript': DFG_javascript,
'c_sharp': DFG_csharp,
}
def calc_dataflow_match(references, candidate, lang):
return corpus_dataflow_match([references], [candidate], lang)
def corpus_dataflow_match(references, candidates, lang):
LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
parser = [parser, dfg_function[lang]]
match_count = 0
total_count = 0
for i in range(len(candidates)):
references_sample = references[i]
candidate = candidates[i]
for reference in references_sample:
try:
candidate = remove_comments_and_docstrings(candidate, 'java')
except:
pass
try:
reference = remove_comments_and_docstrings(reference, 'java')
except:
pass
cand_dfg = get_data_flow(candidate, parser)
ref_dfg = get_data_flow(reference, parser)
normalized_cand_dfg = normalize_dataflow(cand_dfg)
normalized_ref_dfg = normalize_dataflow(ref_dfg)
if len(normalized_ref_dfg) > 0:
total_count += len(normalized_ref_dfg)
for dataflow in normalized_ref_dfg:
if dataflow in normalized_cand_dfg:
match_count += 1
normalized_cand_dfg.remove(dataflow)
if total_count == 0:
print(
"WARNING: There is no reference data-flows extracted from the whole corpus, and the data-flow match score degenerates to 0. Please consider ignoring this score.")
return 0
score = match_count / total_count
return score
def get_data_flow(code, parser):
try:
tree = parser[0].parse(bytes(code, 'utf8'))
root_node = tree.root_node
tokens_index = tree_to_token_index(root_node)
code = code.split('\n')
code_tokens = [index_to_code_token(x, code) for x in tokens_index]
index_to_code = {}
for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)):
index_to_code[index] = (idx, code)
try:
DFG, _ = parser[1](root_node, index_to_code, {})
except:
DFG = []
DFG = sorted(DFG, key=lambda x: x[1])
indexs = set()
for d in DFG:
if len(d[-1]) != 0:
indexs.add(d[1])
for x in d[-1]:
indexs.add(x)
new_DFG = []
for d in DFG:
if d[1] in indexs:
new_DFG.append(d)
codes = code_tokens
dfg = new_DFG
except:
codes = code.split()
dfg = []
# merge nodes
dic = {}
for d in dfg:
if d[1] not in dic:
dic[d[1]] = d
else:
dic[d[1]] = (d[0], d[1], d[2], list(set(dic[d[1]][3] + d[3])), list(set(dic[d[1]][4] + d[4])))
DFG = []
for d in dic:
DFG.append(dic[d])
dfg = DFG
return dfg
def normalize_dataflow_item(dataflow_item):
var_name = dataflow_item[0]
var_pos = dataflow_item[1]
relationship = dataflow_item[2]
par_vars_name_list = dataflow_item[3]
par_vars_pos_list = dataflow_item[4]
var_names = list(set(par_vars_name_list + [var_name]))
norm_names = {}
for i in range(len(var_names)):
norm_names[var_names[i]] = 'var_' + str(i)
norm_var_name = norm_names[var_name]
relationship = dataflow_item[2]
norm_par_vars_name_list = [norm_names[x] for x in par_vars_name_list]
return (norm_var_name, relationship, norm_par_vars_name_list)
def normalize_dataflow(dataflow):
var_dict = {}
i = 0
normalized_dataflow = []
for item in dataflow:
var_name = item[0]
relationship = item[2]
par_vars_name_list = item[3]
for name in par_vars_name_list:
if name not in var_dict:
var_dict[name] = 'var_' + str(i)
i += 1
if var_name not in var_dict:
var_dict[var_name] = 'var_' + str(i)
i += 1
normalized_dataflow.append((var_dict[var_name], relationship, [var_dict[x] for x in par_vars_name_list]))
return normalized_dataflow