JasonTPhillipsJr commited on
Commit
0a51000
·
verified ·
1 Parent(s): 17e89f9

Delete models/spabert/experiments

Browse files
Files changed (30) hide show
  1. models/spabert/experiments/__init__.py +0 -0
  2. models/spabert/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
  3. models/spabert/experiments/entity_matching/__init__.py +0 -0
  4. models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc +0 -0
  5. models/spabert/experiments/entity_matching/data_processing/__init__.py +0 -0
  6. models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc +0 -0
  7. models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc +0 -0
  8. models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc +0 -0
  9. models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc +0 -0
  10. models/spabert/experiments/entity_matching/data_processing/get_namelist.py +0 -95
  11. models/spabert/experiments/entity_matching/data_processing/request_wrapper.py +0 -186
  12. models/spabert/experiments/entity_matching/data_processing/run_linking_query.py +0 -143
  13. models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py +0 -123
  14. models/spabert/experiments/entity_matching/data_processing/run_query_sample.py +0 -22
  15. models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py +0 -31
  16. models/spabert/experiments/entity_matching/data_processing/samples.sparql +0 -22
  17. models/spabert/experiments/entity_matching/data_processing/select_ambi.py +0 -18
  18. models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json +0 -0
  19. models/spabert/experiments/entity_matching/src/evaluation-mrr.py +0 -260
  20. models/spabert/experiments/entity_matching/src/linking_ablation.py +0 -228
  21. models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py +0 -329
  22. models/spabert/experiments/semantic_typing/__init__.py +0 -0
  23. models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py +0 -97
  24. models/spabert/experiments/semantic_typing/src/__init__.py +0 -0
  25. models/spabert/experiments/semantic_typing/src/run_baseline_test.py +0 -82
  26. models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py +0 -209
  27. models/spabert/experiments/semantic_typing/src/test_cls_baseline.py +0 -189
  28. models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py +0 -214
  29. models/spabert/experiments/semantic_typing/src/train_cls_baseline.py +0 -227
  30. models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py +0 -276
models/spabert/experiments/__init__.py DELETED
File without changes
models/spabert/experiments/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (155 Bytes)
 
models/spabert/experiments/entity_matching/__init__.py DELETED
File without changes
models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (171 Bytes)
 
models/spabert/experiments/entity_matching/data_processing/__init__.py DELETED
File without changes
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (187 Bytes)
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc DELETED
Binary file (204 Bytes)
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc DELETED
Binary file (5.97 kB)
 
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc DELETED
Binary file (7.83 kB)
 
