# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. from evaluator.CodeBLEU.parser import DFG_python, DFG_java, DFG_ruby, DFG_go, DFG_php, DFG_javascript, DFG_csharp from evaluator.CodeBLEU.parser import (remove_comments_and_docstrings, tree_to_token_index, index_to_code_token, tree_to_variable_index) from tree_sitter import Language, Parser import os root_dir = os.path.dirname(__file__) dfg_function = { 'python': DFG_python, 'java': DFG_java, 'ruby': DFG_ruby, 'go': DFG_go, 'php': DFG_php, 'javascript': DFG_javascript, 'c_sharp': DFG_csharp, } def calc_dataflow_match(references, candidate, lang): return corpus_dataflow_match([references], [candidate], lang) def corpus_dataflow_match(references, candidates, lang): LANGUAGE = Language(root_dir + '/parser/my-languages.so', lang) parser = Parser() parser.set_language(LANGUAGE) parser = [parser, dfg_function[lang]] match_count = 0 total_count = 0 for i in range(len(candidates)): references_sample = references[i] candidate = candidates[i] for reference in references_sample: try: candidate = remove_comments_and_docstrings(candidate, 'java') except: pass try: reference = remove_comments_and_docstrings(reference, 'java') except: pass cand_dfg = get_data_flow(candidate, parser) ref_dfg = get_data_flow(reference, parser) normalized_cand_dfg = normalize_dataflow(cand_dfg) normalized_ref_dfg = normalize_dataflow(ref_dfg) if len(normalized_ref_dfg) > 0: total_count += len(normalized_ref_dfg) for dataflow in normalized_ref_dfg: if dataflow in normalized_cand_dfg: match_count += 1 normalized_cand_dfg.remove(dataflow) if total_count == 0: print( "WARNING: There is no reference data-flows extracted from the whole corpus, and the data-flow match score degenerates to 0. Please consider ignoring this score.") return 0 score = match_count / total_count return score def get_data_flow(code, parser): try: tree = parser[0].parse(bytes(code, 'utf8')) root_node = tree.root_node tokens_index = tree_to_token_index(root_node) code = code.split('\n') code_tokens = [index_to_code_token(x, code) for x in tokens_index] index_to_code = {} for idx, (index, code) in enumerate(zip(tokens_index, code_tokens)): index_to_code[index] = (idx, code) try: DFG, _ = parser[1](root_node, index_to_code, {}) except: DFG = [] DFG = sorted(DFG, key=lambda x: x[1]) indexs = set() for d in DFG: if len(d[-1]) != 0: indexs.add(d[1]) for x in d[-1]: indexs.add(x) new_DFG = [] for d in DFG: if d[1] in indexs: new_DFG.append(d) codes = code_tokens dfg = new_DFG except: codes = code.split() dfg = [] # merge nodes dic = {} for d in dfg: if d[1] not in dic: dic[d[1]] = d else: dic[d[1]] = (d[0], d[1], d[2], list(set(dic[d[1]][3] + d[3])), list(set(dic[d[1]][4] + d[4]))) DFG = [] for d in dic: DFG.append(dic[d]) dfg = DFG return dfg def normalize_dataflow_item(dataflow_item): var_name = dataflow_item[0] var_pos = dataflow_item[1] relationship = dataflow_item[2] par_vars_name_list = dataflow_item[3] par_vars_pos_list = dataflow_item[4] var_names = list(set(par_vars_name_list + [var_name])) norm_names = {} for i in range(len(var_names)): norm_names[var_names[i]] = 'var_' + str(i) norm_var_name = norm_names[var_name] relationship = dataflow_item[2] norm_par_vars_name_list = [norm_names[x] for x in par_vars_name_list] return (norm_var_name, relationship, norm_par_vars_name_list) def normalize_dataflow(dataflow): var_dict = {} i = 0 normalized_dataflow = [] for item in dataflow: var_name = item[0] relationship = item[2] par_vars_name_list = item[3] for name in par_vars_name_list: if name not in var_dict: var_dict[name] = 'var_' + str(i) i += 1 if var_name not in var_dict: var_dict[var_name] = 'var_' + str(i) i += 1 normalized_dataflow.append((var_dict[var_name], relationship, [var_dict[x] for x in par_vars_name_list])) return normalized_dataflow