File size: 3,740 Bytes
44504f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import torch
import sys
sys.path.append('/Users/stephensun/Desktop/pypro/LAMBDA/knowledge_integration')
from sentence_transformers import SentenceTransformer, util
import numpy as np
# from knw import KNW_INJECTION, knowledge_injection
from prompt_engineering.prompts import PMT_KNW_IN_CORE, PMT_KNW_IN_FULL
from knowledge_integration.nearest_correlation_matrix import nearest_correlation_matrix
from knowledge_integration.nn_network import nn_networks
from knowledge_integration.pami import pattern_mining
from kernel import execute


KNW_INJECTION = {}

def knowledge_register():
    ncm = nearest_correlation_matrix()
    ncm_key = ncm.name+ncm.description
    KNW_INJECTION[ncm_key] = ncm
    nnn = nn_networks()
    nnn_key = nnn.name+nnn.description
    KNW_INJECTION[nnn_key] = nnn
    pami = pattern_mining()
    pami_key = pami.name+pami.description
    KNW_INJECTION[pami_key] = pami



# εˆε§‹εŒ–ε₯子塌ε…₯ζ¨‘εž‹
model = SentenceTransformer('all-MiniLM-L6-v2')

def search_knowledge(user_input, knowledge_embeddings, knowledge_keys):

    input_embedding = model.encode(user_input, convert_to_tensor=True) #embeding
    # similarity
    similarities_list = util.pytorch_cos_sim(input_embedding, knowledge_embeddings)
    if torch.max(similarities_list) > 0.3:

        best_match_idx = np.argmax(similarities_list.cpu())
        best_match_key = knowledge_keys[best_match_idx]
    else:
        best_match_key = False
    return (best_match_key, KNW_INJECTION[best_match_key]) if best_match_key else (False, None)


def format_code_snaps(knw, kernel):
    if knw.mode == 'full':
        core_code = knw.get_core_function()
        return PMT_KNW_IN_FULL.format(code=core_code)
    elif knw.mode == 'core':
        core_code = knw.get_core_function()
        runnable_code = knw.get_runnable_function()
        print("Knowledge_integration: core mode, runnable result: ", execute(runnable_code,kernel))
        retri_knw = PMT_KNW_IN_CORE.format(core=core_code, runnable=runnable_code)
        return retri_knw
    else:
        raise ValueError(f"Invalid mode: {knw.mode}, please choose from ['full', 'core'].")
        # test_case = knw.get_test_case()
        # return KNOWLEDGE_INJECTION_PMT_FIXED.format(test_case=test_case)


def retrieval_knowledge(instruction, kernel): # return code_snaps and mode: 'full' or runnable code in 'core'. Nothing retrieval, return None
    knowledge_register()
    knowledge_keys = list(KNW_INJECTION.keys())
    knowledge_embeddings = model.encode(knowledge_keys, convert_to_tensor=True)
    best_key, best_knw_object = search_knowledge(instruction, knowledge_embeddings, knowledge_keys)
    if best_key:
        return format_code_snaps(best_knw_object, kernel)
    else:
        return None

# def execute_runnable(code):
#     res_type, res = my_app.conv.run_code(code)


if __name__ == '__main__':
    # knowledge_register()
    # knowledge_keys = list(KNW_INJECTION.keys())
    # knowledge_embeddings = model.encode(knowledge_keys, convert_to_tensor=True)
    # user_input = "calculate nearest correlation matrix"
    # best_key, best_knw_object = search_knowledge(user_input)
    # print(best_key,best_knw_object)
    # print(f"Best match key: {best_key}")
    # print(format_code_snaps(best_knw_object))
    #print(retrieval_knowledge("calculate nearest correlation matrix", 'full'))
    print(retrieval_knowledge("Train a fixed points of nonnegative neural networks. Set parameters: networks: nn_sigmoid, learning rate: 5e-3, epochs: 30, wd: 0, b: 64"))
    #print(retrieval_knowledge("Use pattern mining to find frequent patterns in the dataset. Set parameters: fileURL: https://u-aizu.ac.jp/~udayrage/datasets/transactionalDatabases/Transactional_T10I4D100K.csv, minSup: 300."))