models/spabert/experiments/entity_matching/data_processing/get_namelist.py DELETED
@@ -1,95 +0,0 @@
1
- import json
2
- import os
3
-
4
- def get_name_list_osm(ref_paths):
5
- name_list = []
6
-
7
- for json_path in ref_paths:
8
- with open(json_path, 'r') as f:
9
- data = f.readlines()
10
- for line in data:
11
- record = json.loads(line)
12
- name = record['name']
13
- name_list.append(name)
14
-
15
- namelist = sorted(namelist)
16
- return name_list
17
-
18
- # deprecated
19
- def get_name_list_usgs_od(ref_paths):
20
- name_list = []
21
-
22
- for json_path in ref_paths:
23
- with open(json_path, 'r') as f:
24
- annot_dict = json.load(f)
25
- for key, place in annot_dict.items():
26
- place_name = ''
27
- for idx in range(1, len(place)+1):
28
- try:
29
- place_name += place[str(idx)]['text_label']
30
- place_name += ' ' # separate words with spaces
31
-
32
- except Exception as e:
33
- print(place)
34
- place_name = place_name[:-1] # remove last space
35
-
36
- name_list.append(place_name)
37
-
38
- namelist = sorted(namelist)
39
- return name_list
40
-
41
- def get_name_list_usgs_od_per_map(ref_paths):
42
- all_name_list_dict = dict()
43
-
44
- for json_path in ref_paths:
45
- map_name = os.path.basename(json_path).split('.json')[0]
46
-
47
- with open(json_path, 'r') as f:
48
- annot_dict = json.load(f)
49
-
50
- map_name_list = []
51
- for key, place in annot_dict.items():
52
- place_name = ''
53
- for idx in range(1, len(place)+1):
54
- try:
55
- place_name += place[str(idx)]['text_label']
56
- place_name += ' ' # separate words with spaces
57
-
58
- except Exception as e:
59
- print(place)
60
- place_name = place_name[:-1] # remove last space
61
-
62
- map_name_list.append(place_name)
63
- all_name_list_dict[map_name] = sorted(map_name_list)
64
-
65
- return all_name_list_dict
66
-
67
-
68
- def get_name_list_gb1900(ref_path):
69
- name_list = []
70
-
71
- with open(ref_path, 'r',encoding='utf-16') as f:
72
- data = f.readlines()
73
-
74
-
75
- for line in data[1:]: # skip the header
76
- try:
77
- line = line.split(',')
78
- text = line[1]
79
- lat = float(line[-3])
80
- lng = float(line[-2])
81
- semantic_type = line[-1]
82
-
83
- name_list.append(text)
84
- except:
85
- print(line)
86
-
87
- namelist = sorted(namelist)
88
-
89
- return name_list
90
-
91
-
92
- if __name__ == '__main__':
93
- #name_list = get_name_list_usgs_od(['labGISReport-master/output/USGS-15-CA-brawley-e1957-s1957-p1961.json',
94
- #'labGISReport-master/output/USGS-15-CA-capesanmartin-e1921-s1917.json'])
95
- name_list = get_name_list_gb1900('data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/request_wrapper.py DELETED
@@ -1,186 +0,0 @@
1
- import requests
2
- import pdb
3
- import time
4
-
5
- # for linkedgeodata: #http://linkedgeodata.org/sparql
6
-
7
- class RequestWrapper:
8
- def __init__(self, baseuri = "https://query.wikidata.org/sparql"):
9
-
10
- self.baseuri = baseuri
11
-
12
- def response_handler(self, response, query):
13
- if response.status_code == requests.codes.ok:
14
- ret_json = response.json()['results']['bindings']
15
- elif response.status_code == 500:
16
- ret_json = []
17
- #print(q_id)
18
- print('Internal Error happened. Set ret_json to be empty list')
19
-
20
- elif response.status_code == 429:
21
-
22
- print(response.status_code)
23
- print(response.text)
24
- retry_seconds = int(response.text.split('Too Many Requests - Please retry in ')[1].split(' seconds')[0])
25
- print('rerun in %d seconds' %retry_seconds)
26
- time.sleep(retry_seconds + 1)
27
-
28
- response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
29
- ret_json = response.json()['results']['bindings']
30
- #print(ret_json)
31
- print('resumed and succeeded')
32
-
33
- else:
34
- print(response.status_code, response.text)
35
- exit(-1)
36
-
37
- return ret_json
38
-
39
- '''Search for wikidata entities given the name string'''
40
- def wikidata_query (self, name_str):
41
-
42
- query = """
43
- PREFIX wd: <http://www.wikidata.org/entity/>
44
- PREFIX wds: <http://www.wikidata.org/entity/statement/>
45
- PREFIX wdv: <http://www.wikidata.org/value/>
46
- PREFIX wdt: <http://www.wikidata.org/prop/direct/>
47
- PREFIX wikibase: <http://wikiba.se/ontology#>
48
- PREFIX p: <http://www.wikidata.org/prop/>
49
- PREFIX ps: <http://www.wikidata.org/prop/statement/>
50
- PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
51
- PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
52
- PREFIX bd: <http://www.bigdata.com/rdf#>
53
-
54
- SELECT ?item ?coordinates ?itemDescription WHERE {
55
- ?item rdfs:label \"%s\"@en;
56
- wdt:P625 ?coordinates .
57
- SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
58
- }
59
- """%(name_str)
60
-
61
- response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
62
-
63
-
64
- ret_json = self.response_handler(response, query)
65
-
66
- return ret_json
67
-
68
-
69
- '''Search for wikidata entities given the name string'''
70
- def wikidata_query_withinstate (self, name_str, state_id = 'Q99'):
71
-
72
-
73
- query = """
74
- PREFIX wd: <http://www.wikidata.org/entity/>
75
- PREFIX wds: <http://www.wikidata.org/entity/statement/>
76
- PREFIX wdv: <http://www.wikidata.org/value/>
77
- PREFIX wdt: <http://www.wikidata.org/prop/direct/>
78
- PREFIX wikibase: <http://wikiba.se/ontology#>
79
- PREFIX p: <http://www.wikidata.org/prop/>
80
- PREFIX ps: <http://www.wikidata.org/prop/statement/>
81
- PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
82
- PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
83
- PREFIX bd: <http://www.bigdata.com/rdf#>
84
-
85
- SELECT ?item ?coordinates ?itemDescription WHERE {
86
- ?item rdfs:label \"%s\"@en;
87
- wdt:P625 ?coordinates ;
88
- wdt:P131+ wd:%s;
89
- SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
90
- }
91
- """%(name_str, state_id)
92
-
93
- #print(query)
94
-
95
- response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
96
-
97
- ret_json = self.response_handler(response, query)
98
-
99
- return ret_json
100
-
101
-
102
- '''Search for nearby wikidata entities given the entity id'''
103
- def wikidata_nearby_query (self, q_id):
104
-
105
- query = """
106
- PREFIX wd: <http://www.wikidata.org/entity/>
107
- PREFIX wds: <http://www.wikidata.org/entity/statement/>
108
- PREFIX wdv: <http://www.wikidata.org/value/>
109
- PREFIX wdt: <http://www.wikidata.org/prop/direct/>
110
- PREFIX wikibase: <http://wikiba.se/ontology#>
111
- PREFIX p: <http://www.wikidata.org/prop/>
112
- PREFIX ps: <http://www.wikidata.org/prop/statement/>
113
- PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
114
- PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
115
- PREFIX bd: <http://www.bigdata.com/rdf#>
116
-
117
- SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
118
- WHERE
119
- {
120
- wd:%s wdt:P625 ?loc .
121
- SERVICE wikibase:around {
122
- ?place wdt:P625 ?location .
123
- bd:serviceParam wikibase:center ?loc .
124
- bd:serviceParam wikibase:radius "5" .
125
- }
126
- OPTIONAL { ?place wdt:P31 ?instance }
127
- SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
128
- BIND(geof:distance(?loc, ?location) as ?dist)
129
- } ORDER BY ?dist
130
- LIMIT 200
131
- """%(q_id)
132
- # initially 2km
133
-
134
- #pdb.set_trace()
135
-
136
- response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
137
-
138
-
139
- ret_json = self.response_handler(response, query)
140
-
141
-
142
-
143
- return ret_json
144
-
145
-
146
-
147
-
148
- def linkedgeodata_query (self, name_str):
149
-
150
- query = """
151
-
152
- Prefix lgdo: <http://linkedgeodata.org/ontology/>
153
- Prefix geom: <http://geovocab.org/geometry#>
154
- Prefix ogc: <http://www.opengis.net/ont/geosparql#>
155
- Prefix owl: <http://www.w3.org/2002/07/owl#>
156
- Prefix wgs84_pos: <http://www.w3.org/2003/01/geo/wgs84_pos#>
157
- Prefix owl: <http://www.w3.org/2002/07/owl#>
158
- Prefix gn: <http://www.geonames.org/ontology#>
159
-
160
- Select ?s, ?lat, ?long {
161
- {?s rdfs:label \"%s\";
162
- wgs84_pos:lat ?lat ;
163
- wgs84_pos:long ?long;
164
- }
165
- }
166
- """%(name_str)
167
-
168
-
169
-
170
- response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
171
-
172
-
173
- ret_json = self.response_handler(response, query)
174
-
175
-
176
- return ret_json
177
-
178
-
179
-
180
- if __name__ == '__main__':
181
- request_wrapper_wikidata = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
182
- #print(request_wrapper_wikidata.wikidata_nearby_query('Q370771'))
183
- #print(request_wrapper_wikidata.wikidata_query_withinstate('San Bernardino'))
184
-
185
- # not working now
186
- print(request_wrapper_wikidata.linkedgeodata_query('San Bernardino'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/run_linking_query.py DELETED
@@ -1,143 +0,0 @@
1
- #from query_wrapper import QueryWrapper
2
- from request_wrapper import RequestWrapper
3
- from get_namelist import *
4
- import glob
5
- import os
6
- import json
7
- import time
8
-
9
- DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
10
- KB_OPTIONS = ['wikidata', 'linkedgeodata']
11
-
12
- DATASET = 'USGS'
13
- KB = 'wikidata'
14
- OVERWRITE = True
15
- WITHIN_CA = True
16
-
17
- assert DATASET in DATASET_OPTIONS
18
- assert KB in KB_OPTIONS
19
-
20
-
21
- def process_one_namelist(sparql_wrapper, namelist, out_path):
22
-
23
- if OVERWRITE:
24
- # flush the file if it's been written
25
- with open(out_path, 'w') as f:
26
- f.write('')
27
-
28
-
29
- for name in namelist:
30
- name = name.replace('"', '')
31
- name = name.strip("'")
32
- if len(name) == 0:
33
- continue
34
- print(name)
35
- mydict = dict()
36
-
37
- if KB == 'wikidata':
38
- if WITHIN_CA:
39
- mydict[name] = sparql_wrapper.wikidata_query_withinstate(name)
40
- else:
41
- mydict[name] = sparql_wrapper.wikidata_query(name)
42
-
43
- elif KB == 'linkedgeodata':
44
- mydict[name] = sparql_wrapper.linkedgeodata_query(name)
45
- else:
46
- raise NotImplementedError
47
-
48
- line = json.dumps(mydict)
49
-
50
- with open(out_path, 'a') as f:
51
- f.write(line)
52
- f.write('\n')
53
- time.sleep(1)
54
-
55
-
56
- def process_namelist_dict(sparql_wrapper, namelist_dict, out_dir):
57
- i = 0
58
- for map_name, namelist in namelist_dict.items():
59
- # if i <=5:
60
- # i += 1
61
- # continue
62
-
63
- print('processing %s' %map_name)
64
-
65
- if WITHIN_CA:
66
- out_path = os.path.join(out_dir, KB + '_' + map_name + '.json')
67
- else:
68
- out_path = os.path.join(out_dir, KB + '_ca_' + map_name + '.json')
69
-
70
- process_one_namelist(sparql_wrapper, namelist, out_path)
71
- i+=1
72
-
73
-
74
- if KB == 'linkedgeodata':
75
- sparql_wrapper = RequestWrapper(baseuri = 'http://linkedgeodata.org/sparql')
76
- elif KB == 'wikidata':
77
- sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
78
- else:
79
- raise NotImplementedError
80
-
81
-
82
-
83
- if DATASET == 'OSM':
84
- osm_dir = '../surface_form/data_sample_london/data_osm/'
85
- osm_paths = glob.glob(os.path.join(osm_dir, 'embedding*.json'))
86
-
87
- out_path = 'outputs/'+KB+'_linking.json'
88
- namelist = get_name_list_osm(osm_paths)
89
-
90
- print('# files',len(file_paths))
91
-
92
- process_one_namelist(sparql_wrapper, namelist, out_path)
93
-
94
-
95
- elif DATASET == 'OS':
96
- histmap_dir = 'data/labGISReport-master/output/'
97
- file_paths = glob.glob(os.path.join(histmap_dir, '10*.json'))
98
-
99
- out_path = 'outputs/'+KB+'_os_linking_descript.json'
100
- namelist = get_name_list_usgs_od(file_paths)
101
-
102
- print('# files',len(file_paths))
103
-
104
-
105
- process_one_namelist(sparql_wrapper, namelist, out_path)
106
-
107
- elif DATASET == 'USGS':
108
- histmap_dir = 'data/labGISReport-master/output/'
109
- file_paths = glob.glob(os.path.join(histmap_dir, 'USGS*.json'))
110
-
111
- if WITHIN_CA:
112
- out_dir = 'outputs/' + KB +'_ca'
113
- else:
114
- out_dir = 'outputs/' + KB
115
- namelist_dict = get_name_list_usgs_od_per_map(file_paths)
116
-
117
- if not os.path.isdir(out_dir):
118
- os.makedirs(out_dir)
119
-
120
- print('# files',len(file_paths))
121
-
122
- process_namelist_dict(sparql_wrapper, namelist_dict, out_dir)
123
-
124
- elif DATASET == 'GB1900':
125
-
126
- file_path = 'data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv'
127
- out_path = 'outputs/'+KB+'_gb1900_linking_descript.json'
128
- namelist = get_name_list_gb1900(file_path)
129
-
130
-
131
- process_one_namelist(sparql_wrapper, namelist, out_path)
132
-
133
- else:
134
- raise NotImplementedError
135
-
136
-
137
-
138
-
139
- #namelist = namelist[730:] #for GB1900
140
-
141
-
142
-
143
- print('done')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py DELETED
@@ -1,123 +0,0 @@
1
- from query_wrapper import QueryWrapper
2
- from request_wrapper import RequestWrapper
3
- import glob
4
- import os
5
- import json
6
- import time
7
- import random
8
-
9
- DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
10
- KB_OPTIONS = ['wikidata', 'linkedgeodata']
11
-
12
- dataset = 'USGS'
13
- kb = 'wikidata'
14
- overwrite = False
15
-
16
- assert dataset in DATASET_OPTIONS
17
- assert kb in KB_OPTIONS
18
-
19
- if dataset == 'OSM':
20
- raise NotImplementedError
21
-
22
- elif dataset == 'OS':
23
- raise NotImplementedError
24
-
25
- elif dataset == 'USGS':
26
-
27
- candidate_file_paths = glob.glob('outputs/alignment_dir/wikidata_USGS*.json')
28
- candidate_file_paths = sorted(candidate_file_paths)
29
-
30
- out_dir = 'outputs/wikidata_neighbors/'
31
-
32
- if not os.path.isdir(out_dir):
33
- os.makedirs(out_dir)
34
-
35
- elif dataset == 'GB1900':
36
-
37
- raise NotImplementedError
38
-
39
- else:
40
- raise NotImplementedError
41
-
42
-
43
- if kb == 'linkedgeodata':
44
- sparql_wrapper = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
45
- elif kb == 'wikidata':
46
- #sparql_wrapper = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
47
- sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
48
- else:
49
- raise NotImplementedError
50
-
51
- start_map = 6 # 6
52
- start_line = 4 # 4
53
-
54
- for candiate_file_path in candidate_file_paths[start_map:]:
55
- map_name = os.path.basename(candiate_file_path).split('wikidata_')[1]
56
- out_path = os.path.join(out_dir, 'wikidata_' + map_name)
57
-
58
- with open(candiate_file_path, 'r') as f:
59
- cand_data = f.readlines()
60
-
61
- with open(out_path, 'a') as out_f:
62
- for line in cand_data[start_line:]:
63
- line_dict = json.loads(line)
64
- ret_line_dict = dict()
65
- for key, value in line_dict.items(): # actually just one pair
66
-
67
- print(key)
68
-
69
- place_name = key
70
- for cand_entity in value:
71
- time.sleep(2)
72
- q_id = cand_entity['item']['value'].split('/')[-1]
73
- response = sparql_wrapper.wikidata_nearby_query(str(q_id))
74
-
75
- if place_name in ret_line_dict:
76
- ret_line_dict[place_name].append(response)
77
- else:
78
- ret_line_dict[place_name] = [response]
79
-
80
- #time.sleep(random.random()*6)
81
-
82
-
83
-
84
- out_f.write(json.dumps(ret_line_dict))
85
- out_f.write('\n')
86
-
87
- print('finished with ',candiate_file_path)
88
- break
89
-
90
- print('done')
91
-
92
- '''
93
- {"Martin": [{"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27001"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(18.921388888 49.063611111)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in northern Slovakia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q281028"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.734166666 43.175)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in and county seat of Bennett County, South Dakota, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q761390"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(1.3394 51.179)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "hamlet in Kent"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2177502"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-100.115 47.826666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in North Dakota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2454021"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-85.64168 42.53698)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Michigan, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2481111"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.21833 32.09917)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Red River Parish, Louisiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2635473"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.75944 37.56778)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Kentucky, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2679547"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-1.9041 50.9759)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Hampshire, England, United Kingdom"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2780056"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-88.851666666 36.341944444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Tennessee, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q3261150"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.18556 34.48639)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "town in Franklin and Stephens Counties, Georgia, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6002227"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.709 41.2581)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "census-designated place in Nebraska, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774807"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.1906 29.2936)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Marion County, Florida, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774809"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.336666666 41.5575)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Ohio, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774810"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(116.037 -32.071)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "suburb of Perth, Western Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q9029707"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-81.476388888 33.069166666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in South Carolina, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q11770660"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.325391 53.1245)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village and civil parish in Lincolnshire, UK"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14692833"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.5444 47.4894)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Itasca County, Minnesota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14714180"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.0889 39.2242)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Grant County, West Virginia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q18496647"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.95941 46.34585)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Italy"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q20949553"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(22.851944444 41.954444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "mountain in Republic of Macedonia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q24065096"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "<http://www.wikidata.org/entity/Q111> Point(290.75 -21.34)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "crater on Mars"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26300074"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.97879914 51.74729077)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Thame, South Oxfordshire, Oxfordshire, OX9"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27988822"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-87.67361111 38.12444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Vanderburgh County, Indiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27995389"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-121.317 47.28)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Washington, United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q28345614"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-76.8 48.5)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Senneterre, Quebec, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q30626037"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.90972222 39.80638889)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q61038281"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-91.13 49.25)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Meteorological Service of Canada's station for Martin (MSC ID: 6035000), Ontario, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526691"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(152.7011111 -29.881666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Fitzroy County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526695"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(148.1011111 -33.231666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Ashburnham County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96149222"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(14.725573779 48.76768332)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96158116"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(15.638358708 49.930136157)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q103777024"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-3.125028822 58.694748091)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Shipwreck off the Scottish Coast, imported from Canmore Nov 2020"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q107077206"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.688165 46.582982)"}}]}
94
-
95
-
96
- if overwrite:
97
- # flush the file if it's been written
98
- with open(out_path, 'w') as f:
99
- f.write('')
100
-
101
- for name in namelist:
102
- name = name.replace('"', '')
103
- name = name.strip("'")
104
- if len(name) == 0:
105
- continue
106
- print(name)
107
- mydict = dict()
108
- mydict[name] = sparql_wrapper.wikidata_query(name)
109
- line = json.dumps(mydict)
110
- #print(line)
111
- with open(out_path, 'a') as f:
112
- f.write(line)
113
- f.write('\n')
114
- time.sleep(1)
115
-
116
- print('done')
117
-
118
-
119
- {"info": {"name": "10TH ST", "geometry": [4193.118085062303, -831.274950414831]},
120
- "neighbor_info":
121
- {"name_list": ["BM 107", "PALM AVE", "WT", "Hidalgo Sch", "PO", "MAIN ST", "BRYANT CANAL", "BM 123", "Oakley Sch", "BRAWLEY", "Witter Sch", "BM 104", "Pistol Range", "Reid Sch", "STANLEY", "MUNICIPAL AIRPORT", "WESTERN AVE", "CANAL", "Riverview Cem", "BEST CANAL"],
122
- "geometry_list": [[4180.493095652702, -836.0635465095995], [4240.450935702045, -855.345637906981], [4136.084840542623, -917.7895986922882], [4150.386997979736, -948.7258091165079], [4056.955267048625, -847.1018277439381], [4008.112642182582, -849.089249977583], [4124.177447575567, -1004.0706369942257], [4145.382175508665, -626.1608201557082], [4398.137868976953, -764.1087236140554], [4221.1546492913285, -1062.5745271963772], [4015.203890157584, -985.0178210457995], [3989.2345421184878, -948.9340389243871], [4385.585449075614, -660.4590917125413], [3936.505159635338, -803.6822663422273], [3960.1233867112846, -686.7988766730389], [4409.714306709143, -600.6633389979504], [3871.2873706574037, -832.0785684368772], [4304.899727301024, -524.472390102557], [3955.640201659347, -578.5544271698675], [4075.8524354668034, -1183.5837385075774]]}}
123
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/run_query_sample.py DELETED
@@ -1,22 +0,0 @@
1
- from query_wrapper import QueryWrapper
2
- from get_namelist import *
3
- import glob
4
- import os
5
- import json
6
- import time
7
-
8
-
9
- #sparql_wrapper_linkedgeo = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
10
-
11
- #print(sparql_wrapper_linkedgeo.linkedgeodata_query('Los Angeles'))
12
-
13
-
14
- sparql_wrapper_wikidata = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
15
-
16
- #print(sparql_wrapper_wikidata.wikidata_query('Los Angeles'))
17
-
18
- #time.sleep(3)
19
-
20
- #print(sparql_wrapper_wikidata.wikidata_nearby_query('Q370771'))
21
- print(sparql_wrapper_wikidata.wikidata_nearby_query('Q97625145'))
22
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py DELETED
@@ -1,31 +0,0 @@
1
- import pandas as pd
2
- import json
3
- from request_wrapper import RequestWrapper
4
- import time
5
- import pdb
6
-
7
- start_idx = 17335
8
- wikidata_sample30k_path = 'wikidata_sample30k/wikidata_30k.json'
9
- out_path = 'wikidata_sample30k/wikidata_30k_neighbor.json'
10
-
11
- #with open(out_path, 'w') as out_f:
12
- # pass
13
-
14
- sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
15
-
16
- df= pd.read_json(wikidata_sample30k_path)
17
- df = df[start_idx:]
18
-
19
- print('length of df:', len(df))
20
-
21
- for index, record in df.iterrows():
22
- print(index)
23
- uri = record.results['item']['value']
24
- q_id = uri.split('/')[-1]
25
- response = sparql_wrapper.wikidata_nearby_query(str(q_id))
26
- time.sleep(1)
27
- with open(out_path, 'a') as out_f:
28
- out_f.write(json.dumps(response))
29
- out_f.write('\n')
30
-
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/samples.sparql DELETED
@@ -1,22 +0,0 @@
1
- '''
2
- Entities near xxx within 1km
3
- SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
4
- WHERE
5
- {
6
- wd:Q9188 wdt:P625 ?loc .
7
- SERVICE wikibase:around {
8
- ?place wdt:P625 ?location .
9
- bd:serviceParam wikibase:center ?loc .
10
- bd:serviceParam wikibase:radius "1" .
11
- }
12
- OPTIONAL { ?place wdt:P31 ?instance }
13
- SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
14
- BIND(geof:distance(?loc, ?location) as ?dist)
15
- } ORDER BY ?dist
16
- '''
17
-
18
-
19
- SELECT distinct ?item ?itemLabel WHERE {
20
- ?item wdt:P625 ?geo .
21
- SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/select_ambi.py DELETED
@@ -1,18 +0,0 @@
1
- import json
2
-
3
- json_file = 'entity_linking/outputs/wikidata_usgs_linking_descript.json'
4
-
5
- with open(json_file, 'r') as f:
6
- data = f.readlines()
7
-
8
- num_ambi = 0
9
- for line in data:
10
- line_dict = json.loads(line)
11
- for key,value in line_dict.items():
12
- len_value = len(value)
13
- if len_value < 2:
14
- continue
15
- else:
16
- num_ambi += 1
17
- print(key)
18
- print(num_ambi)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json DELETED
The diff for this file is too large to render. See raw diff
 
models/spabert/experiments/entity_matching/src/evaluation-mrr.py DELETED
@@ -1,260 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
-
5
- import sys
6
- import os
7
- import glob
8
- import json
9
- import numpy as np
10
- import pandas as pd
11
- import pdb
12
-
13
-
14
- prediction_dir = sys.argv[1]
15
-
16
- print(prediction_dir)
17
-
18
- gt_dir = '../data_processing/outputs/alignment_gt_dir/'
19
- prediction_path_list = sorted(os.listdir(prediction_dir))
20
-
21
- DISPLAY = False
22
- DETAIL = False
23
-
24
- if DISPLAY:
25
- from IPython.display import display
26
-
27
- def recall_at_k_all_map(all_rank_list, k = 1):
28
-
29
- rank_list = [item for sublist in all_rank_list for item in sublist]
30
- total_query = len(rank_list)
31
- prec = np.sum(np.array(rank_list)<=k)
32
- prec = 1.0 * prec / total_query
33
-
34
- return prec
35
-
36
- def recall_at_k_permap(all_rank_list, k = 1):
37
-
38
- prec_list = []
39
- for rank_list in all_rank_list:
40
- total_query = len(rank_list)
41
- prec = np.sum(np.array(rank_list)<=k)
42
- prec = 1.0 * prec / total_query
43
- prec_list.append(prec)
44
-
45
- return prec_list
46
-
47
-
48
-
49
-
50
- def reciprocal_rank(all_rank_list):
51
-
52
- recip_list = [1./rank for rank in all_rank_list]
53
- mean_recip = np.mean(recip_list)
54
-
55
- return mean_recip, recip_list
56
-
57
-
58
-
59
-
60
- count_hist_list = []
61
-
62
- all_rank_list = []
63
-
64
- all_recip_list = []
65
-
66
- permap_recip_list = []
67
-
68
- for map_path in prediction_path_list:
69
-
70
- pred_path = os.path.join(prediction_dir, map_path)
71
- gt_path = os.path.join(gt_dir, map_path.split('.json')[0] + '.csv')
72
-
73
- if DETAIL:
74
- print(pred_path)
75
-
76
-
77
- with open(gt_path, 'r') as f:
78
- gt_data = f.readlines()
79
-
80
- gt_dict = dict()
81
- for line in gt_data:
82
- line = line.split(',')
83
- pivot_name = line[0]
84
- gt_uri = line[1]
85
- gt_dict[pivot_name] = gt_uri
86
-
87
- rank_list = []
88
- pivot_name_list = []
89
- with open(pred_path, 'r') as f:
90
- pred_data = f.readlines()
91
- for line in pred_data:
92
- pred_dict = json.loads(line)
93
- #print(pred_dict.keys())
94
- pivot_name = pred_dict['pivot_name']
95
- sorted_match_uri = pred_dict['sorted_match_uri']
96
- #sorted_match_des = pred_dict['sorted_match_des']
97
- sorted_sim_matrix = pred_dict['sorted_sim_matrix']
98
-
99
-
100
- total = len(sorted_match_uri)
101
- if total == 1:
102
- continue
103
-
104
- if pivot_name in gt_dict:
105
-
106
- gt_uri = gt_dict[pivot_name]
107
-
108
- try:
109
- assert gt_uri in sorted_match_uri
110
- except Exception as e:
111
- #print(e)
112
- continue
113
-
114
- pivot_name_list.append(pivot_name)
115
- count_hist_list.append(total)
116
- rank = sorted_match_uri.index(gt_uri) +1
117
-
118
- rank_list.append(rank)
119
- #print(rank,'/',total)
120
-
121
- all_rank_list.append(rank_list)
122
-
123
- mean_recip, recip_list = reciprocal_rank(rank_list)
124
-
125
- all_recip_list.extend(recip_list)
126
- permap_recip_list.append(recip_list)
127
-
128
- d = {'pivot': pivot_name_list + ['AVG'], 'rank':rank_list + [' '] ,'recip rank': recip_list + [str(mean_recip)]}
129
- if DETAIL:
130
- print(pivot_name_list, rank_list, recip_list)
131
-
132
- if DISPLAY:
133
- df = pd.DataFrame(data=d)
134
-
135
- display(df)
136
-
137
-
138
-
139
- print('all mrr, micro', np.mean(all_recip_list))
140
-
141
-
142
- if DETAIL:
143
-
144
- len(rank_list)
145
-
146
-
147
-
148
- print(recall_at_k_all_map(all_rank_list, k = 1))
149
- print(recall_at_k_all_map(all_rank_list, k = 2))
150
- print(recall_at_k_all_map(all_rank_list, k = 5))
151
- print(recall_at_k_all_map(all_rank_list, k = 10))
152
-
153
-
154
- print(prediction_path_list)
155
-
156
-
157
- prec_list_1 = recall_at_k_permap(all_rank_list, k = 1)
158
- prec_list_2 = recall_at_k_permap(all_rank_list, k = 2)
159
- prec_list_5 = recall_at_k_permap(all_rank_list, k = 5)
160
- prec_list_10 = recall_at_k_permap(all_rank_list, k = 10)
161
-
162
- if DETAIL:
163
-
164
- print(np.mean(prec_list_1))
165
- print(prec_list_1)
166
- print('\n')
167
-
168
- print(np.mean(prec_list_2))
169
- print(prec_list_2)
170
- print('\n')
171
-
172
- print(np.mean(prec_list_5))
173
- print(prec_list_5)
174
- print('\n')
175
-
176
- print(np.mean(prec_list_10))
177
- print(prec_list_10)
178
- print('\n')
179
-
180
-
181
-
182
-
183
-
184
-
185
-
186
-
187
- import pandas as pd
188
-
189
-
190
-
191
- map_name_list = [name.split('.json')[0].split('USGS-')[1] for name in prediction_path_list]
192
- d = {'map_name': map_name_list,'recall@1': prec_list_1, 'recall@2': prec_list_2, 'recall@5': prec_list_5, 'recall@10': prec_list_10 }
193
- df = pd.DataFrame(data=d)
194
-
195
-
196
- if DETAIL:
197
- print(df)
198
-
199
-
200
-
201
-
202
-
203
- category = ['15-CA','30-CA','60-CA']
204
- col_1 = [np.mean(prec_list_1[0:4]), np.mean(prec_list_1[4:9]), np.mean(prec_list_1[9:])]
205
- col_2 = [np.mean(prec_list_2[0:4]), np.mean(prec_list_2[4:9]), np.mean(prec_list_2[9:])]
206
- col_3 = [np.mean(prec_list_5[0:4]), np.mean(prec_list_5[4:9]), np.mean(prec_list_5[9:])]
207
- col_4 = [np.mean(prec_list_10[0:4]), np.mean(prec_list_10[4:9]), np.mean(prec_list_10[9:])]
208
-
209
-
210
-
211
- mrr_15 = permap_recip_list[0] + permap_recip_list[1] + permap_recip_list[2] + permap_recip_list[3]
212
- mrr_30 = permap_recip_list[4] + permap_recip_list[5] + permap_recip_list[6] + permap_recip_list[7] + permap_recip_list[8]
213
- mrr_60 = permap_recip_list[9] + permap_recip_list[10] + permap_recip_list[11] + permap_recip_list[12] + permap_recip_list[13]
214
-
215
-
216
-
217
- column_5 = [np.mean(mrr_15), np.mean(mrr_30), np.mean(mrr_60)]
218
-
219
-
220
- d = {'map set': category, 'mrr': column_5, 'prec@1': col_1, 'prec@2': col_2, 'prec@5': col_3, 'prec@10': col_4 }
221
- df = pd.DataFrame(data=d)
222
-
223
- print(df)
224
-
225
-
226
-
227
-
228
- print('all mrr, micro', np.mean(all_recip_list))
229
-
230
- print('\n')
231
-
232
-
233
-
234
- print(recall_at_k_all_map(all_rank_list, k = 1))
235
- print(recall_at_k_all_map(all_rank_list, k = 2))
236
- print(recall_at_k_all_map(all_rank_list, k = 5))
237
- print(recall_at_k_all_map(all_rank_list, k = 10))
238
-
239
-
240
-
241
-
242
- if DISPLAY:
243
-
244
- import seaborn
245
-
246
- p = seaborn.histplot(data = count_hist_list, color = 'blue', alpha=0.2)
247
- p.set_xlabel("Number of Candiates")
248
- p.set_title("Candidate Distribution in USGS")
249
-
250
-
251
-
252
-
253
- len(count_hist_list)
254
-
255
-
256
-
257
-
258
-
259
-
260
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/src/linking_ablation.py DELETED
@@ -1,228 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import sys
5
- import os
6
- import numpy as np
7
- import pdb
8
- import json
9
- import scipy.spatial as sp
10
- import argparse
11
-
12
-
13
- import torch
14
- from torch.utils.data import DataLoader
15
-
16
- from transformers import AdamW
17
- from transformers import BertTokenizer
18
- from tqdm import tqdm # for our progress bar
19
-
20
- sys.path.append('../../../')
21
- from datasets.usgs_os_sample_loader import USGS_MapDataset
22
- from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
23
- from models.spatial_bert_model import SpatialBertModel
24
- from models.spatial_bert_model import SpatialBertConfig
25
- from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
26
- from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
27
- from utils.baseline_utils import get_baseline_model
28
-
29
- from transformers import BertModel
30
-
31
- sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
32
- from dataset_loader import SpatialDataset
33
- from osm_sample_loader import PbfMapDataset
34
-
35
-
36
-
37
- MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
38
- 'spanbert-base','spanbert-large','luke-base','luke-large',
39
- 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
40
-
41
-
42
- CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
43
-
44
- def recall_at_k(rank_list, k = 1):
45
-
46
- total_query = len(rank_list)
47
- recall = np.sum(np.array(rank_list)<=k)
48
- recall = 1.0 * recall / total_query
49
-
50
- return recall
51
-
52
- def reciprocal_rank(all_rank_list):
53
-
54
- recip_list = [1./rank for rank in all_rank_list]
55
- mean_recip = np.mean(recip_list)
56
-
57
- return mean_recip, recip_list
58
-
59
- def link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list):
60
-
61
- source_emb_list = [source_dict['emb'] for source_dict in source_embedding_ogc_list]
62
- source_ogc_list = [source_dict['ogc_fid'] for source_dict in source_embedding_ogc_list]
63
-
64
- target_emb_list = [target_dict['emb'] for target_dict in target_embedding_ogc_list]
65
- target_ogc_list = [target_dict['ogc_fid'] for target_dict in target_embedding_ogc_list]
66
-
67
- rank_list = []
68
- for source_emb, source_ogc in zip(source_emb_list, source_ogc_list):
69
- sim_matrix = 1 - sp.distance.cdist(np.array(target_emb_list), np.array([source_emb]), 'cosine')
70
- closest_match_ogc = sort_ref_closest_match(sim_matrix, target_ogc_list)
71
-
72
- closest_match_ogc = [a[0] for a in closest_match_ogc]
73
- rank = closest_match_ogc.index(source_ogc) +1
74
- rank_list.append(rank)
75
-
76
-
77
- mean_recip, recip_list = reciprocal_rank(rank_list)
78
- r1 = recall_at_k(rank_list, k = 1)
79
- r5 = recall_at_k(rank_list, k = 5)
80
- r10 = recall_at_k(rank_list, k = 10)
81
-
82
- return mean_recip , r1, r5, r10
83
-
84
- def get_embedding_and_ogc(dataset, model_name, model):
85
- dict_list = []
86
-
87
- for source in dataset:
88
- if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
89
- source_emb = get_spatialbert_embedding(source, model)
90
- else:
91
- source_emb = get_bert_embedding(source, model)
92
-
93
- source_dict = {}
94
- source_dict['emb'] = source_emb
95
- source_dict['ogc_fid'] = source['ogc_fid']
96
- #wikidata_dict['wikidata_des_list'] = [wikidata_cand['description']]
97
-
98
- dict_list.append(source_dict)
99
-
100
- return dict_list
101
-
102
-
103
- def entity_linking_func(args):
104
-
105
- model_name = args.model_name
106
- candset_mode = args.candset_mode
107
-
108
- distance_norm_factor = args.distance_norm_factor
109
- spatial_dist_fill= args.spatial_dist_fill
110
- sep_between_neighbors = args.sep_between_neighbors
111
-
112
- spatial_bert_weight_dir = args.spatial_bert_weight_dir
113
- spatial_bert_weight_name = args.spatial_bert_weight_name
114
-
115
- if_no_spatial_distance = args.no_spatial_distance
116
- random_remove_neighbor = args.random_remove_neighbor
117
-
118
-
119
- assert model_name in MODEL_OPTIONS
120
- assert candset_mode in CANDSET_MODES
121
-
122
-
123
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
124
-
125
-
126
-
127
- if model_name == 'spatial_bert-base':
128
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
129
-
130
- config = SpatialBertConfig()
131
- model = SpatialBertModel(config)
132
-
133
- model.to(device)
134
- model.eval()
135
-
136
- # load pretrained weights
137
- weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
138
- model = load_spatial_bert_pretrained_weights(model, weight_path)
139
-
140
- elif model_name == 'spatial_bert-large':
141
- tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
142
-
143
- config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
144
- model = SpatialBertModel(config)
145
-
146
- model.to(device)
147
- model.eval()
148
-
149
- # load pretrained weights
150
- weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
151
- model = load_spatial_bert_pretrained_weights(model, weight_path)
152
-
153
- else:
154
- model, tokenizer = get_baseline_model(model_name)
155
- model.to(device)
156
- model.eval()
157
-
158
- source_file_path = '../data/osm-point-minnesota-full.json'
159
- source_dataset = PbfMapDataset(data_file_path = source_file_path,
160
- tokenizer = tokenizer,
161
- max_token_len = 512,
162
- distance_norm_factor = distance_norm_factor,
163
- spatial_dist_fill = spatial_dist_fill,
164
- with_type = False,
165
- sep_between_neighbors = sep_between_neighbors,
166
- mode = None,
167
- random_remove_neighbor = random_remove_neighbor,
168
- )
169
-
170
- target_dataset = PbfMapDataset(data_file_path = source_file_path,
171
- tokenizer = tokenizer,
172
- max_token_len = 512,
173
- distance_norm_factor = distance_norm_factor,
174
- spatial_dist_fill = spatial_dist_fill,
175
- with_type = False,
176
- sep_between_neighbors = sep_between_neighbors,
177
- mode = None,
178
- random_remove_neighbor = 0., # keep all
179
- )
180
-
181
- # process candidates for each phrase
182
-
183
-
184
- source_embedding_ogc_list = get_embedding_and_ogc(source_dataset, model_name, model)
185
- target_embedding_ogc_list = get_embedding_and_ogc(target_dataset, model_name, model)
186
-
187
-
188
- mean_recip , r1, r5, r10 = link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list)
189
- print('\n')
190
- print(random_remove_neighbor, mean_recip , r1, r5, r10)
191
-
192
-
193
-
194
- def main():
195
- parser = argparse.ArgumentParser()
196
- parser.add_argument('--model_name', type=str, default='spatial_bert-base')
197
- parser.add_argument('--candset_mode', type=str, default='all_map')
198
-
199
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
200
- parser.add_argument('--spatial_dist_fill', type=float, default = 20)
201
-
202
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
203
- parser.add_argument('--no_spatial_distance', default=False, action='store_true')
204
-
205
- parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
206
- parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
207
-
208
- parser.add_argument('--random_remove_neighbor', type = float, default = 0.)
209
-
210
-
211
- args = parser.parse_args()
212
- # print('\n')
213
- # print(args)
214
- # print('\n')
215
-
216
- entity_linking_func(args)
217
-
218
- # CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-base' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr5e-05_sep_bert-base_nofreeze_london_california_bsize12/ep0_iter06000_0.2936/' --spatial_bert_weight_name='keeppos_ep0_iter02000_0.4879.pth' --random_remove_neighbor=0.1
219
-
220
-
221
- # CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-large' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr1e-06_sep_bert-large_nofreeze_london_california_bsize12/ep2_iter02000_0.3921/' --spatial_bert_weight_name='keeppos_ep8_iter03568_0.2661_val0.2284.pth' --random_remove_neighbor=0.1
222
-
223
-
224
- if __name__ == '__main__':
225
-
226
- main()
227
-
228
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py DELETED
@@ -1,329 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- import sys
5
- import os
6
- import numpy as np
7
- import pdb
8
- import json
9
- import scipy.spatial as sp
10
- import argparse
11
-
12
-
13
- import torch
14
- from torch.utils.data import DataLoader
15
- #from transformers.models.bert.modeling_bert import BertForMaskedLM
16
-
17
- from transformers import AdamW
18
- from transformers import BertTokenizer
19
- from tqdm import tqdm # for our progress bar
20
-
21
- sys.path.append('../../../')
22
- from datasets.usgs_os_sample_loader import USGS_MapDataset
23
- from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
24
- from models.spatial_bert_model import SpatialBertModel
25
- from models.spatial_bert_model import SpatialBertConfig
26
- #from models.spatial_bert_model import SpatialBertForMaskedLM
27
- from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
28
- from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
29
- from utils.baseline_utils import get_baseline_model
30
-
31
- from transformers import BertModel
32
-
33
-
34
- MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
35
- 'spanbert-base','spanbert-large','luke-base','luke-large',
36
- 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
37
-
38
- MAP_TYPES = ['usgs']
39
- CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
40
-
41
-
42
- def disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode = 'all_map', if_use_distance = True, select_indices = None):
43
-
44
- if select_indices is None:
45
- select_indices = range(0, len(usgs_dataset))
46
-
47
-
48
- assert(candset_mode in ['all_map','per_map'])
49
-
50
- wikidata_emb_list = [wikidata_dict['wikidata_emb_list'] for wikidata_dict in wikidata_dict_list]
51
- wikidata_uri_list = [wikidata_dict['wikidata_uri_list'] for wikidata_dict in wikidata_dict_list]
52
- #wikidata_des_list = [wikidata_dict['wikidata_des_list'] for wikidata_dict in wikidata_dict_list]
53
-
54
- if candset_mode == 'all_map':
55
- wikidata_emb_list = [item for sublist in wikidata_emb_list for item in sublist] # flatten
56
- wikidata_uri_list = [item for sublist in wikidata_uri_list for item in sublist] # flatten
57
- #wikidata_des_list = [item for sublist in wikidata_des_list for item in sublist] # flatten
58
-
59
-
60
- ret_list = []
61
- for i in select_indices:
62
-
63
- if candset_mode == 'per_map':
64
- usgs_entity = usgs_dataset[i]
65
- wikidata_emb_list = wikidata_emb_list[i]
66
- wikidata_uri_list = wikidata_uri_list[i]
67
- #wikidata_des_list = wikidata_des_list[i]
68
-
69
- elif candset_mode == 'all_map':
70
- usgs_entity = usgs_dataset[i]
71
- else:
72
- raise NotImplementedError
73
-
74
- if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
75
- usgs_emb = get_spatialbert_embedding(usgs_entity, model, use_distance = if_use_distance)
76
- else:
77
- usgs_emb = get_bert_embedding(usgs_entity, model)
78
-
79
-
80
- sim_matrix = 1 - sp.distance.cdist(np.array(wikidata_emb_list), np.array([usgs_emb]), 'cosine')
81
-
82
- closest_match_uri = sort_ref_closest_match(sim_matrix, wikidata_uri_list)
83
- #closest_match_des = sort_ref_closest_match(sim_matrix, wikidata_des_list)
84
-
85
-
86
- sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order
87
-
88
- ret_dict = dict()
89
- ret_dict['pivot_name'] = usgs_entity['pivot_name']
90
- ret_dict['sorted_match_uri'] = [a[0] for a in closest_match_uri]
91
- #ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des]
92
- ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix]
93
-
94
- ret_list.append(ret_dict)
95
-
96
- return ret_list
97
-
98
- def entity_linking_func(args):
99
-
100
- model_name = args.model_name
101
- map_type = args.map_type
102
- candset_mode = args.candset_mode
103
-
104
- usgs_distance_norm_factor = args.usgs_distance_norm_factor
105
- spatial_dist_fill= args.spatial_dist_fill
106
- sep_between_neighbors = args.sep_between_neighbors
107
-
108
- spatial_bert_weight_dir = args.spatial_bert_weight_dir
109
- spatial_bert_weight_name = args.spatial_bert_weight_name
110
-
111
- if_no_spatial_distance = args.no_spatial_distance
112
-
113
-
114
- assert model_name in MODEL_OPTIONS
115
- assert map_type in MAP_TYPES
116
- assert candset_mode in CANDSET_MODES
117
-
118
-
119
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
120
-
121
- if args.out_dir is None:
122
-
123
- if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
124
-
125
- if sep_between_neighbors:
126
- spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_sep'
127
- else:
128
- spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_nosep'
129
-
130
-
131
- checkpoint_ep = spatial_bert_weight_name.split('_')[3]
132
- checkpoint_iter = spatial_bert_weight_name.split('_')[4]
133
- loss_val = spatial_bert_weight_name.split('_')[5][:-4]
134
-
135
- if if_no_spatial_distance:
136
- linking_prediction_dir = 'linking_prediction_dir/abalation_no_distance/'
137
- else:
138
- linking_prediction_dir = 'linking_prediction_dir'
139
-
140
- if model_name == 'spatial_bert-base':
141
- out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val
142
- elif model_name == 'spatial_bert-large':
143
-
144
- freeze_str = spatial_bert_weight_dir.split('/')[-2].split('_')[1] # either 'freeze' or 'nofreeze'
145
- out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val + '-' + freeze_str
146
-
147
-
148
-
149
- else:
150
- out_dir = '/data2/zekun/baseline_linking_prediction_dir/' + map_type + '-' + model_name
151
-
152
- else:
153
- out_dir = args.out_dir
154
-
155
- print('out_dir', out_dir)
156
-
157
- if not os.path.isdir(out_dir):
158
- os.makedirs(out_dir)
159
-
160
-
161
- if model_name == 'spatial_bert-base':
162
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
163
-
164
- config = SpatialBertConfig()
165
- model = SpatialBertModel(config)
166
-
167
- model.to(device)
168
- model.eval()
169
-
170
- # load pretrained weights
171
- weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
172
- model = load_spatial_bert_pretrained_weights(model, weight_path)
173
-
174
- elif model_name == 'spatial_bert-large':
175
- tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
176
-
177
- config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
178
- model = SpatialBertModel(config)
179
-
180
- model.to(device)
181
- model.eval()
182
-
183
- # load pretrained weights
184
- weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
185
- model = load_spatial_bert_pretrained_weights(model, weight_path)
186
-
187
- else:
188
- model, tokenizer = get_baseline_model(model_name)
189
- model.to(device)
190
- model.eval()
191
-
192
-
193
- if map_type == 'usgs':
194
- map_name_list = ['USGS-15-CA-brawley-e1957-s1957-p1961',
195
- 'USGS-15-CA-paloalto-e1899-s1895-rp1911',
196
- 'USGS-15-CA-capesanmartin-e1921-s1917',
197
- 'USGS-15-CA-sanfrancisco-e1899-s1892-rp1911',
198
- 'USGS-30-CA-dardanelles-e1898-s1891-rp1912',
199
- 'USGS-30-CA-holtville-e1907-s1905-rp1946',
200
- 'USGS-30-CA-indiospecial-e1904-s1901-rp1910',
201
- 'USGS-30-CA-lompoc-e1943-s1903-ap1941-rv1941',
202
- 'USGS-30-CA-sanpedro-e1943-rv1944',
203
- 'USGS-60-CA-alturas-e1892-rp1904',
204
- 'USGS-60-CA-amboy-e1942',
205
- 'USGS-60-CA-amboy-e1943-rv1943',
206
- 'USGS-60-CA-modoclavabed-e1886-s1884',
207
- 'USGS-60-CA-saltonsea-e1943-ap1940-rv1942']
208
-
209
- print('processing wikidata...')
210
-
211
- wikidata_dict_list = []
212
-
213
- wikidata_random30k = Wikidata_Random_Dataset(
214
- data_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor_reformat.json',
215
- #neighbor_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor.json',
216
- tokenizer = tokenizer,
217
- max_token_len = 512,
218
- distance_norm_factor = 0.0001,
219
- spatial_dist_fill=100,
220
- sep_between_neighbors = sep_between_neighbors,
221
- )
222
-
223
- # process candidates for each phrase
224
- for wikidata_cand in wikidata_random30k:
225
- if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
226
- wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
227
- else:
228
- wikidata_emb = get_bert_embedding(wikidata_cand, model)
229
-
230
- wikidata_dict = {}
231
- wikidata_dict['wikidata_emb_list'] = [wikidata_emb]
232
- wikidata_dict['wikidata_uri_list'] = [wikidata_cand['uri']]
233
-
234
- wikidata_dict_list.append(wikidata_dict)
235
-
236
-
237
-
238
- for map_name in map_name_list:
239
-
240
- print(map_name)
241
-
242
- wikidata_dict_per_map = {}
243
- wikidata_dict_per_map['wikidata_emb_list'] = []
244
- wikidata_dict_per_map['wikidata_uri_list'] = []
245
-
246
- wikidata_dataset_permap = Wikidata_Geocoord_Dataset(
247
- data_file_path = '../data_processing/outputs/wikidata_reformat/wikidata_' + map_name + '.json',
248
- tokenizer = tokenizer,
249
- max_token_len = 512,
250
- distance_norm_factor = 0.0001,
251
- spatial_dist_fill=100,
252
- sep_between_neighbors = sep_between_neighbors)
253
-
254
-
255
-
256
- for i in range(0, len(wikidata_dataset_permap)):
257
- # get all candiates for phrases within the map
258
- wikidata_candidates = wikidata_dataset_permap[i] # dataset for each map, list of [cand for each phrase]
259
-
260
-
261
- # process candidates for each phrase
262
- for wikidata_cand in wikidata_candidates:
263
- if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
264
- wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
265
- else:
266
- wikidata_emb = get_bert_embedding(wikidata_cand, model)
267
-
268
- wikidata_dict_per_map['wikidata_emb_list'].append(wikidata_emb)
269
- wikidata_dict_per_map['wikidata_uri_list'].append(wikidata_cand['uri'])
270
-
271
-
272
- wikidata_dict_list.append(wikidata_dict_per_map)
273
-
274
-
275
-
276
- for map_name in map_name_list:
277
-
278
- print(map_name)
279
-
280
-
281
- usgs_dataset = USGS_MapDataset(
282
- data_file_path = '../data_processing/outputs/alignment_dir/map_' + map_name + '.json',
283
- tokenizer = tokenizer,
284
- distance_norm_factor = usgs_distance_norm_factor,
285
- spatial_dist_fill = spatial_dist_fill,
286
- sep_between_neighbors = sep_between_neighbors)
287
-
288
-
289
- ret_list = disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode= candset_mode, if_use_distance = not if_no_spatial_distance, select_indices = None)
290
-
291
- write_to_csv(out_dir, map_name, ret_list)
292
-
293
- print('Done')
294
-
295
-
296
- def main():
297
- parser = argparse.ArgumentParser()
298
- parser.add_argument('--model_name', type=str, default='spatial_bert-base')
299
- parser.add_argument('--out_dir', type=str, default=None)
300
- parser.add_argument('--map_type', type=str, default='usgs')
301
- parser.add_argument('--candset_mode', type=str, default='all_map')
302
-
303
- parser.add_argument('--usgs_distance_norm_factor', type=float, default = 1)
304
- parser.add_argument('--spatial_dist_fill', type=float, default = 100)
305
-
306
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
307
- parser.add_argument('--no_spatial_distance', default=False, action='store_true')
308
-
309
- parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
310
- parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
311
-
312
- args = parser.parse_args()
313
- print('\n')
314
- print(args)
315
- print('\n')
316
-
317
- # out_dir not None, and out_dir does not exist, then create out_dir
318
- if args.out_dir is not None and not os.path.isdir(args.out_dir):
319
- os.makedirs(args.out_dir)
320
-
321
- entity_linking_func(args)
322
-
323
-
324
-
325
- if __name__ == '__main__':
326
-
327
- main()
328
-
329
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/__init__.py DELETED
File without changes
models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py DELETED
@@ -1,97 +0,0 @@
1
- import os
2
- import json
3
- import math
4
- import glob
5
- import re
6
- import pdb
7
-
8
- '''
9
- NO LONGER NEEDED
10
-
11
- Process the california, london, and minnesota OSM data and prepare pseudo-sentence, spatial context
12
-
13
- Load the raw output files genrated by sql
14
- Unify the json by changing the structure of dictionary
15
- Save the output into two files, one for training+ validation, and the other one for testing
16
- '''
17
-
18
- region_list = ['california','london','minnesota']
19
-
20
- input_json_dir = '../data/sql_output/sub_files/'
21
- output_json_dir = '../data/sql_output/'
22
-
23
- for region_name in region_list:
24
- file_list = glob.glob(os.path.join(input_json_dir, 'spatialbert-osm-point-' + region_name + '*.json'))
25
- file_list = sorted(file_list)
26
- print('found %d files for region %s' % (len(file_list), region_name))
27
-
28
-
29
- num_test_files = int(math.ceil(len(file_list) * 0.2))
30
- num_train_val_files = len(file_list) - num_test_files
31
-
32
- print('%d files for train-val' % num_train_val_files)
33
- print('%d files for test-tes' % num_test_files)
34
-
35
- train_val_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_train_val.json')
36
- test_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_test.json')
37
-
38
- # refresh the file
39
- with open(train_val_output_path, 'w') as f:
40
- pass
41
- with open(test_output_path, 'w') as f:
42
- pass
43
-
44
- for idx in range(len(file_list)):
45
-
46
- if idx < num_train_val_files:
47
- output_path = train_val_output_path
48
- else:
49
- output_path = test_output_path
50
-
51
- file_path = file_list[idx]
52
-
53
- print(file_path)
54
-
55
- with open(file_path, 'r') as f:
56
- data = f.readlines()
57
-
58
- line = data[0]
59
-
60
- line = re.sub(r'\n', '', line)
61
- line = re.sub(r'\\n', '', line)
62
- line = re.sub(r'\\+', '', line)
63
- line = re.sub(r'\+', '', line)
64
-
65
- line_dict_list = json.loads(line)
66
-
67
-
68
- for line_dict in line_dict_list:
69
-
70
- line_dict = line_dict['json_build_object']
71
-
72
- if not line_dict['name'][0].isalpha(): # discard record if the first char is not enghlish etter
73
- continue
74
-
75
- neighbor_name_list = line_dict['neighbor_info'][0]['name_list']
76
- neighbor_geom_list = line_dict['neighbor_info'][0]['geometry_list']
77
-
78
- assert(len(neighbor_geom_list) == len(neighbor_geom_list))
79
-
80
- temp_dict = \
81
- {'info':{'name':line_dict['name'],
82
- 'geometry':{'coordinates':line_dict['geometry']},
83
- 'class':line_dict['class']
84
- },
85
- 'neighbor_info':{'name_list': neighbor_name_list,
86
- 'geometry_list': neighbor_geom_list
87
- }
88
- }
89
-
90
- with open(output_path, 'a') as f:
91
- json.dump(temp_dict, f)
92
- f.write('\n')
93
-
94
-
95
-
96
-
97
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/__init__.py DELETED
File without changes
models/spabert/experiments/semantic_typing/src/run_baseline_test.py DELETED
@@ -1,82 +0,0 @@
1
- import os
2
- import pdb
3
- import argparse
4
- import numpy as np
5
- import time
6
-
7
- MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
8
- 'spanbert-base','spanbert-large','luke-base','luke-large',
9
- 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
10
-
11
- def execute_command(command, if_print_command):
12
- t1 = time.time()
13
-
14
- if if_print_command:
15
- print(command)
16
- os.system(command)
17
-
18
- t2 = time.time()
19
- time_usage = t2 - t1
20
- return time_usage
21
-
22
- def run_test(args):
23
- weight_dir = args.weight_dir
24
- backbone_option = args.backbone_option
25
- gpu_id = str(args.gpu_id)
26
- if_print_command = args.print_command
27
- sep_between_neighbors = args.sep_between_neighbors
28
-
29
- assert backbone_option in MODEL_OPTIONS
30
-
31
- if sep_between_neighbors:
32
- sep_str = '_sep'
33
- else:
34
- sep_str = ''
35
-
36
- if 'large' in backbone_option:
37
- checkpoint_dir = os.path.join(weight_dir, 'typing_lr1e-06_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
38
- else:
39
- checkpoint_dir = os.path.join(weight_dir, 'typing_lr5e-05_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
40
- weight_files = os.listdir(checkpoint_dir)
41
-
42
- val_loss_list = [weight_file.split('_')[-1] for weight_file in weight_files]
43
- min_loss_weight = weight_files[np.argmin(val_loss_list)]
44
-
45
- checkpoint_path = os.path.join(checkpoint_dir, min_loss_weight)
46
-
47
- if sep_between_neighbors:
48
- command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --sep_between_neighbors --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
49
- else:
50
- command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
51
-
52
-
53
- execute_command(command, if_print_command)
54
-
55
-
56
-
57
-
58
- def main():
59
- parser = argparse.ArgumentParser()
60
-
61
- parser.add_argument('--weight_dir', type=str, default='/data2/zekun/spatial_bert_baseline_weights/')
62
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
63
- parser.add_argument('--backbone_option', type=str, default=None)
64
- parser.add_argument('--gpu_id', type=int, default=0) # output prefix
65
-
66
- parser.add_argument('--print_command', default=False, action='store_true')
67
-
68
-
69
- args = parser.parse_args()
70
- print('\n')
71
- print(args)
72
- print('\n')
73
-
74
- run_test(args)
75
-
76
-
77
- if __name__ == '__main__':
78
-
79
- main()
80
-
81
-
82
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py DELETED
@@ -1,209 +0,0 @@
1
- import os
2
- import sys
3
- from transformers import RobertaTokenizer, BertTokenizer
4
- from tqdm import tqdm # for our progress bar
5
- from transformers import AdamW
6
-
7
- import torch
8
- from torch.utils.data import DataLoader
9
- import torch.nn.functional as F
10
-
11
- sys.path.append('../../../')
12
- from models.spatial_bert_model import SpatialBertModel
13
- from models.spatial_bert_model import SpatialBertConfig
14
- from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
15
- from datasets.osm_sample_loader import PbfMapDataset
16
- from datasets.const import *
17
- from transformers.models.bert.modeling_bert import BertForMaskedLM
18
-
19
- from sklearn.metrics import label_ranking_average_precision_score
20
- from sklearn.metrics import precision_recall_fscore_support
21
- import numpy as np
22
- import argparse
23
- from sklearn.preprocessing import LabelEncoder
24
- import pdb
25
-
26
-
27
- DEBUG = False
28
- torch.backends.cudnn.deterministic = True
29
- torch.backends.cudnn.benchmark = False
30
- torch.manual_seed(42)
31
- torch.cuda.manual_seed_all(42)
32
-
33
-
34
- def testing(args):
35
-
36
- max_token_len = args.max_token_len
37
- batch_size = args.batch_size
38
- num_workers = args.num_workers
39
- distance_norm_factor = args.distance_norm_factor
40
- spatial_dist_fill=args.spatial_dist_fill
41
- with_type = args.with_type
42
- sep_between_neighbors = args.sep_between_neighbors
43
- checkpoint_path = args.checkpoint_path
44
- if_no_spatial_distance = args.no_spatial_distance
45
-
46
- bert_option = args.bert_option
47
- num_neighbor_limit = args.num_neighbor_limit
48
-
49
-
50
- london_file_path = '../data/sql_output/osm-point-london-typing.json'
51
- california_file_path = '../data/sql_output/osm-point-california-typing.json'
52
-
53
-
54
- if bert_option == 'bert-base':
55
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
56
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(CLASS_9_LIST))
57
- elif bert_option == 'bert-large':
58
- tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
59
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(CLASS_9_LIST))
60
- else:
61
- raise NotImplementedError
62
-
63
-
64
- model = SpatialBertForSemanticTyping(config)
65
-
66
-
67
- label_encoder = LabelEncoder()
68
- label_encoder.fit(CLASS_9_LIST)
69
-
70
-
71
-
72
- london_dataset = PbfMapDataset(data_file_path = london_file_path,
73
- tokenizer = tokenizer,
74
- max_token_len = max_token_len,
75
- distance_norm_factor = distance_norm_factor,
76
- spatial_dist_fill = spatial_dist_fill,
77
- with_type = with_type,
78
- sep_between_neighbors = sep_between_neighbors,
79
- label_encoder = label_encoder,
80
- num_neighbor_limit = num_neighbor_limit,
81
- mode = 'test',)
82
-
83
- california_dataset = PbfMapDataset(data_file_path = california_file_path,
84
- tokenizer = tokenizer,
85
- max_token_len = max_token_len,
86
- distance_norm_factor = distance_norm_factor,
87
- spatial_dist_fill = spatial_dist_fill,
88
- with_type = with_type,
89
- sep_between_neighbors = sep_between_neighbors,
90
- label_encoder = label_encoder,
91
- num_neighbor_limit = num_neighbor_limit,
92
- mode = 'test')
93
-
94
- test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
95
-
96
-
97
-
98
- test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
99
- shuffle=False, pin_memory=True, drop_last=False)
100
-
101
-
102
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
103
- model.to(device)
104
-
105
-
106
- model.load_state_dict(torch.load(checkpoint_path)) # #
107
-
108
- model.eval()
109
-
110
-
111
-
112
- print('start testing...')
113
-
114
-
115
- # setup loop with TQDM and dataloader
116
- loop = tqdm(test_loader, leave=True)
117
-
118
-
119
- mrr_total = 0.
120
- prec_total = 0.
121
- sample_cnt = 0
122
-
123
- gt_list = []
124
- pred_list = []
125
-
126
- for batch in loop:
127
- # initialize calculated gradients (from prev step)
128
-
129
- # pull all tensor batches required for training
130
- input_ids = batch['pseudo_sentence'].to(device)
131
- attention_mask = batch['attention_mask'].to(device)
132
- position_list_x = batch['norm_lng_list'].to(device)
133
- position_list_y = batch['norm_lat_list'].to(device)
134
- sent_position_ids = batch['sent_position_ids'].to(device)
135
-
136
- #labels = batch['pseudo_sentence'].to(device)
137
- labels = batch['pivot_type'].to(device)
138
- pivot_lens = batch['pivot_token_len'].to(device)
139
-
140
- outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
141
- position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
142
-
143
-
144
- onehot_labels = F.one_hot(labels, num_classes=9)
145
-
146
- gt_list.extend(onehot_labels.cpu().detach().numpy())
147
- pred_list.extend(outputs.logits.cpu().detach().numpy())
148
-
149
- #pdb.set_trace()
150
- mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
151
- mrr_total += mrr * input_ids.shape[0]
152
- sample_cnt += input_ids.shape[0]
153
-
154
- precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
155
-
156
- precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
157
-
158
- # print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
159
- # print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
160
- # print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
161
- # print('supports:\n', supports)
162
- print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
163
-
164
-
165
-
166
- def main():
167
-
168
- parser = argparse.ArgumentParser()
169
-
170
- parser.add_argument('--max_token_len', type=int, default=300)
171
- parser.add_argument('--batch_size', type=int, default=12)
172
- parser.add_argument('--num_workers', type=int, default=5)
173
-
174
- parser.add_argument('--num_neighbor_limit', type=int, default = None)
175
-
176
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
177
- parser.add_argument('--spatial_dist_fill', type=float, default = 20)
178
-
179
-
180
- parser.add_argument('--with_type', default=False, action='store_true')
181
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
182
- parser.add_argument('--no_spatial_distance', default=False, action='store_true')
183
-
184
- parser.add_argument('--bert_option', type=str, default='bert-base')
185
- parser.add_argument('--prediction_save_dir', type=str, default=None)
186
-
187
- parser.add_argument('--checkpoint_path', type=str, default=None)
188
-
189
-
190
-
191
- args = parser.parse_args()
192
- print('\n')
193
- print(args)
194
- print('\n')
195
-
196
-
197
- # out_dir not None, and out_dir does not exist, then create out_dir
198
- if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
199
- os.makedirs(args.prediction_save_dir)
200
-
201
- testing(args)
202
-
203
-
204
-
205
- if __name__ == '__main__':
206
-
207
- main()
208
-
209
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/test_cls_baseline.py DELETED
@@ -1,189 +0,0 @@
1
- import os
2
- import sys
3
- from tqdm import tqdm # for our progress bar
4
- import numpy as np
5
- import argparse
6
- from sklearn.preprocessing import LabelEncoder
7
- import pdb
8
-
9
-
10
- import torch
11
- from torch.utils.data import DataLoader
12
- from transformers import AdamW
13
- import torch.nn.functional as F
14
-
15
- sys.path.append('../../../')
16
- from datasets.osm_sample_loader import PbfMapDataset
17
- from datasets.const import *
18
- from utils.baseline_utils import get_baseline_model
19
- from models.baseline_typing_model import BaselineForSemanticTyping
20
-
21
- from sklearn.metrics import label_ranking_average_precision_score
22
- from sklearn.metrics import precision_recall_fscore_support
23
-
24
- torch.backends.cudnn.deterministic = True
25
- torch.backends.cudnn.benchmark = False
26
- torch.manual_seed(42)
27
- torch.cuda.manual_seed_all(42)
28
-
29
-
30
-
31
- MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
32
- 'spanbert-base','spanbert-large','luke-base','luke-large',
33
- 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
34
-
35
- def testing(args):
36
-
37
- num_workers = args.num_workers
38
- batch_size = args.batch_size
39
- max_token_len = args.max_token_len
40
-
41
- distance_norm_factor = args.distance_norm_factor
42
- spatial_dist_fill=args.spatial_dist_fill
43
- with_type = args.with_type
44
- sep_between_neighbors = args.sep_between_neighbors
45
- freeze_backbone = args.freeze_backbone
46
-
47
-
48
- backbone_option = args.backbone_option
49
-
50
- checkpoint_path = args.checkpoint_path
51
-
52
- assert(backbone_option in MODEL_OPTIONS)
53
-
54
-
55
- london_file_path = '../data/sql_output/osm-point-london-typing.json'
56
- california_file_path = '../data/sql_output/osm-point-california-typing.json'
57
-
58
-
59
-
60
- label_encoder = LabelEncoder()
61
- label_encoder.fit(CLASS_9_LIST)
62
-
63
-
64
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
65
-
66
- backbone_model, tokenizer = get_baseline_model(backbone_option)
67
- model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
68
-
69
- model.load_state_dict(torch.load(checkpoint_path) ) #, strict = False # load sentence position embedding weights as well
70
-
71
- model.to(device)
72
- model.train()
73
-
74
-
75
-
76
- london_dataset = PbfMapDataset(data_file_path = london_file_path,
77
- tokenizer = tokenizer,
78
- max_token_len = max_token_len,
79
- distance_norm_factor = distance_norm_factor,
80
- spatial_dist_fill = spatial_dist_fill,
81
- with_type = with_type,
82
- sep_between_neighbors = sep_between_neighbors,
83
- label_encoder = label_encoder,
84
- mode = 'test')
85
-
86
-
87
- california_dataset = PbfMapDataset(data_file_path = california_file_path,
88
- tokenizer = tokenizer,
89
- max_token_len = max_token_len,
90
- distance_norm_factor = distance_norm_factor,
91
- spatial_dist_fill = spatial_dist_fill,
92
- with_type = with_type,
93
- sep_between_neighbors = sep_between_neighbors,
94
- label_encoder = label_encoder,
95
- mode = 'test')
96
-
97
- test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
98
-
99
-
100
- test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
101
- shuffle=False, pin_memory=True, drop_last=False)
102
-
103
-
104
-
105
-
106
-
107
- print('start testing...')
108
-
109
- # setup loop with TQDM and dataloader
110
- loop = tqdm(test_loader, leave=True)
111
-
112
- mrr_total = 0.
113
- prec_total = 0.
114
- sample_cnt = 0
115
-
116
- gt_list = []
117
- pred_list = []
118
-
119
- for batch in loop:
120
- # initialize calculated gradients (from prev step)
121
-
122
- # pull all tensor batches required for training
123
- input_ids = batch['pseudo_sentence'].to(device)
124
- attention_mask = batch['attention_mask'].to(device)
125
- position_ids = batch['sent_position_ids'].to(device)
126
-
127
- #labels = batch['pseudo_sentence'].to(device)
128
- labels = batch['pivot_type'].to(device)
129
- pivot_lens = batch['pivot_token_len'].to(device)
130
-
131
- outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
132
- labels = labels, pivot_len_list = pivot_lens)
133
-
134
-
135
- onehot_labels = F.one_hot(labels, num_classes=9)
136
-
137
- gt_list.extend(onehot_labels.cpu().detach().numpy())
138
- pred_list.extend(outputs.logits.cpu().detach().numpy())
139
-
140
- mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
141
- mrr_total += mrr * input_ids.shape[0]
142
- sample_cnt += input_ids.shape[0]
143
-
144
- precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
145
- precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
146
- print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
147
- print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
148
- print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
149
- print('supports:\n', supports)
150
- print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
151
-
152
- #pdb.set_trace()
153
- #print(mrr_total/sample_cnt)
154
-
155
-
156
-
157
- def main():
158
-
159
- parser = argparse.ArgumentParser()
160
- parser.add_argument('--num_workers', type=int, default=5)
161
- parser.add_argument('--batch_size', type=int, default=12)
162
- parser.add_argument('--max_token_len', type=int, default=300)
163
-
164
-
165
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
166
- parser.add_argument('--spatial_dist_fill', type=float, default = 20)
167
-
168
- parser.add_argument('--with_type', default=False, action='store_true')
169
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
170
- parser.add_argument('--freeze_backbone', default=False, action='store_true')
171
-
172
- parser.add_argument('--backbone_option', type=str, default='bert-base')
173
- parser.add_argument('--checkpoint_path', type=str, default=None)
174
-
175
-
176
- args = parser.parse_args()
177
- print('\n')
178
- print(args)
179
- print('\n')
180
-
181
-
182
- testing(args)
183
-
184
-
185
- if __name__ == '__main__':
186
-
187
- main()
188
-
189
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py DELETED
@@ -1,214 +0,0 @@
1
- import os
2
- import sys
3
- from transformers import RobertaTokenizer, BertTokenizer
4
- from tqdm import tqdm # for our progress bar
5
- from transformers import AdamW
6
-
7
- import torch
8
- from torch.utils.data import DataLoader
9
- import torch.nn.functional as F
10
-
11
- sys.path.append('../../../')
12
- from models.spatial_bert_model import SpatialBertModel
13
- from models.spatial_bert_model import SpatialBertConfig
14
- from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
15
- from datasets.osm_sample_loader import PbfMapDataset
16
- from datasets.const import *
17
- from transformers.models.bert.modeling_bert import BertForMaskedLM
18
-
19
- from sklearn.metrics import label_ranking_average_precision_score
20
- from sklearn.metrics import precision_recall_fscore_support
21
- import numpy as np
22
- import argparse
23
- from sklearn.preprocessing import LabelEncoder
24
- import pdb
25
-
26
-
27
- DEBUG = False
28
-
29
- torch.backends.cudnn.deterministic = True
30
- torch.backends.cudnn.benchmark = False
31
- torch.manual_seed(42)
32
- torch.cuda.manual_seed_all(42)
33
-
34
-
35
- def testing(args):
36
-
37
- max_token_len = args.max_token_len
38
- batch_size = args.batch_size
39
- num_workers = args.num_workers
40
- distance_norm_factor = args.distance_norm_factor
41
- spatial_dist_fill=args.spatial_dist_fill
42
- with_type = args.with_type
43
- sep_between_neighbors = args.sep_between_neighbors
44
- checkpoint_path = args.checkpoint_path
45
- if_no_spatial_distance = args.no_spatial_distance
46
-
47
- bert_option = args.bert_option
48
-
49
-
50
-
51
- if args.num_classes == 9:
52
- london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
53
- california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
54
- TYPE_LIST = CLASS_9_LIST
55
- type_key_str = 'class'
56
- elif args.num_classes == 74:
57
- london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
58
- california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
59
- TYPE_LIST = CLASS_74_LIST
60
- type_key_str = 'fine_class'
61
- else:
62
- raise NotImplementedError
63
-
64
- if bert_option == 'bert-base':
65
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
66
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
67
- elif bert_option == 'bert-large':
68
- tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
69
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(TYPE_LIST))
70
- else:
71
- raise NotImplementedError
72
-
73
-
74
- model = SpatialBertForSemanticTyping(config)
75
-
76
-
77
- label_encoder = LabelEncoder()
78
- label_encoder.fit(TYPE_LIST)
79
-
80
- london_dataset = PbfMapDataset(data_file_path = london_file_path,
81
- tokenizer = tokenizer,
82
- max_token_len = max_token_len,
83
- distance_norm_factor = distance_norm_factor,
84
- spatial_dist_fill = spatial_dist_fill,
85
- with_type = with_type,
86
- type_key_str = type_key_str,
87
- sep_between_neighbors = sep_between_neighbors,
88
- label_encoder = label_encoder,
89
- mode = 'test')
90
-
91
- california_dataset = PbfMapDataset(data_file_path = california_file_path,
92
- tokenizer = tokenizer,
93
- max_token_len = max_token_len,
94
- distance_norm_factor = distance_norm_factor,
95
- spatial_dist_fill = spatial_dist_fill,
96
- with_type = with_type,
97
- type_key_str = type_key_str,
98
- sep_between_neighbors = sep_between_neighbors,
99
- label_encoder = label_encoder,
100
- mode = 'test')
101
-
102
- test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
103
-
104
-
105
-
106
- test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
107
- shuffle=False, pin_memory=True, drop_last=False)
108
-
109
-
110
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
111
- model.to(device)
112
-
113
- model.load_state_dict(torch.load(checkpoint_path))
114
-
115
- model.eval()
116
-
117
-
118
-
119
- print('start testing...')
120
-
121
-
122
- # setup loop with TQDM and dataloader
123
- loop = tqdm(test_loader, leave=True)
124
-
125
-
126
- mrr_total = 0.
127
- prec_total = 0.
128
- sample_cnt = 0
129
-
130
- gt_list = []
131
- pred_list = []
132
-
133
- for batch in loop:
134
- # initialize calculated gradients (from prev step)
135
-
136
- # pull all tensor batches required for training
137
- input_ids = batch['pseudo_sentence'].to(device)
138
- attention_mask = batch['attention_mask'].to(device)
139
- position_list_x = batch['norm_lng_list'].to(device)
140
- position_list_y = batch['norm_lat_list'].to(device)
141
- sent_position_ids = batch['sent_position_ids'].to(device)
142
-
143
- #labels = batch['pseudo_sentence'].to(device)
144
- labels = batch['pivot_type'].to(device)
145
- pivot_lens = batch['pivot_token_len'].to(device)
146
-
147
- outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
148
- position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
149
-
150
-
151
- onehot_labels = F.one_hot(labels, num_classes=len(TYPE_LIST))
152
-
153
- gt_list.extend(onehot_labels.cpu().detach().numpy())
154
- pred_list.extend(outputs.logits.cpu().detach().numpy())
155
-
156
- mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
157
- mrr_total += mrr * input_ids.shape[0]
158
- sample_cnt += input_ids.shape[0]
159
-
160
- precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
161
- # print('precisions:\n', precisions)
162
- # print('recalls:\n', recalls)
163
- # print('fscores:\n', fscores)
164
- # print('supports:\n', supports)
165
- precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
166
- print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
167
- print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
168
- print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
169
- print('supports:\n', supports)
170
- print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
171
-
172
-
173
-
174
-
175
- def main():
176
-
177
- parser = argparse.ArgumentParser()
178
-
179
- parser.add_argument('--max_token_len', type=int, default=512)
180
- parser.add_argument('--batch_size', type=int, default=12)
181
- parser.add_argument('--num_workers', type=int, default=5)
182
-
183
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
184
- parser.add_argument('--spatial_dist_fill', type=float, default = 100)
185
- parser.add_argument('--num_classes', type=int, default = 9)
186
-
187
- parser.add_argument('--with_type', default=False, action='store_true')
188
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
189
- parser.add_argument('--no_spatial_distance', default=False, action='store_true')
190
-
191
- parser.add_argument('--bert_option', type=str, default='bert-base')
192
- parser.add_argument('--prediction_save_dir', type=str, default=None)
193
-
194
- parser.add_argument('--checkpoint_path', type=str, default=None)
195
-
196
-
197
- args = parser.parse_args()
198
- print('\n')
199
- print(args)
200
- print('\n')
201
-
202
-
203
- # out_dir not None, and out_dir does not exist, then create out_dir
204
- if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
205
- os.makedirs(args.prediction_save_dir)
206
-
207
- testing(args)
208
-
209
-
210
- if __name__ == '__main__':
211
-
212
- main()
213
-
214
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/train_cls_baseline.py DELETED
@@ -1,227 +0,0 @@
1
- import os
2
- import sys
3
- from tqdm import tqdm # for our progress bar
4
- import numpy as np
5
- import argparse
6
- from sklearn.preprocessing import LabelEncoder
7
- import pdb
8
-
9
-
10
- import torch
11
- from torch.utils.data import DataLoader
12
- from transformers import AdamW
13
-
14
- sys.path.append('../../../')
15
- from datasets.osm_sample_loader import PbfMapDataset
16
- from datasets.const import *
17
- from utils.baseline_utils import get_baseline_model
18
- from models.baseline_typing_model import BaselineForSemanticTyping
19
-
20
-
21
- MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
22
- 'spanbert-base','spanbert-large','luke-base','luke-large',
23
- 'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
24
-
25
- def training(args):
26
-
27
- num_workers = args.num_workers
28
- batch_size = args.batch_size
29
- epochs = args.epochs
30
- lr = args.lr #1e-7 # 5e-5
31
- save_interval = args.save_interval
32
- max_token_len = args.max_token_len
33
- distance_norm_factor = args.distance_norm_factor
34
- spatial_dist_fill=args.spatial_dist_fill
35
- with_type = args.with_type
36
- sep_between_neighbors = args.sep_between_neighbors
37
- freeze_backbone = args.freeze_backbone
38
-
39
-
40
- backbone_option = args.backbone_option
41
-
42
- assert(backbone_option in MODEL_OPTIONS)
43
-
44
-
45
- london_file_path = '../data/sql_output/osm-point-london-typing.json'
46
- california_file_path = '../data/sql_output/osm-point-california-typing.json'
47
-
48
- if args.model_save_dir is None:
49
- freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
50
- sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
51
- model_save_dir = '/data2/zekun/spatial_bert_baseline_weights/typing_lr' + str("{:.0e}".format(lr)) +'_'+backbone_option+ freeze_pathstr + sep_pathstr + '_london_california_bsize' + str(batch_size)
52
-
53
- if not os.path.isdir(model_save_dir):
54
- os.makedirs(model_save_dir)
55
- else:
56
- model_save_dir = args.model_save_dir
57
-
58
- print('model_save_dir', model_save_dir)
59
- print('\n')
60
-
61
-
62
- label_encoder = LabelEncoder()
63
- label_encoder.fit(CLASS_9_LIST)
64
-
65
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
66
-
67
- backbone_model, tokenizer = get_baseline_model(backbone_option)
68
- model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
69
-
70
- model.to(device)
71
- model.train()
72
-
73
- london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
74
- tokenizer = tokenizer,
75
- max_token_len = max_token_len,
76
- distance_norm_factor = distance_norm_factor,
77
- spatial_dist_fill = spatial_dist_fill,
78
- with_type = with_type,
79
- sep_between_neighbors = sep_between_neighbors,
80
- label_encoder = label_encoder,
81
- mode = 'train')
82
-
83
- percent_80 = int(len(london_train_val_dataset) * 0.8)
84
- london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
85
-
86
- california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
87
- tokenizer = tokenizer,
88
- max_token_len = max_token_len,
89
- distance_norm_factor = distance_norm_factor,
90
- spatial_dist_fill = spatial_dist_fill,
91
- with_type = with_type,
92
- sep_between_neighbors = sep_between_neighbors,
93
- label_encoder = label_encoder,
94
- mode = 'train')
95
- percent_80 = int(len(california_train_val_dataset) * 0.8)
96
- california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
97
-
98
- train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
99
- val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
100
-
101
-
102
-
103
- train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
104
- shuffle=True, pin_memory=True, drop_last=True)
105
- val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
106
- shuffle=False, pin_memory=True, drop_last=False)
107
-
108
-
109
-
110
-
111
-
112
-
113
-
114
- # initialize optimizer
115
- optim = AdamW(model.parameters(), lr = lr)
116
-
117
- print('start training...')
118
-
119
- for epoch in range(epochs):
120
- # setup loop with TQDM and dataloader
121
- loop = tqdm(train_loader, leave=True)
122
- iter = 0
123
- for batch in loop:
124
- # initialize calculated gradients (from prev step)
125
- optim.zero_grad()
126
- # pull all tensor batches required for training
127
- input_ids = batch['pseudo_sentence'].to(device)
128
- attention_mask = batch['attention_mask'].to(device)
129
- position_ids = batch['sent_position_ids'].to(device)
130
-
131
- #labels = batch['pseudo_sentence'].to(device)
132
- labels = batch['pivot_type'].to(device)
133
- pivot_lens = batch['pivot_token_len'].to(device)
134
-
135
- outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
136
- labels = labels, pivot_len_list = pivot_lens)
137
-
138
-
139
- loss = outputs.loss
140
- loss.backward()
141
- optim.step()
142
-
143
- loop.set_description(f'Epoch {epoch}')
144
- loop.set_postfix({'loss':loss.item()})
145
-
146
-
147
- iter += 1
148
-
149
- if iter % save_interval == 0 or iter == loop.total:
150
- loss_valid = validating(val_loader, model, device)
151
- print('validation loss', loss_valid)
152
-
153
- save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
154
- + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
155
-
156
- torch.save(model.state_dict(), save_path)
157
- print('saving model checkpoint to', save_path)
158
-
159
-
160
-
161
- def validating(val_loader, model, device):
162
-
163
- with torch.no_grad():
164
-
165
- loss_valid = 0
166
- loop = tqdm(val_loader, leave=True)
167
-
168
- for batch in loop:
169
- input_ids = batch['pseudo_sentence'].to(device)
170
- attention_mask = batch['attention_mask'].to(device)
171
- position_ids = batch['sent_position_ids'].to(device)
172
-
173
- labels = batch['pivot_type'].to(device)
174
- pivot_lens = batch['pivot_token_len'].to(device)
175
-
176
- outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
177
- labels = labels, pivot_len_list = pivot_lens)
178
-
179
- loss_valid += outputs.loss
180
-
181
- loss_valid /= len(val_loader)
182
-
183
- return loss_valid
184
-
185
-
186
- def main():
187
-
188
- parser = argparse.ArgumentParser()
189
- parser.add_argument('--num_workers', type=int, default=5)
190
- parser.add_argument('--batch_size', type=int, default=12)
191
- parser.add_argument('--epochs', type=int, default=10)
192
- parser.add_argument('--save_interval', type=int, default=2000)
193
- parser.add_argument('--max_token_len', type=int, default=300)
194
-
195
-
196
- parser.add_argument('--lr', type=float, default = 5e-5)
197
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
198
- parser.add_argument('--spatial_dist_fill', type=float, default = 20)
199
-
200
- parser.add_argument('--with_type', default=False, action='store_true')
201
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
202
- parser.add_argument('--freeze_backbone', default=False, action='store_true')
203
-
204
- parser.add_argument('--backbone_option', type=str, default='bert-base')
205
- parser.add_argument('--model_save_dir', type=str, default=None)
206
-
207
-
208
-
209
- args = parser.parse_args()
210
- print('\n')
211
- print(args)
212
- print('\n')
213
-
214
-
215
- # out_dir not None, and out_dir does not exist, then create out_dir
216
- if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
217
- os.makedirs(args.model_save_dir)
218
-
219
- training(args)
220
-
221
-
222
-
223
- if __name__ == '__main__':
224
-
225
- main()
226
-
227
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py DELETED
@@ -1,276 +0,0 @@
1
- import os
2
- import sys
3
- from transformers import RobertaTokenizer, BertTokenizer
4
- from tqdm import tqdm # for our progress bar
5
- from transformers import AdamW
6
-
7
- import torch
8
- from torch.utils.data import DataLoader
9
-
10
- sys.path.append('../../../')
11
- from models.spatial_bert_model import SpatialBertModel
12
- from models.spatial_bert_model import SpatialBertConfig
13
- from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
14
- from datasets.osm_sample_loader import PbfMapDataset
15
- from datasets.const import *
16
-
17
- from transformers.models.bert.modeling_bert import BertForMaskedLM
18
-
19
- import numpy as np
20
- import argparse
21
- from sklearn.preprocessing import LabelEncoder
22
- import pdb
23
-
24
-
25
- DEBUG = False
26
-
27
-
28
- def training(args):
29
-
30
- num_workers = args.num_workers
31
- batch_size = args.batch_size
32
- epochs = args.epochs
33
- lr = args.lr #1e-7 # 5e-5
34
- save_interval = args.save_interval
35
- max_token_len = args.max_token_len
36
- distance_norm_factor = args.distance_norm_factor
37
- spatial_dist_fill=args.spatial_dist_fill
38
- with_type = args.with_type
39
- sep_between_neighbors = args.sep_between_neighbors
40
- freeze_backbone = args.freeze_backbone
41
- mlm_checkpoint_path = args.mlm_checkpoint_path
42
-
43
- if_no_spatial_distance = args.no_spatial_distance
44
-
45
-
46
- bert_option = args.bert_option
47
-
48
- assert bert_option in ['bert-base','bert-large']
49
-
50
- if args.num_classes == 9:
51
- london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
52
- california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
53
- TYPE_LIST = CLASS_9_LIST
54
- type_key_str = 'class'
55
- elif args.num_classes == 74:
56
- london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
57
- california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
58
- TYPE_LIST = CLASS_74_LIST
59
- type_key_str = 'fine_class'
60
- else:
61
- raise NotImplementedError
62
-
63
-
64
- if args.model_save_dir is None:
65
- checkpoint_basename = os.path.basename(mlm_checkpoint_path)
66
- checkpoint_prefix = checkpoint_basename.replace("mlm_mem_keeppos_","").strip('.pth')
67
-
68
- sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
69
- freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
70
- if if_no_spatial_distance:
71
- model_save_dir = '/data2/zekun/spatial_bert_weights_ablation/'
72
- else:
73
- model_save_dir = '/data2/zekun/spatial_bert_weights/'
74
- model_save_dir = os.path.join(model_save_dir, 'typing_lr' + str("{:.0e}".format(lr)) + sep_pathstr +'_'+bert_option+ freeze_pathstr + '_london_california_bsize' + str(batch_size) )
75
- model_save_dir = os.path.join(model_save_dir, checkpoint_prefix)
76
-
77
- if not os.path.isdir(model_save_dir):
78
- os.makedirs(model_save_dir)
79
- else:
80
- model_save_dir = args.model_save_dir
81
-
82
-
83
- print('model_save_dir', model_save_dir)
84
- print('\n')
85
-
86
- if bert_option == 'bert-base':
87
- tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
88
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
89
- elif bert_option == 'bert-large':
90
- tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
91
- config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24, num_semantic_types=len(TYPE_LIST))
92
- else:
93
- raise NotImplementedError
94
-
95
-
96
-
97
- label_encoder = LabelEncoder()
98
- label_encoder.fit(TYPE_LIST)
99
-
100
-
101
- london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
102
- tokenizer = tokenizer,
103
- max_token_len = max_token_len,
104
- distance_norm_factor = distance_norm_factor,
105
- spatial_dist_fill = spatial_dist_fill,
106
- with_type = with_type,
107
- type_key_str = type_key_str,
108
- sep_between_neighbors = sep_between_neighbors,
109
- label_encoder = label_encoder,
110
- mode = 'train')
111
-
112
- percent_80 = int(len(london_train_val_dataset) * 0.8)
113
- london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
114
-
115
- california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
116
- tokenizer = tokenizer,
117
- max_token_len = max_token_len,
118
- distance_norm_factor = distance_norm_factor,
119
- spatial_dist_fill = spatial_dist_fill,
120
- with_type = with_type,
121
- type_key_str = type_key_str,
122
- sep_between_neighbors = sep_between_neighbors,
123
- label_encoder = label_encoder,
124
- mode = 'train')
125
- percent_80 = int(len(california_train_val_dataset) * 0.8)
126
- california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
127
-
128
- train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
129
- val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
130
-
131
-
132
- if DEBUG:
133
- train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
134
- shuffle=False, pin_memory=True, drop_last=True)
135
- val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
136
- shuffle=False, pin_memory=True, drop_last=False)
137
- else:
138
- train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
139
- shuffle=True, pin_memory=True, drop_last=True)
140
- val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
141
- shuffle=False, pin_memory=True, drop_last=False)
142
-
143
-
144
- device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
145
-
146
- model = SpatialBertForSemanticTyping(config)
147
- model.to(device)
148
-
149
-
150
- model.load_state_dict(torch.load(mlm_checkpoint_path), strict = False)
151
-
152
- model.train()
153
-
154
-
155
-
156
- # initialize optimizer
157
- optim = AdamW(model.parameters(), lr = lr)
158
-
159
- print('start training...')
160
-
161
- for epoch in range(epochs):
162
- # setup loop with TQDM and dataloader
163
- loop = tqdm(train_loader, leave=True)
164
- iter = 0
165
- for batch in loop:
166
- # initialize calculated gradients (from prev step)
167
- optim.zero_grad()
168
- # pull all tensor batches required for training
169
- input_ids = batch['pseudo_sentence'].to(device)
170
- attention_mask = batch['attention_mask'].to(device)
171
- position_list_x = batch['norm_lng_list'].to(device)
172
- position_list_y = batch['norm_lat_list'].to(device)
173
- sent_position_ids = batch['sent_position_ids'].to(device)
174
-
175
- #labels = batch['pseudo_sentence'].to(device)
176
- labels = batch['pivot_type'].to(device)
177
- pivot_lens = batch['pivot_token_len'].to(device)
178
-
179
- outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
180
- position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
181
-
182
-
183
- loss = outputs.loss
184
- loss.backward()
185
- optim.step()
186
-
187
- loop.set_description(f'Epoch {epoch}')
188
- loop.set_postfix({'loss':loss.item()})
189
-
190
- if DEBUG:
191
- print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() )
192
-
193
- iter += 1
194
-
195
- if iter % save_interval == 0 or iter == loop.total:
196
- loss_valid = validating(val_loader, model, device)
197
-
198
- save_path = os.path.join(model_save_dir, 'keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
199
- + '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
200
-
201
- torch.save(model.state_dict(), save_path)
202
- print('validation loss', loss_valid)
203
- print('saving model checkpoint to', save_path)
204
-
205
- def validating(val_loader, model, device):
206
-
207
- with torch.no_grad():
208
-
209
- loss_valid = 0
210
- loop = tqdm(val_loader, leave=True)
211
-
212
- for batch in loop:
213
- input_ids = batch['pseudo_sentence'].to(device)
214
- attention_mask = batch['attention_mask'].to(device)
215
- position_list_x = batch['norm_lng_list'].to(device)
216
- position_list_y = batch['norm_lat_list'].to(device)
217
- sent_position_ids = batch['sent_position_ids'].to(device)
218
-
219
-
220
- labels = batch['pivot_type'].to(device)
221
- pivot_lens = batch['pivot_token_len'].to(device)
222
-
223
- outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
224
- position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
225
-
226
- loss_valid += outputs.loss
227
-
228
- loss_valid /= len(val_loader)
229
-
230
- return loss_valid
231
-
232
-
233
- def main():
234
-
235
- parser = argparse.ArgumentParser()
236
- parser.add_argument('--num_workers', type=int, default=5)
237
- parser.add_argument('--batch_size', type=int, default=12)
238
- parser.add_argument('--epochs', type=int, default=10)
239
- parser.add_argument('--save_interval', type=int, default=2000)
240
- parser.add_argument('--max_token_len', type=int, default=512)
241
-
242
-
243
- parser.add_argument('--lr', type=float, default = 5e-5)
244
- parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
245
- parser.add_argument('--spatial_dist_fill', type=float, default = 100)
246
- parser.add_argument('--num_classes', type=int, default = 9)
247
-
248
- parser.add_argument('--with_type', default=False, action='store_true')
249
- parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
250
- parser.add_argument('--freeze_backbone', default=False, action='store_true')
251
- parser.add_argument('--no_spatial_distance', default=False, action='store_true')
252
-
253
- parser.add_argument('--bert_option', type=str, default='bert-base')
254
- parser.add_argument('--model_save_dir', type=str, default=None)
255
-
256
- parser.add_argument('--mlm_checkpoint_path', type=str, default=None)
257
-
258
-
259
- args = parser.parse_args()
260
- print('\n')
261
- print(args)
262
- print('\n')
263
-
264
-
265
- # out_dir not None, and out_dir does not exist, then create out_dir
266
- if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
267
- os.makedirs(args.model_save_dir)
268
-
269
- training(args)
270
-
271
-
272
- if __name__ == '__main__':
273
-
274
- main()
275
-
276
-