Spaces:
Running
Running
Delete models/spabert/experiments
Browse files- models/spabert/experiments/__init__.py +0 -0
- models/spabert/experiments/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/__init__.py +0 -0
- models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__init__.py +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc +0 -0
- models/spabert/experiments/entity_matching/data_processing/get_namelist.py +0 -95
- models/spabert/experiments/entity_matching/data_processing/request_wrapper.py +0 -186
- models/spabert/experiments/entity_matching/data_processing/run_linking_query.py +0 -143
- models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py +0 -123
- models/spabert/experiments/entity_matching/data_processing/run_query_sample.py +0 -22
- models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py +0 -31
- models/spabert/experiments/entity_matching/data_processing/samples.sparql +0 -22
- models/spabert/experiments/entity_matching/data_processing/select_ambi.py +0 -18
- models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json +0 -0
- models/spabert/experiments/entity_matching/src/evaluation-mrr.py +0 -260
- models/spabert/experiments/entity_matching/src/linking_ablation.py +0 -228
- models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py +0 -329
- models/spabert/experiments/semantic_typing/__init__.py +0 -0
- models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py +0 -97
- models/spabert/experiments/semantic_typing/src/__init__.py +0 -0
- models/spabert/experiments/semantic_typing/src/run_baseline_test.py +0 -82
- models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py +0 -209
- models/spabert/experiments/semantic_typing/src/test_cls_baseline.py +0 -189
- models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py +0 -214
- models/spabert/experiments/semantic_typing/src/train_cls_baseline.py +0 -227
- models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py +0 -276
models/spabert/experiments/__init__.py
DELETED
File without changes
|
models/spabert/experiments/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (155 Bytes)
|
|
models/spabert/experiments/entity_matching/__init__.py
DELETED
File without changes
|
models/spabert/experiments/entity_matching/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (171 Bytes)
|
|
models/spabert/experiments/entity_matching/data_processing/__init__.py
DELETED
File without changes
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-310.pyc
DELETED
Binary file (187 Bytes)
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/__init__.cpython-311.pyc
DELETED
Binary file (204 Bytes)
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-310.pyc
DELETED
Binary file (5.97 kB)
|
|
models/spabert/experiments/entity_matching/data_processing/__pycache__/request_wrapper.cpython-311.pyc
DELETED
Binary file (7.83 kB)
|
|
models/spabert/experiments/entity_matching/data_processing/get_namelist.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
-
|
4 |
-
def get_name_list_osm(ref_paths):
|
5 |
-
name_list = []
|
6 |
-
|
7 |
-
for json_path in ref_paths:
|
8 |
-
with open(json_path, 'r') as f:
|
9 |
-
data = f.readlines()
|
10 |
-
for line in data:
|
11 |
-
record = json.loads(line)
|
12 |
-
name = record['name']
|
13 |
-
name_list.append(name)
|
14 |
-
|
15 |
-
namelist = sorted(namelist)
|
16 |
-
return name_list
|
17 |
-
|
18 |
-
# deprecated
|
19 |
-
def get_name_list_usgs_od(ref_paths):
|
20 |
-
name_list = []
|
21 |
-
|
22 |
-
for json_path in ref_paths:
|
23 |
-
with open(json_path, 'r') as f:
|
24 |
-
annot_dict = json.load(f)
|
25 |
-
for key, place in annot_dict.items():
|
26 |
-
place_name = ''
|
27 |
-
for idx in range(1, len(place)+1):
|
28 |
-
try:
|
29 |
-
place_name += place[str(idx)]['text_label']
|
30 |
-
place_name += ' ' # separate words with spaces
|
31 |
-
|
32 |
-
except Exception as e:
|
33 |
-
print(place)
|
34 |
-
place_name = place_name[:-1] # remove last space
|
35 |
-
|
36 |
-
name_list.append(place_name)
|
37 |
-
|
38 |
-
namelist = sorted(namelist)
|
39 |
-
return name_list
|
40 |
-
|
41 |
-
def get_name_list_usgs_od_per_map(ref_paths):
|
42 |
-
all_name_list_dict = dict()
|
43 |
-
|
44 |
-
for json_path in ref_paths:
|
45 |
-
map_name = os.path.basename(json_path).split('.json')[0]
|
46 |
-
|
47 |
-
with open(json_path, 'r') as f:
|
48 |
-
annot_dict = json.load(f)
|
49 |
-
|
50 |
-
map_name_list = []
|
51 |
-
for key, place in annot_dict.items():
|
52 |
-
place_name = ''
|
53 |
-
for idx in range(1, len(place)+1):
|
54 |
-
try:
|
55 |
-
place_name += place[str(idx)]['text_label']
|
56 |
-
place_name += ' ' # separate words with spaces
|
57 |
-
|
58 |
-
except Exception as e:
|
59 |
-
print(place)
|
60 |
-
place_name = place_name[:-1] # remove last space
|
61 |
-
|
62 |
-
map_name_list.append(place_name)
|
63 |
-
all_name_list_dict[map_name] = sorted(map_name_list)
|
64 |
-
|
65 |
-
return all_name_list_dict
|
66 |
-
|
67 |
-
|
68 |
-
def get_name_list_gb1900(ref_path):
|
69 |
-
name_list = []
|
70 |
-
|
71 |
-
with open(ref_path, 'r',encoding='utf-16') as f:
|
72 |
-
data = f.readlines()
|
73 |
-
|
74 |
-
|
75 |
-
for line in data[1:]: # skip the header
|
76 |
-
try:
|
77 |
-
line = line.split(',')
|
78 |
-
text = line[1]
|
79 |
-
lat = float(line[-3])
|
80 |
-
lng = float(line[-2])
|
81 |
-
semantic_type = line[-1]
|
82 |
-
|
83 |
-
name_list.append(text)
|
84 |
-
except:
|
85 |
-
print(line)
|
86 |
-
|
87 |
-
namelist = sorted(namelist)
|
88 |
-
|
89 |
-
return name_list
|
90 |
-
|
91 |
-
|
92 |
-
if __name__ == '__main__':
|
93 |
-
#name_list = get_name_list_usgs_od(['labGISReport-master/output/USGS-15-CA-brawley-e1957-s1957-p1961.json',
|
94 |
-
#'labGISReport-master/output/USGS-15-CA-capesanmartin-e1921-s1917.json'])
|
95 |
-
name_list = get_name_list_gb1900('data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/request_wrapper.py
DELETED
@@ -1,186 +0,0 @@
|
|
1 |
-
import requests
|
2 |
-
import pdb
|
3 |
-
import time
|
4 |
-
|
5 |
-
# for linkedgeodata: #http://linkedgeodata.org/sparql
|
6 |
-
|
7 |
-
class RequestWrapper:
|
8 |
-
def __init__(self, baseuri = "https://query.wikidata.org/sparql"):
|
9 |
-
|
10 |
-
self.baseuri = baseuri
|
11 |
-
|
12 |
-
def response_handler(self, response, query):
|
13 |
-
if response.status_code == requests.codes.ok:
|
14 |
-
ret_json = response.json()['results']['bindings']
|
15 |
-
elif response.status_code == 500:
|
16 |
-
ret_json = []
|
17 |
-
#print(q_id)
|
18 |
-
print('Internal Error happened. Set ret_json to be empty list')
|
19 |
-
|
20 |
-
elif response.status_code == 429:
|
21 |
-
|
22 |
-
print(response.status_code)
|
23 |
-
print(response.text)
|
24 |
-
retry_seconds = int(response.text.split('Too Many Requests - Please retry in ')[1].split(' seconds')[0])
|
25 |
-
print('rerun in %d seconds' %retry_seconds)
|
26 |
-
time.sleep(retry_seconds + 1)
|
27 |
-
|
28 |
-
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
29 |
-
ret_json = response.json()['results']['bindings']
|
30 |
-
#print(ret_json)
|
31 |
-
print('resumed and succeeded')
|
32 |
-
|
33 |
-
else:
|
34 |
-
print(response.status_code, response.text)
|
35 |
-
exit(-1)
|
36 |
-
|
37 |
-
return ret_json
|
38 |
-
|
39 |
-
'''Search for wikidata entities given the name string'''
|
40 |
-
def wikidata_query (self, name_str):
|
41 |
-
|
42 |
-
query = """
|
43 |
-
PREFIX wd: <http://www.wikidata.org/entity/>
|
44 |
-
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
45 |
-
PREFIX wdv: <http://www.wikidata.org/value/>
|
46 |
-
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
47 |
-
PREFIX wikibase: <http://wikiba.se/ontology#>
|
48 |
-
PREFIX p: <http://www.wikidata.org/prop/>
|
49 |
-
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
50 |
-
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
51 |
-
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
52 |
-
PREFIX bd: <http://www.bigdata.com/rdf#>
|
53 |
-
|
54 |
-
SELECT ?item ?coordinates ?itemDescription WHERE {
|
55 |
-
?item rdfs:label \"%s\"@en;
|
56 |
-
wdt:P625 ?coordinates .
|
57 |
-
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
58 |
-
}
|
59 |
-
"""%(name_str)
|
60 |
-
|
61 |
-
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
62 |
-
|
63 |
-
|
64 |
-
ret_json = self.response_handler(response, query)
|
65 |
-
|
66 |
-
return ret_json
|
67 |
-
|
68 |
-
|
69 |
-
'''Search for wikidata entities given the name string'''
|
70 |
-
def wikidata_query_withinstate (self, name_str, state_id = 'Q99'):
|
71 |
-
|
72 |
-
|
73 |
-
query = """
|
74 |
-
PREFIX wd: <http://www.wikidata.org/entity/>
|
75 |
-
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
76 |
-
PREFIX wdv: <http://www.wikidata.org/value/>
|
77 |
-
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
78 |
-
PREFIX wikibase: <http://wikiba.se/ontology#>
|
79 |
-
PREFIX p: <http://www.wikidata.org/prop/>
|
80 |
-
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
81 |
-
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
82 |
-
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
83 |
-
PREFIX bd: <http://www.bigdata.com/rdf#>
|
84 |
-
|
85 |
-
SELECT ?item ?coordinates ?itemDescription WHERE {
|
86 |
-
?item rdfs:label \"%s\"@en;
|
87 |
-
wdt:P625 ?coordinates ;
|
88 |
-
wdt:P131+ wd:%s;
|
89 |
-
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
90 |
-
}
|
91 |
-
"""%(name_str, state_id)
|
92 |
-
|
93 |
-
#print(query)
|
94 |
-
|
95 |
-
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
96 |
-
|
97 |
-
ret_json = self.response_handler(response, query)
|
98 |
-
|
99 |
-
return ret_json
|
100 |
-
|
101 |
-
|
102 |
-
'''Search for nearby wikidata entities given the entity id'''
|
103 |
-
def wikidata_nearby_query (self, q_id):
|
104 |
-
|
105 |
-
query = """
|
106 |
-
PREFIX wd: <http://www.wikidata.org/entity/>
|
107 |
-
PREFIX wds: <http://www.wikidata.org/entity/statement/>
|
108 |
-
PREFIX wdv: <http://www.wikidata.org/value/>
|
109 |
-
PREFIX wdt: <http://www.wikidata.org/prop/direct/>
|
110 |
-
PREFIX wikibase: <http://wikiba.se/ontology#>
|
111 |
-
PREFIX p: <http://www.wikidata.org/prop/>
|
112 |
-
PREFIX ps: <http://www.wikidata.org/prop/statement/>
|
113 |
-
PREFIX pq: <http://www.wikidata.org/prop/qualifier/>
|
114 |
-
PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#>
|
115 |
-
PREFIX bd: <http://www.bigdata.com/rdf#>
|
116 |
-
|
117 |
-
SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
|
118 |
-
WHERE
|
119 |
-
{
|
120 |
-
wd:%s wdt:P625 ?loc .
|
121 |
-
SERVICE wikibase:around {
|
122 |
-
?place wdt:P625 ?location .
|
123 |
-
bd:serviceParam wikibase:center ?loc .
|
124 |
-
bd:serviceParam wikibase:radius "5" .
|
125 |
-
}
|
126 |
-
OPTIONAL { ?place wdt:P31 ?instance }
|
127 |
-
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
128 |
-
BIND(geof:distance(?loc, ?location) as ?dist)
|
129 |
-
} ORDER BY ?dist
|
130 |
-
LIMIT 200
|
131 |
-
"""%(q_id)
|
132 |
-
# initially 2km
|
133 |
-
|
134 |
-
#pdb.set_trace()
|
135 |
-
|
136 |
-
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
137 |
-
|
138 |
-
|
139 |
-
ret_json = self.response_handler(response, query)
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
return ret_json
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
def linkedgeodata_query (self, name_str):
|
149 |
-
|
150 |
-
query = """
|
151 |
-
|
152 |
-
Prefix lgdo: <http://linkedgeodata.org/ontology/>
|
153 |
-
Prefix geom: <http://geovocab.org/geometry#>
|
154 |
-
Prefix ogc: <http://www.opengis.net/ont/geosparql#>
|
155 |
-
Prefix owl: <http://www.w3.org/2002/07/owl#>
|
156 |
-
Prefix wgs84_pos: <http://www.w3.org/2003/01/geo/wgs84_pos#>
|
157 |
-
Prefix owl: <http://www.w3.org/2002/07/owl#>
|
158 |
-
Prefix gn: <http://www.geonames.org/ontology#>
|
159 |
-
|
160 |
-
Select ?s, ?lat, ?long {
|
161 |
-
{?s rdfs:label \"%s\";
|
162 |
-
wgs84_pos:lat ?lat ;
|
163 |
-
wgs84_pos:long ?long;
|
164 |
-
}
|
165 |
-
}
|
166 |
-
"""%(name_str)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
response = requests.get(self.baseuri, params = {'format':'json', 'query':query})
|
171 |
-
|
172 |
-
|
173 |
-
ret_json = self.response_handler(response, query)
|
174 |
-
|
175 |
-
|
176 |
-
return ret_json
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
if __name__ == '__main__':
|
181 |
-
request_wrapper_wikidata = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
182 |
-
#print(request_wrapper_wikidata.wikidata_nearby_query('Q370771'))
|
183 |
-
#print(request_wrapper_wikidata.wikidata_query_withinstate('San Bernardino'))
|
184 |
-
|
185 |
-
# not working now
|
186 |
-
print(request_wrapper_wikidata.linkedgeodata_query('San Bernardino'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/run_linking_query.py
DELETED
@@ -1,143 +0,0 @@
|
|
1 |
-
#from query_wrapper import QueryWrapper
|
2 |
-
from request_wrapper import RequestWrapper
|
3 |
-
from get_namelist import *
|
4 |
-
import glob
|
5 |
-
import os
|
6 |
-
import json
|
7 |
-
import time
|
8 |
-
|
9 |
-
DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
|
10 |
-
KB_OPTIONS = ['wikidata', 'linkedgeodata']
|
11 |
-
|
12 |
-
DATASET = 'USGS'
|
13 |
-
KB = 'wikidata'
|
14 |
-
OVERWRITE = True
|
15 |
-
WITHIN_CA = True
|
16 |
-
|
17 |
-
assert DATASET in DATASET_OPTIONS
|
18 |
-
assert KB in KB_OPTIONS
|
19 |
-
|
20 |
-
|
21 |
-
def process_one_namelist(sparql_wrapper, namelist, out_path):
|
22 |
-
|
23 |
-
if OVERWRITE:
|
24 |
-
# flush the file if it's been written
|
25 |
-
with open(out_path, 'w') as f:
|
26 |
-
f.write('')
|
27 |
-
|
28 |
-
|
29 |
-
for name in namelist:
|
30 |
-
name = name.replace('"', '')
|
31 |
-
name = name.strip("'")
|
32 |
-
if len(name) == 0:
|
33 |
-
continue
|
34 |
-
print(name)
|
35 |
-
mydict = dict()
|
36 |
-
|
37 |
-
if KB == 'wikidata':
|
38 |
-
if WITHIN_CA:
|
39 |
-
mydict[name] = sparql_wrapper.wikidata_query_withinstate(name)
|
40 |
-
else:
|
41 |
-
mydict[name] = sparql_wrapper.wikidata_query(name)
|
42 |
-
|
43 |
-
elif KB == 'linkedgeodata':
|
44 |
-
mydict[name] = sparql_wrapper.linkedgeodata_query(name)
|
45 |
-
else:
|
46 |
-
raise NotImplementedError
|
47 |
-
|
48 |
-
line = json.dumps(mydict)
|
49 |
-
|
50 |
-
with open(out_path, 'a') as f:
|
51 |
-
f.write(line)
|
52 |
-
f.write('\n')
|
53 |
-
time.sleep(1)
|
54 |
-
|
55 |
-
|
56 |
-
def process_namelist_dict(sparql_wrapper, namelist_dict, out_dir):
|
57 |
-
i = 0
|
58 |
-
for map_name, namelist in namelist_dict.items():
|
59 |
-
# if i <=5:
|
60 |
-
# i += 1
|
61 |
-
# continue
|
62 |
-
|
63 |
-
print('processing %s' %map_name)
|
64 |
-
|
65 |
-
if WITHIN_CA:
|
66 |
-
out_path = os.path.join(out_dir, KB + '_' + map_name + '.json')
|
67 |
-
else:
|
68 |
-
out_path = os.path.join(out_dir, KB + '_ca_' + map_name + '.json')
|
69 |
-
|
70 |
-
process_one_namelist(sparql_wrapper, namelist, out_path)
|
71 |
-
i+=1
|
72 |
-
|
73 |
-
|
74 |
-
if KB == 'linkedgeodata':
|
75 |
-
sparql_wrapper = RequestWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
76 |
-
elif KB == 'wikidata':
|
77 |
-
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
78 |
-
else:
|
79 |
-
raise NotImplementedError
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
if DATASET == 'OSM':
|
84 |
-
osm_dir = '../surface_form/data_sample_london/data_osm/'
|
85 |
-
osm_paths = glob.glob(os.path.join(osm_dir, 'embedding*.json'))
|
86 |
-
|
87 |
-
out_path = 'outputs/'+KB+'_linking.json'
|
88 |
-
namelist = get_name_list_osm(osm_paths)
|
89 |
-
|
90 |
-
print('# files',len(file_paths))
|
91 |
-
|
92 |
-
process_one_namelist(sparql_wrapper, namelist, out_path)
|
93 |
-
|
94 |
-
|
95 |
-
elif DATASET == 'OS':
|
96 |
-
histmap_dir = 'data/labGISReport-master/output/'
|
97 |
-
file_paths = glob.glob(os.path.join(histmap_dir, '10*.json'))
|
98 |
-
|
99 |
-
out_path = 'outputs/'+KB+'_os_linking_descript.json'
|
100 |
-
namelist = get_name_list_usgs_od(file_paths)
|
101 |
-
|
102 |
-
print('# files',len(file_paths))
|
103 |
-
|
104 |
-
|
105 |
-
process_one_namelist(sparql_wrapper, namelist, out_path)
|
106 |
-
|
107 |
-
elif DATASET == 'USGS':
|
108 |
-
histmap_dir = 'data/labGISReport-master/output/'
|
109 |
-
file_paths = glob.glob(os.path.join(histmap_dir, 'USGS*.json'))
|
110 |
-
|
111 |
-
if WITHIN_CA:
|
112 |
-
out_dir = 'outputs/' + KB +'_ca'
|
113 |
-
else:
|
114 |
-
out_dir = 'outputs/' + KB
|
115 |
-
namelist_dict = get_name_list_usgs_od_per_map(file_paths)
|
116 |
-
|
117 |
-
if not os.path.isdir(out_dir):
|
118 |
-
os.makedirs(out_dir)
|
119 |
-
|
120 |
-
print('# files',len(file_paths))
|
121 |
-
|
122 |
-
process_namelist_dict(sparql_wrapper, namelist_dict, out_dir)
|
123 |
-
|
124 |
-
elif DATASET == 'GB1900':
|
125 |
-
|
126 |
-
file_path = 'data/GB1900_gazetteer_abridged_july_2018/gb1900_abridged.csv'
|
127 |
-
out_path = 'outputs/'+KB+'_gb1900_linking_descript.json'
|
128 |
-
namelist = get_name_list_gb1900(file_path)
|
129 |
-
|
130 |
-
|
131 |
-
process_one_namelist(sparql_wrapper, namelist, out_path)
|
132 |
-
|
133 |
-
else:
|
134 |
-
raise NotImplementedError
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
#namelist = namelist[730:] #for GB1900
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
print('done')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/run_map_neighbor_query.py
DELETED
@@ -1,123 +0,0 @@
|
|
1 |
-
from query_wrapper import QueryWrapper
|
2 |
-
from request_wrapper import RequestWrapper
|
3 |
-
import glob
|
4 |
-
import os
|
5 |
-
import json
|
6 |
-
import time
|
7 |
-
import random
|
8 |
-
|
9 |
-
DATASET_OPTIONS = ['OSM', 'OS', 'USGS', 'GB1900']
|
10 |
-
KB_OPTIONS = ['wikidata', 'linkedgeodata']
|
11 |
-
|
12 |
-
dataset = 'USGS'
|
13 |
-
kb = 'wikidata'
|
14 |
-
overwrite = False
|
15 |
-
|
16 |
-
assert dataset in DATASET_OPTIONS
|
17 |
-
assert kb in KB_OPTIONS
|
18 |
-
|
19 |
-
if dataset == 'OSM':
|
20 |
-
raise NotImplementedError
|
21 |
-
|
22 |
-
elif dataset == 'OS':
|
23 |
-
raise NotImplementedError
|
24 |
-
|
25 |
-
elif dataset == 'USGS':
|
26 |
-
|
27 |
-
candidate_file_paths = glob.glob('outputs/alignment_dir/wikidata_USGS*.json')
|
28 |
-
candidate_file_paths = sorted(candidate_file_paths)
|
29 |
-
|
30 |
-
out_dir = 'outputs/wikidata_neighbors/'
|
31 |
-
|
32 |
-
if not os.path.isdir(out_dir):
|
33 |
-
os.makedirs(out_dir)
|
34 |
-
|
35 |
-
elif dataset == 'GB1900':
|
36 |
-
|
37 |
-
raise NotImplementedError
|
38 |
-
|
39 |
-
else:
|
40 |
-
raise NotImplementedError
|
41 |
-
|
42 |
-
|
43 |
-
if kb == 'linkedgeodata':
|
44 |
-
sparql_wrapper = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
45 |
-
elif kb == 'wikidata':
|
46 |
-
#sparql_wrapper = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
47 |
-
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
48 |
-
else:
|
49 |
-
raise NotImplementedError
|
50 |
-
|
51 |
-
start_map = 6 # 6
|
52 |
-
start_line = 4 # 4
|
53 |
-
|
54 |
-
for candiate_file_path in candidate_file_paths[start_map:]:
|
55 |
-
map_name = os.path.basename(candiate_file_path).split('wikidata_')[1]
|
56 |
-
out_path = os.path.join(out_dir, 'wikidata_' + map_name)
|
57 |
-
|
58 |
-
with open(candiate_file_path, 'r') as f:
|
59 |
-
cand_data = f.readlines()
|
60 |
-
|
61 |
-
with open(out_path, 'a') as out_f:
|
62 |
-
for line in cand_data[start_line:]:
|
63 |
-
line_dict = json.loads(line)
|
64 |
-
ret_line_dict = dict()
|
65 |
-
for key, value in line_dict.items(): # actually just one pair
|
66 |
-
|
67 |
-
print(key)
|
68 |
-
|
69 |
-
place_name = key
|
70 |
-
for cand_entity in value:
|
71 |
-
time.sleep(2)
|
72 |
-
q_id = cand_entity['item']['value'].split('/')[-1]
|
73 |
-
response = sparql_wrapper.wikidata_nearby_query(str(q_id))
|
74 |
-
|
75 |
-
if place_name in ret_line_dict:
|
76 |
-
ret_line_dict[place_name].append(response)
|
77 |
-
else:
|
78 |
-
ret_line_dict[place_name] = [response]
|
79 |
-
|
80 |
-
#time.sleep(random.random()*6)
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
out_f.write(json.dumps(ret_line_dict))
|
85 |
-
out_f.write('\n')
|
86 |
-
|
87 |
-
print('finished with ',candiate_file_path)
|
88 |
-
break
|
89 |
-
|
90 |
-
print('done')
|
91 |
-
|
92 |
-
'''
|
93 |
-
{"Martin": [{"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27001"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(18.921388888 49.063611111)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in northern Slovakia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q281028"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.734166666 43.175)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in and county seat of Bennett County, South Dakota, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q761390"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(1.3394 51.179)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "hamlet in Kent"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2177502"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-100.115 47.826666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in North Dakota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2454021"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-85.64168 42.53698)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Michigan, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2481111"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.21833 32.09917)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Red River Parish, Louisiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2635473"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.75944 37.56778)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Kentucky, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2679547"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-1.9041 50.9759)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village in Hampshire, England, United Kingdom"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q2780056"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-88.851666666 36.341944444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "city in Tennessee, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q3261150"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.18556 34.48639)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "town in Franklin and Stephens Counties, Georgia, USA"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6002227"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-101.709 41.2581)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "census-designated place in Nebraska, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774807"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-82.1906 29.2936)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Marion County, Florida, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774809"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-83.336666666 41.5575)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Ohio, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q6774810"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(116.037 -32.071)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "suburb of Perth, Western Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q9029707"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-81.476388888 33.069166666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in South Carolina, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q11770660"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.325391 53.1245)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "village and civil parish in Lincolnshire, UK"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14692833"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-93.5444 47.4894)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Itasca County, Minnesota"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q14714180"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.0889 39.2242)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "unincorporated community in Grant County, West Virginia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q18496647"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.95941 46.34585)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Italy"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q20949553"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(22.851944444 41.954444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "mountain in Republic of Macedonia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q24065096"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "<http://www.wikidata.org/entity/Q111> Point(290.75 -21.34)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "crater on Mars"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q26300074"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-0.97879914 51.74729077)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Thame, South Oxfordshire, Oxfordshire, OX9"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27988822"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-87.67361111 38.12444444)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Vanderburgh County, Indiana, United States"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q27995389"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-121.317 47.28)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Washington, United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q28345614"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-76.8 48.5)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in Senneterre, Quebec, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q30626037"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-79.90972222 39.80638889)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "human settlement in United States of America"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q61038281"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-91.13 49.25)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Meteorological Service of Canada's station for Martin (MSC ID: 6035000), Ontario, Canada"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526691"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(152.7011111 -29.881666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Fitzroy County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q63526695"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(148.1011111 -33.231666666)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Parish of Ashburnham County, New South Wales, Australia"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96149222"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(14.725573779 48.76768332)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q96158116"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(15.638358708 49.930136157)"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q103777024"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(-3.125028822 58.694748091)"}, "itemDescription": {"xml:lang": "en", "type": "literal", "value": "Shipwreck off the Scottish Coast, imported from Canmore Nov 2020"}}, {"item": {"type": "uri", "value": "http://www.wikidata.org/entity/Q107077206"}, "coordinates": {"datatype": "http://www.opengis.net/ont/geosparql#wktLiteral", "type": "literal", "value": "Point(11.688165 46.582982)"}}]}
|
94 |
-
|
95 |
-
|
96 |
-
if overwrite:
|
97 |
-
# flush the file if it's been written
|
98 |
-
with open(out_path, 'w') as f:
|
99 |
-
f.write('')
|
100 |
-
|
101 |
-
for name in namelist:
|
102 |
-
name = name.replace('"', '')
|
103 |
-
name = name.strip("'")
|
104 |
-
if len(name) == 0:
|
105 |
-
continue
|
106 |
-
print(name)
|
107 |
-
mydict = dict()
|
108 |
-
mydict[name] = sparql_wrapper.wikidata_query(name)
|
109 |
-
line = json.dumps(mydict)
|
110 |
-
#print(line)
|
111 |
-
with open(out_path, 'a') as f:
|
112 |
-
f.write(line)
|
113 |
-
f.write('\n')
|
114 |
-
time.sleep(1)
|
115 |
-
|
116 |
-
print('done')
|
117 |
-
|
118 |
-
|
119 |
-
{"info": {"name": "10TH ST", "geometry": [4193.118085062303, -831.274950414831]},
|
120 |
-
"neighbor_info":
|
121 |
-
{"name_list": ["BM 107", "PALM AVE", "WT", "Hidalgo Sch", "PO", "MAIN ST", "BRYANT CANAL", "BM 123", "Oakley Sch", "BRAWLEY", "Witter Sch", "BM 104", "Pistol Range", "Reid Sch", "STANLEY", "MUNICIPAL AIRPORT", "WESTERN AVE", "CANAL", "Riverview Cem", "BEST CANAL"],
|
122 |
-
"geometry_list": [[4180.493095652702, -836.0635465095995], [4240.450935702045, -855.345637906981], [4136.084840542623, -917.7895986922882], [4150.386997979736, -948.7258091165079], [4056.955267048625, -847.1018277439381], [4008.112642182582, -849.089249977583], [4124.177447575567, -1004.0706369942257], [4145.382175508665, -626.1608201557082], [4398.137868976953, -764.1087236140554], [4221.1546492913285, -1062.5745271963772], [4015.203890157584, -985.0178210457995], [3989.2345421184878, -948.9340389243871], [4385.585449075614, -660.4590917125413], [3936.505159635338, -803.6822663422273], [3960.1233867112846, -686.7988766730389], [4409.714306709143, -600.6633389979504], [3871.2873706574037, -832.0785684368772], [4304.899727301024, -524.472390102557], [3955.640201659347, -578.5544271698675], [4075.8524354668034, -1183.5837385075774]]}}
|
123 |
-
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/run_query_sample.py
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
from query_wrapper import QueryWrapper
|
2 |
-
from get_namelist import *
|
3 |
-
import glob
|
4 |
-
import os
|
5 |
-
import json
|
6 |
-
import time
|
7 |
-
|
8 |
-
|
9 |
-
#sparql_wrapper_linkedgeo = QueryWrapper(baseuri = 'http://linkedgeodata.org/sparql')
|
10 |
-
|
11 |
-
#print(sparql_wrapper_linkedgeo.linkedgeodata_query('Los Angeles'))
|
12 |
-
|
13 |
-
|
14 |
-
sparql_wrapper_wikidata = QueryWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
15 |
-
|
16 |
-
#print(sparql_wrapper_wikidata.wikidata_query('Los Angeles'))
|
17 |
-
|
18 |
-
#time.sleep(3)
|
19 |
-
|
20 |
-
#print(sparql_wrapper_wikidata.wikidata_nearby_query('Q370771'))
|
21 |
-
print(sparql_wrapper_wikidata.wikidata_nearby_query('Q97625145'))
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/run_wikidata_neighbor_query.py
DELETED
@@ -1,31 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import json
|
3 |
-
from request_wrapper import RequestWrapper
|
4 |
-
import time
|
5 |
-
import pdb
|
6 |
-
|
7 |
-
start_idx = 17335
|
8 |
-
wikidata_sample30k_path = 'wikidata_sample30k/wikidata_30k.json'
|
9 |
-
out_path = 'wikidata_sample30k/wikidata_30k_neighbor.json'
|
10 |
-
|
11 |
-
#with open(out_path, 'w') as out_f:
|
12 |
-
# pass
|
13 |
-
|
14 |
-
sparql_wrapper = RequestWrapper(baseuri = 'https://query.wikidata.org/sparql')
|
15 |
-
|
16 |
-
df= pd.read_json(wikidata_sample30k_path)
|
17 |
-
df = df[start_idx:]
|
18 |
-
|
19 |
-
print('length of df:', len(df))
|
20 |
-
|
21 |
-
for index, record in df.iterrows():
|
22 |
-
print(index)
|
23 |
-
uri = record.results['item']['value']
|
24 |
-
q_id = uri.split('/')[-1]
|
25 |
-
response = sparql_wrapper.wikidata_nearby_query(str(q_id))
|
26 |
-
time.sleep(1)
|
27 |
-
with open(out_path, 'a') as out_f:
|
28 |
-
out_f.write(json.dumps(response))
|
29 |
-
out_f.write('\n')
|
30 |
-
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/samples.sparql
DELETED
@@ -1,22 +0,0 @@
|
|
1 |
-
'''
|
2 |
-
Entities near xxx within 1km
|
3 |
-
SELECT ?place ?placeLabel ?location ?instanceLabel ?placeDescription
|
4 |
-
WHERE
|
5 |
-
{
|
6 |
-
wd:Q9188 wdt:P625 ?loc .
|
7 |
-
SERVICE wikibase:around {
|
8 |
-
?place wdt:P625 ?location .
|
9 |
-
bd:serviceParam wikibase:center ?loc .
|
10 |
-
bd:serviceParam wikibase:radius "1" .
|
11 |
-
}
|
12 |
-
OPTIONAL { ?place wdt:P31 ?instance }
|
13 |
-
SERVICE wikibase:label { bd:serviceParam wikibase:language "en" }
|
14 |
-
BIND(geof:distance(?loc, ?location) as ?dist)
|
15 |
-
} ORDER BY ?dist
|
16 |
-
'''
|
17 |
-
|
18 |
-
|
19 |
-
SELECT distinct ?item ?itemLabel WHERE {
|
20 |
-
?item wdt:P625 ?geo .
|
21 |
-
SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
|
22 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/select_ambi.py
DELETED
@@ -1,18 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
|
3 |
-
json_file = 'entity_linking/outputs/wikidata_usgs_linking_descript.json'
|
4 |
-
|
5 |
-
with open(json_file, 'r') as f:
|
6 |
-
data = f.readlines()
|
7 |
-
|
8 |
-
num_ambi = 0
|
9 |
-
for line in data:
|
10 |
-
line_dict = json.loads(line)
|
11 |
-
for key,value in line_dict.items():
|
12 |
-
len_value = len(value)
|
13 |
-
if len_value < 2:
|
14 |
-
continue
|
15 |
-
else:
|
16 |
-
num_ambi += 1
|
17 |
-
print(key)
|
18 |
-
print(num_ambi)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/data_processing/wikidata_sample30k/wikidata_30k.json
DELETED
The diff for this file is too large to render.
See raw diff
|
|
models/spabert/experiments/entity_matching/src/evaluation-mrr.py
DELETED
@@ -1,260 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
|
5 |
-
import sys
|
6 |
-
import os
|
7 |
-
import glob
|
8 |
-
import json
|
9 |
-
import numpy as np
|
10 |
-
import pandas as pd
|
11 |
-
import pdb
|
12 |
-
|
13 |
-
|
14 |
-
prediction_dir = sys.argv[1]
|
15 |
-
|
16 |
-
print(prediction_dir)
|
17 |
-
|
18 |
-
gt_dir = '../data_processing/outputs/alignment_gt_dir/'
|
19 |
-
prediction_path_list = sorted(os.listdir(prediction_dir))
|
20 |
-
|
21 |
-
DISPLAY = False
|
22 |
-
DETAIL = False
|
23 |
-
|
24 |
-
if DISPLAY:
|
25 |
-
from IPython.display import display
|
26 |
-
|
27 |
-
def recall_at_k_all_map(all_rank_list, k = 1):
|
28 |
-
|
29 |
-
rank_list = [item for sublist in all_rank_list for item in sublist]
|
30 |
-
total_query = len(rank_list)
|
31 |
-
prec = np.sum(np.array(rank_list)<=k)
|
32 |
-
prec = 1.0 * prec / total_query
|
33 |
-
|
34 |
-
return prec
|
35 |
-
|
36 |
-
def recall_at_k_permap(all_rank_list, k = 1):
|
37 |
-
|
38 |
-
prec_list = []
|
39 |
-
for rank_list in all_rank_list:
|
40 |
-
total_query = len(rank_list)
|
41 |
-
prec = np.sum(np.array(rank_list)<=k)
|
42 |
-
prec = 1.0 * prec / total_query
|
43 |
-
prec_list.append(prec)
|
44 |
-
|
45 |
-
return prec_list
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
def reciprocal_rank(all_rank_list):
|
51 |
-
|
52 |
-
recip_list = [1./rank for rank in all_rank_list]
|
53 |
-
mean_recip = np.mean(recip_list)
|
54 |
-
|
55 |
-
return mean_recip, recip_list
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
count_hist_list = []
|
61 |
-
|
62 |
-
all_rank_list = []
|
63 |
-
|
64 |
-
all_recip_list = []
|
65 |
-
|
66 |
-
permap_recip_list = []
|
67 |
-
|
68 |
-
for map_path in prediction_path_list:
|
69 |
-
|
70 |
-
pred_path = os.path.join(prediction_dir, map_path)
|
71 |
-
gt_path = os.path.join(gt_dir, map_path.split('.json')[0] + '.csv')
|
72 |
-
|
73 |
-
if DETAIL:
|
74 |
-
print(pred_path)
|
75 |
-
|
76 |
-
|
77 |
-
with open(gt_path, 'r') as f:
|
78 |
-
gt_data = f.readlines()
|
79 |
-
|
80 |
-
gt_dict = dict()
|
81 |
-
for line in gt_data:
|
82 |
-
line = line.split(',')
|
83 |
-
pivot_name = line[0]
|
84 |
-
gt_uri = line[1]
|
85 |
-
gt_dict[pivot_name] = gt_uri
|
86 |
-
|
87 |
-
rank_list = []
|
88 |
-
pivot_name_list = []
|
89 |
-
with open(pred_path, 'r') as f:
|
90 |
-
pred_data = f.readlines()
|
91 |
-
for line in pred_data:
|
92 |
-
pred_dict = json.loads(line)
|
93 |
-
#print(pred_dict.keys())
|
94 |
-
pivot_name = pred_dict['pivot_name']
|
95 |
-
sorted_match_uri = pred_dict['sorted_match_uri']
|
96 |
-
#sorted_match_des = pred_dict['sorted_match_des']
|
97 |
-
sorted_sim_matrix = pred_dict['sorted_sim_matrix']
|
98 |
-
|
99 |
-
|
100 |
-
total = len(sorted_match_uri)
|
101 |
-
if total == 1:
|
102 |
-
continue
|
103 |
-
|
104 |
-
if pivot_name in gt_dict:
|
105 |
-
|
106 |
-
gt_uri = gt_dict[pivot_name]
|
107 |
-
|
108 |
-
try:
|
109 |
-
assert gt_uri in sorted_match_uri
|
110 |
-
except Exception as e:
|
111 |
-
#print(e)
|
112 |
-
continue
|
113 |
-
|
114 |
-
pivot_name_list.append(pivot_name)
|
115 |
-
count_hist_list.append(total)
|
116 |
-
rank = sorted_match_uri.index(gt_uri) +1
|
117 |
-
|
118 |
-
rank_list.append(rank)
|
119 |
-
#print(rank,'/',total)
|
120 |
-
|
121 |
-
all_rank_list.append(rank_list)
|
122 |
-
|
123 |
-
mean_recip, recip_list = reciprocal_rank(rank_list)
|
124 |
-
|
125 |
-
all_recip_list.extend(recip_list)
|
126 |
-
permap_recip_list.append(recip_list)
|
127 |
-
|
128 |
-
d = {'pivot': pivot_name_list + ['AVG'], 'rank':rank_list + [' '] ,'recip rank': recip_list + [str(mean_recip)]}
|
129 |
-
if DETAIL:
|
130 |
-
print(pivot_name_list, rank_list, recip_list)
|
131 |
-
|
132 |
-
if DISPLAY:
|
133 |
-
df = pd.DataFrame(data=d)
|
134 |
-
|
135 |
-
display(df)
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
print('all mrr, micro', np.mean(all_recip_list))
|
140 |
-
|
141 |
-
|
142 |
-
if DETAIL:
|
143 |
-
|
144 |
-
len(rank_list)
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
print(recall_at_k_all_map(all_rank_list, k = 1))
|
149 |
-
print(recall_at_k_all_map(all_rank_list, k = 2))
|
150 |
-
print(recall_at_k_all_map(all_rank_list, k = 5))
|
151 |
-
print(recall_at_k_all_map(all_rank_list, k = 10))
|
152 |
-
|
153 |
-
|
154 |
-
print(prediction_path_list)
|
155 |
-
|
156 |
-
|
157 |
-
prec_list_1 = recall_at_k_permap(all_rank_list, k = 1)
|
158 |
-
prec_list_2 = recall_at_k_permap(all_rank_list, k = 2)
|
159 |
-
prec_list_5 = recall_at_k_permap(all_rank_list, k = 5)
|
160 |
-
prec_list_10 = recall_at_k_permap(all_rank_list, k = 10)
|
161 |
-
|
162 |
-
if DETAIL:
|
163 |
-
|
164 |
-
print(np.mean(prec_list_1))
|
165 |
-
print(prec_list_1)
|
166 |
-
print('\n')
|
167 |
-
|
168 |
-
print(np.mean(prec_list_2))
|
169 |
-
print(prec_list_2)
|
170 |
-
print('\n')
|
171 |
-
|
172 |
-
print(np.mean(prec_list_5))
|
173 |
-
print(prec_list_5)
|
174 |
-
print('\n')
|
175 |
-
|
176 |
-
print(np.mean(prec_list_10))
|
177 |
-
print(prec_list_10)
|
178 |
-
print('\n')
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
import pandas as pd
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
map_name_list = [name.split('.json')[0].split('USGS-')[1] for name in prediction_path_list]
|
192 |
-
d = {'map_name': map_name_list,'recall@1': prec_list_1, 'recall@2': prec_list_2, 'recall@5': prec_list_5, 'recall@10': prec_list_10 }
|
193 |
-
df = pd.DataFrame(data=d)
|
194 |
-
|
195 |
-
|
196 |
-
if DETAIL:
|
197 |
-
print(df)
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
category = ['15-CA','30-CA','60-CA']
|
204 |
-
col_1 = [np.mean(prec_list_1[0:4]), np.mean(prec_list_1[4:9]), np.mean(prec_list_1[9:])]
|
205 |
-
col_2 = [np.mean(prec_list_2[0:4]), np.mean(prec_list_2[4:9]), np.mean(prec_list_2[9:])]
|
206 |
-
col_3 = [np.mean(prec_list_5[0:4]), np.mean(prec_list_5[4:9]), np.mean(prec_list_5[9:])]
|
207 |
-
col_4 = [np.mean(prec_list_10[0:4]), np.mean(prec_list_10[4:9]), np.mean(prec_list_10[9:])]
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
mrr_15 = permap_recip_list[0] + permap_recip_list[1] + permap_recip_list[2] + permap_recip_list[3]
|
212 |
-
mrr_30 = permap_recip_list[4] + permap_recip_list[5] + permap_recip_list[6] + permap_recip_list[7] + permap_recip_list[8]
|
213 |
-
mrr_60 = permap_recip_list[9] + permap_recip_list[10] + permap_recip_list[11] + permap_recip_list[12] + permap_recip_list[13]
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
column_5 = [np.mean(mrr_15), np.mean(mrr_30), np.mean(mrr_60)]
|
218 |
-
|
219 |
-
|
220 |
-
d = {'map set': category, 'mrr': column_5, 'prec@1': col_1, 'prec@2': col_2, 'prec@5': col_3, 'prec@10': col_4 }
|
221 |
-
df = pd.DataFrame(data=d)
|
222 |
-
|
223 |
-
print(df)
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
print('all mrr, micro', np.mean(all_recip_list))
|
229 |
-
|
230 |
-
print('\n')
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
print(recall_at_k_all_map(all_rank_list, k = 1))
|
235 |
-
print(recall_at_k_all_map(all_rank_list, k = 2))
|
236 |
-
print(recall_at_k_all_map(all_rank_list, k = 5))
|
237 |
-
print(recall_at_k_all_map(all_rank_list, k = 10))
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
if DISPLAY:
|
243 |
-
|
244 |
-
import seaborn
|
245 |
-
|
246 |
-
p = seaborn.histplot(data = count_hist_list, color = 'blue', alpha=0.2)
|
247 |
-
p.set_xlabel("Number of Candiates")
|
248 |
-
p.set_title("Candidate Distribution in USGS")
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
len(count_hist_list)
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/src/linking_ablation.py
DELETED
@@ -1,228 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
import sys
|
5 |
-
import os
|
6 |
-
import numpy as np
|
7 |
-
import pdb
|
8 |
-
import json
|
9 |
-
import scipy.spatial as sp
|
10 |
-
import argparse
|
11 |
-
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch.utils.data import DataLoader
|
15 |
-
|
16 |
-
from transformers import AdamW
|
17 |
-
from transformers import BertTokenizer
|
18 |
-
from tqdm import tqdm # for our progress bar
|
19 |
-
|
20 |
-
sys.path.append('../../../')
|
21 |
-
from datasets.usgs_os_sample_loader import USGS_MapDataset
|
22 |
-
from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
|
23 |
-
from models.spatial_bert_model import SpatialBertModel
|
24 |
-
from models.spatial_bert_model import SpatialBertConfig
|
25 |
-
from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
|
26 |
-
from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
|
27 |
-
from utils.baseline_utils import get_baseline_model
|
28 |
-
|
29 |
-
from transformers import BertModel
|
30 |
-
|
31 |
-
sys.path.append('/home/zekun/spatial_bert/spatial_bert/datasets')
|
32 |
-
from dataset_loader import SpatialDataset
|
33 |
-
from osm_sample_loader import PbfMapDataset
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
|
38 |
-
'spanbert-base','spanbert-large','luke-base','luke-large',
|
39 |
-
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
40 |
-
|
41 |
-
|
42 |
-
CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
|
43 |
-
|
44 |
-
def recall_at_k(rank_list, k = 1):
|
45 |
-
|
46 |
-
total_query = len(rank_list)
|
47 |
-
recall = np.sum(np.array(rank_list)<=k)
|
48 |
-
recall = 1.0 * recall / total_query
|
49 |
-
|
50 |
-
return recall
|
51 |
-
|
52 |
-
def reciprocal_rank(all_rank_list):
|
53 |
-
|
54 |
-
recip_list = [1./rank for rank in all_rank_list]
|
55 |
-
mean_recip = np.mean(recip_list)
|
56 |
-
|
57 |
-
return mean_recip, recip_list
|
58 |
-
|
59 |
-
def link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list):
|
60 |
-
|
61 |
-
source_emb_list = [source_dict['emb'] for source_dict in source_embedding_ogc_list]
|
62 |
-
source_ogc_list = [source_dict['ogc_fid'] for source_dict in source_embedding_ogc_list]
|
63 |
-
|
64 |
-
target_emb_list = [target_dict['emb'] for target_dict in target_embedding_ogc_list]
|
65 |
-
target_ogc_list = [target_dict['ogc_fid'] for target_dict in target_embedding_ogc_list]
|
66 |
-
|
67 |
-
rank_list = []
|
68 |
-
for source_emb, source_ogc in zip(source_emb_list, source_ogc_list):
|
69 |
-
sim_matrix = 1 - sp.distance.cdist(np.array(target_emb_list), np.array([source_emb]), 'cosine')
|
70 |
-
closest_match_ogc = sort_ref_closest_match(sim_matrix, target_ogc_list)
|
71 |
-
|
72 |
-
closest_match_ogc = [a[0] for a in closest_match_ogc]
|
73 |
-
rank = closest_match_ogc.index(source_ogc) +1
|
74 |
-
rank_list.append(rank)
|
75 |
-
|
76 |
-
|
77 |
-
mean_recip, recip_list = reciprocal_rank(rank_list)
|
78 |
-
r1 = recall_at_k(rank_list, k = 1)
|
79 |
-
r5 = recall_at_k(rank_list, k = 5)
|
80 |
-
r10 = recall_at_k(rank_list, k = 10)
|
81 |
-
|
82 |
-
return mean_recip , r1, r5, r10
|
83 |
-
|
84 |
-
def get_embedding_and_ogc(dataset, model_name, model):
|
85 |
-
dict_list = []
|
86 |
-
|
87 |
-
for source in dataset:
|
88 |
-
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
89 |
-
source_emb = get_spatialbert_embedding(source, model)
|
90 |
-
else:
|
91 |
-
source_emb = get_bert_embedding(source, model)
|
92 |
-
|
93 |
-
source_dict = {}
|
94 |
-
source_dict['emb'] = source_emb
|
95 |
-
source_dict['ogc_fid'] = source['ogc_fid']
|
96 |
-
#wikidata_dict['wikidata_des_list'] = [wikidata_cand['description']]
|
97 |
-
|
98 |
-
dict_list.append(source_dict)
|
99 |
-
|
100 |
-
return dict_list
|
101 |
-
|
102 |
-
|
103 |
-
def entity_linking_func(args):
|
104 |
-
|
105 |
-
model_name = args.model_name
|
106 |
-
candset_mode = args.candset_mode
|
107 |
-
|
108 |
-
distance_norm_factor = args.distance_norm_factor
|
109 |
-
spatial_dist_fill= args.spatial_dist_fill
|
110 |
-
sep_between_neighbors = args.sep_between_neighbors
|
111 |
-
|
112 |
-
spatial_bert_weight_dir = args.spatial_bert_weight_dir
|
113 |
-
spatial_bert_weight_name = args.spatial_bert_weight_name
|
114 |
-
|
115 |
-
if_no_spatial_distance = args.no_spatial_distance
|
116 |
-
random_remove_neighbor = args.random_remove_neighbor
|
117 |
-
|
118 |
-
|
119 |
-
assert model_name in MODEL_OPTIONS
|
120 |
-
assert candset_mode in CANDSET_MODES
|
121 |
-
|
122 |
-
|
123 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
if model_name == 'spatial_bert-base':
|
128 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
129 |
-
|
130 |
-
config = SpatialBertConfig()
|
131 |
-
model = SpatialBertModel(config)
|
132 |
-
|
133 |
-
model.to(device)
|
134 |
-
model.eval()
|
135 |
-
|
136 |
-
# load pretrained weights
|
137 |
-
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
138 |
-
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
139 |
-
|
140 |
-
elif model_name == 'spatial_bert-large':
|
141 |
-
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
142 |
-
|
143 |
-
config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
|
144 |
-
model = SpatialBertModel(config)
|
145 |
-
|
146 |
-
model.to(device)
|
147 |
-
model.eval()
|
148 |
-
|
149 |
-
# load pretrained weights
|
150 |
-
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
151 |
-
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
152 |
-
|
153 |
-
else:
|
154 |
-
model, tokenizer = get_baseline_model(model_name)
|
155 |
-
model.to(device)
|
156 |
-
model.eval()
|
157 |
-
|
158 |
-
source_file_path = '../data/osm-point-minnesota-full.json'
|
159 |
-
source_dataset = PbfMapDataset(data_file_path = source_file_path,
|
160 |
-
tokenizer = tokenizer,
|
161 |
-
max_token_len = 512,
|
162 |
-
distance_norm_factor = distance_norm_factor,
|
163 |
-
spatial_dist_fill = spatial_dist_fill,
|
164 |
-
with_type = False,
|
165 |
-
sep_between_neighbors = sep_between_neighbors,
|
166 |
-
mode = None,
|
167 |
-
random_remove_neighbor = random_remove_neighbor,
|
168 |
-
)
|
169 |
-
|
170 |
-
target_dataset = PbfMapDataset(data_file_path = source_file_path,
|
171 |
-
tokenizer = tokenizer,
|
172 |
-
max_token_len = 512,
|
173 |
-
distance_norm_factor = distance_norm_factor,
|
174 |
-
spatial_dist_fill = spatial_dist_fill,
|
175 |
-
with_type = False,
|
176 |
-
sep_between_neighbors = sep_between_neighbors,
|
177 |
-
mode = None,
|
178 |
-
random_remove_neighbor = 0., # keep all
|
179 |
-
)
|
180 |
-
|
181 |
-
# process candidates for each phrase
|
182 |
-
|
183 |
-
|
184 |
-
source_embedding_ogc_list = get_embedding_and_ogc(source_dataset, model_name, model)
|
185 |
-
target_embedding_ogc_list = get_embedding_and_ogc(target_dataset, model_name, model)
|
186 |
-
|
187 |
-
|
188 |
-
mean_recip , r1, r5, r10 = link_to_itself(source_embedding_ogc_list, target_embedding_ogc_list)
|
189 |
-
print('\n')
|
190 |
-
print(random_remove_neighbor, mean_recip , r1, r5, r10)
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
def main():
|
195 |
-
parser = argparse.ArgumentParser()
|
196 |
-
parser.add_argument('--model_name', type=str, default='spatial_bert-base')
|
197 |
-
parser.add_argument('--candset_mode', type=str, default='all_map')
|
198 |
-
|
199 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
200 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
201 |
-
|
202 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
203 |
-
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
204 |
-
|
205 |
-
parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
|
206 |
-
parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
|
207 |
-
|
208 |
-
parser.add_argument('--random_remove_neighbor', type = float, default = 0.)
|
209 |
-
|
210 |
-
|
211 |
-
args = parser.parse_args()
|
212 |
-
# print('\n')
|
213 |
-
# print(args)
|
214 |
-
# print('\n')
|
215 |
-
|
216 |
-
entity_linking_func(args)
|
217 |
-
|
218 |
-
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-base' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr5e-05_sep_bert-base_nofreeze_london_california_bsize12/ep0_iter06000_0.2936/' --spatial_bert_weight_name='keeppos_ep0_iter02000_0.4879.pth' --random_remove_neighbor=0.1
|
219 |
-
|
220 |
-
|
221 |
-
# CUDA_VISIBLE_DEVICES='1' python3 linking_ablation.py --sep_between_neighbors --model_name='spatial_bert-large' --spatial_bert_weight_dir='/data/zekun/spatial_bert_weights/typing_lr1e-06_sep_bert-large_nofreeze_london_california_bsize12/ep2_iter02000_0.3921/' --spatial_bert_weight_name='keeppos_ep8_iter03568_0.2661_val0.2284.pth' --random_remove_neighbor=0.1
|
222 |
-
|
223 |
-
|
224 |
-
if __name__ == '__main__':
|
225 |
-
|
226 |
-
main()
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/entity_matching/src/unsupervised_wiki_location_allcand.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
import sys
|
5 |
-
import os
|
6 |
-
import numpy as np
|
7 |
-
import pdb
|
8 |
-
import json
|
9 |
-
import scipy.spatial as sp
|
10 |
-
import argparse
|
11 |
-
|
12 |
-
|
13 |
-
import torch
|
14 |
-
from torch.utils.data import DataLoader
|
15 |
-
#from transformers.models.bert.modeling_bert import BertForMaskedLM
|
16 |
-
|
17 |
-
from transformers import AdamW
|
18 |
-
from transformers import BertTokenizer
|
19 |
-
from tqdm import tqdm # for our progress bar
|
20 |
-
|
21 |
-
sys.path.append('../../../')
|
22 |
-
from datasets.usgs_os_sample_loader import USGS_MapDataset
|
23 |
-
from datasets.wikidata_sample_loader import Wikidata_Geocoord_Dataset, Wikidata_Random_Dataset
|
24 |
-
from models.spatial_bert_model import SpatialBertModel
|
25 |
-
from models.spatial_bert_model import SpatialBertConfig
|
26 |
-
#from models.spatial_bert_model import SpatialBertForMaskedLM
|
27 |
-
from utils.find_closest import find_ref_closest_match, sort_ref_closest_match
|
28 |
-
from utils.common_utils import load_spatial_bert_pretrained_weights, get_spatialbert_embedding, get_bert_embedding, write_to_csv
|
29 |
-
from utils.baseline_utils import get_baseline_model
|
30 |
-
|
31 |
-
from transformers import BertModel
|
32 |
-
|
33 |
-
|
34 |
-
MODEL_OPTIONS = ['spatial_bert-base','spatial_bert-large', 'bert-base','bert-large','roberta-base','roberta-large',
|
35 |
-
'spanbert-base','spanbert-large','luke-base','luke-large',
|
36 |
-
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
37 |
-
|
38 |
-
MAP_TYPES = ['usgs']
|
39 |
-
CANDSET_MODES = ['all_map'] # candidate set is constructed based on all maps or one map
|
40 |
-
|
41 |
-
|
42 |
-
def disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode = 'all_map', if_use_distance = True, select_indices = None):
|
43 |
-
|
44 |
-
if select_indices is None:
|
45 |
-
select_indices = range(0, len(usgs_dataset))
|
46 |
-
|
47 |
-
|
48 |
-
assert(candset_mode in ['all_map','per_map'])
|
49 |
-
|
50 |
-
wikidata_emb_list = [wikidata_dict['wikidata_emb_list'] for wikidata_dict in wikidata_dict_list]
|
51 |
-
wikidata_uri_list = [wikidata_dict['wikidata_uri_list'] for wikidata_dict in wikidata_dict_list]
|
52 |
-
#wikidata_des_list = [wikidata_dict['wikidata_des_list'] for wikidata_dict in wikidata_dict_list]
|
53 |
-
|
54 |
-
if candset_mode == 'all_map':
|
55 |
-
wikidata_emb_list = [item for sublist in wikidata_emb_list for item in sublist] # flatten
|
56 |
-
wikidata_uri_list = [item for sublist in wikidata_uri_list for item in sublist] # flatten
|
57 |
-
#wikidata_des_list = [item for sublist in wikidata_des_list for item in sublist] # flatten
|
58 |
-
|
59 |
-
|
60 |
-
ret_list = []
|
61 |
-
for i in select_indices:
|
62 |
-
|
63 |
-
if candset_mode == 'per_map':
|
64 |
-
usgs_entity = usgs_dataset[i]
|
65 |
-
wikidata_emb_list = wikidata_emb_list[i]
|
66 |
-
wikidata_uri_list = wikidata_uri_list[i]
|
67 |
-
#wikidata_des_list = wikidata_des_list[i]
|
68 |
-
|
69 |
-
elif candset_mode == 'all_map':
|
70 |
-
usgs_entity = usgs_dataset[i]
|
71 |
-
else:
|
72 |
-
raise NotImplementedError
|
73 |
-
|
74 |
-
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
75 |
-
usgs_emb = get_spatialbert_embedding(usgs_entity, model, use_distance = if_use_distance)
|
76 |
-
else:
|
77 |
-
usgs_emb = get_bert_embedding(usgs_entity, model)
|
78 |
-
|
79 |
-
|
80 |
-
sim_matrix = 1 - sp.distance.cdist(np.array(wikidata_emb_list), np.array([usgs_emb]), 'cosine')
|
81 |
-
|
82 |
-
closest_match_uri = sort_ref_closest_match(sim_matrix, wikidata_uri_list)
|
83 |
-
#closest_match_des = sort_ref_closest_match(sim_matrix, wikidata_des_list)
|
84 |
-
|
85 |
-
|
86 |
-
sorted_sim_matrix = np.sort(sim_matrix, axis = 0)[::-1] # descending order
|
87 |
-
|
88 |
-
ret_dict = dict()
|
89 |
-
ret_dict['pivot_name'] = usgs_entity['pivot_name']
|
90 |
-
ret_dict['sorted_match_uri'] = [a[0] for a in closest_match_uri]
|
91 |
-
#ret_dict['sorted_match_des'] = [a[0] for a in closest_match_des]
|
92 |
-
ret_dict['sorted_sim_matrix'] = [a[0] for a in sorted_sim_matrix]
|
93 |
-
|
94 |
-
ret_list.append(ret_dict)
|
95 |
-
|
96 |
-
return ret_list
|
97 |
-
|
98 |
-
def entity_linking_func(args):
|
99 |
-
|
100 |
-
model_name = args.model_name
|
101 |
-
map_type = args.map_type
|
102 |
-
candset_mode = args.candset_mode
|
103 |
-
|
104 |
-
usgs_distance_norm_factor = args.usgs_distance_norm_factor
|
105 |
-
spatial_dist_fill= args.spatial_dist_fill
|
106 |
-
sep_between_neighbors = args.sep_between_neighbors
|
107 |
-
|
108 |
-
spatial_bert_weight_dir = args.spatial_bert_weight_dir
|
109 |
-
spatial_bert_weight_name = args.spatial_bert_weight_name
|
110 |
-
|
111 |
-
if_no_spatial_distance = args.no_spatial_distance
|
112 |
-
|
113 |
-
|
114 |
-
assert model_name in MODEL_OPTIONS
|
115 |
-
assert map_type in MAP_TYPES
|
116 |
-
assert candset_mode in CANDSET_MODES
|
117 |
-
|
118 |
-
|
119 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
120 |
-
|
121 |
-
if args.out_dir is None:
|
122 |
-
|
123 |
-
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
124 |
-
|
125 |
-
if sep_between_neighbors:
|
126 |
-
spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_sep'
|
127 |
-
else:
|
128 |
-
spatialbert_output_dir_str = 'dnorm' + str(usgs_distance_norm_factor ) + '_distfill' + str(spatial_dist_fill) + '_nosep'
|
129 |
-
|
130 |
-
|
131 |
-
checkpoint_ep = spatial_bert_weight_name.split('_')[3]
|
132 |
-
checkpoint_iter = spatial_bert_weight_name.split('_')[4]
|
133 |
-
loss_val = spatial_bert_weight_name.split('_')[5][:-4]
|
134 |
-
|
135 |
-
if if_no_spatial_distance:
|
136 |
-
linking_prediction_dir = 'linking_prediction_dir/abalation_no_distance/'
|
137 |
-
else:
|
138 |
-
linking_prediction_dir = 'linking_prediction_dir'
|
139 |
-
|
140 |
-
if model_name == 'spatial_bert-base':
|
141 |
-
out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val
|
142 |
-
elif model_name == 'spatial_bert-large':
|
143 |
-
|
144 |
-
freeze_str = spatial_bert_weight_dir.split('/')[-2].split('_')[1] # either 'freeze' or 'nofreeze'
|
145 |
-
out_dir = os.path.join('/data2/zekun/', linking_prediction_dir, spatialbert_output_dir_str) + '/' + map_type + '-' + model_name + '-' + checkpoint_ep + '-' + checkpoint_iter + '-' + loss_val + '-' + freeze_str
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
else:
|
150 |
-
out_dir = '/data2/zekun/baseline_linking_prediction_dir/' + map_type + '-' + model_name
|
151 |
-
|
152 |
-
else:
|
153 |
-
out_dir = args.out_dir
|
154 |
-
|
155 |
-
print('out_dir', out_dir)
|
156 |
-
|
157 |
-
if not os.path.isdir(out_dir):
|
158 |
-
os.makedirs(out_dir)
|
159 |
-
|
160 |
-
|
161 |
-
if model_name == 'spatial_bert-base':
|
162 |
-
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
163 |
-
|
164 |
-
config = SpatialBertConfig()
|
165 |
-
model = SpatialBertModel(config)
|
166 |
-
|
167 |
-
model.to(device)
|
168 |
-
model.eval()
|
169 |
-
|
170 |
-
# load pretrained weights
|
171 |
-
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
172 |
-
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
173 |
-
|
174 |
-
elif model_name == 'spatial_bert-large':
|
175 |
-
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
|
176 |
-
|
177 |
-
config = SpatialBertConfig(hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24)
|
178 |
-
model = SpatialBertModel(config)
|
179 |
-
|
180 |
-
model.to(device)
|
181 |
-
model.eval()
|
182 |
-
|
183 |
-
# load pretrained weights
|
184 |
-
weight_path = os.path.join(spatial_bert_weight_dir, spatial_bert_weight_name)
|
185 |
-
model = load_spatial_bert_pretrained_weights(model, weight_path)
|
186 |
-
|
187 |
-
else:
|
188 |
-
model, tokenizer = get_baseline_model(model_name)
|
189 |
-
model.to(device)
|
190 |
-
model.eval()
|
191 |
-
|
192 |
-
|
193 |
-
if map_type == 'usgs':
|
194 |
-
map_name_list = ['USGS-15-CA-brawley-e1957-s1957-p1961',
|
195 |
-
'USGS-15-CA-paloalto-e1899-s1895-rp1911',
|
196 |
-
'USGS-15-CA-capesanmartin-e1921-s1917',
|
197 |
-
'USGS-15-CA-sanfrancisco-e1899-s1892-rp1911',
|
198 |
-
'USGS-30-CA-dardanelles-e1898-s1891-rp1912',
|
199 |
-
'USGS-30-CA-holtville-e1907-s1905-rp1946',
|
200 |
-
'USGS-30-CA-indiospecial-e1904-s1901-rp1910',
|
201 |
-
'USGS-30-CA-lompoc-e1943-s1903-ap1941-rv1941',
|
202 |
-
'USGS-30-CA-sanpedro-e1943-rv1944',
|
203 |
-
'USGS-60-CA-alturas-e1892-rp1904',
|
204 |
-
'USGS-60-CA-amboy-e1942',
|
205 |
-
'USGS-60-CA-amboy-e1943-rv1943',
|
206 |
-
'USGS-60-CA-modoclavabed-e1886-s1884',
|
207 |
-
'USGS-60-CA-saltonsea-e1943-ap1940-rv1942']
|
208 |
-
|
209 |
-
print('processing wikidata...')
|
210 |
-
|
211 |
-
wikidata_dict_list = []
|
212 |
-
|
213 |
-
wikidata_random30k = Wikidata_Random_Dataset(
|
214 |
-
data_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor_reformat.json',
|
215 |
-
#neighbor_file_path = '../data_processing/wikidata_sample30k/wikidata_30k_neighbor.json',
|
216 |
-
tokenizer = tokenizer,
|
217 |
-
max_token_len = 512,
|
218 |
-
distance_norm_factor = 0.0001,
|
219 |
-
spatial_dist_fill=100,
|
220 |
-
sep_between_neighbors = sep_between_neighbors,
|
221 |
-
)
|
222 |
-
|
223 |
-
# process candidates for each phrase
|
224 |
-
for wikidata_cand in wikidata_random30k:
|
225 |
-
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
226 |
-
wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
|
227 |
-
else:
|
228 |
-
wikidata_emb = get_bert_embedding(wikidata_cand, model)
|
229 |
-
|
230 |
-
wikidata_dict = {}
|
231 |
-
wikidata_dict['wikidata_emb_list'] = [wikidata_emb]
|
232 |
-
wikidata_dict['wikidata_uri_list'] = [wikidata_cand['uri']]
|
233 |
-
|
234 |
-
wikidata_dict_list.append(wikidata_dict)
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
for map_name in map_name_list:
|
239 |
-
|
240 |
-
print(map_name)
|
241 |
-
|
242 |
-
wikidata_dict_per_map = {}
|
243 |
-
wikidata_dict_per_map['wikidata_emb_list'] = []
|
244 |
-
wikidata_dict_per_map['wikidata_uri_list'] = []
|
245 |
-
|
246 |
-
wikidata_dataset_permap = Wikidata_Geocoord_Dataset(
|
247 |
-
data_file_path = '../data_processing/outputs/wikidata_reformat/wikidata_' + map_name + '.json',
|
248 |
-
tokenizer = tokenizer,
|
249 |
-
max_token_len = 512,
|
250 |
-
distance_norm_factor = 0.0001,
|
251 |
-
spatial_dist_fill=100,
|
252 |
-
sep_between_neighbors = sep_between_neighbors)
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
for i in range(0, len(wikidata_dataset_permap)):
|
257 |
-
# get all candiates for phrases within the map
|
258 |
-
wikidata_candidates = wikidata_dataset_permap[i] # dataset for each map, list of [cand for each phrase]
|
259 |
-
|
260 |
-
|
261 |
-
# process candidates for each phrase
|
262 |
-
for wikidata_cand in wikidata_candidates:
|
263 |
-
if model_name == 'spatial_bert-base' or model_name == 'spatial_bert-large':
|
264 |
-
wikidata_emb = get_spatialbert_embedding(wikidata_cand, model)
|
265 |
-
else:
|
266 |
-
wikidata_emb = get_bert_embedding(wikidata_cand, model)
|
267 |
-
|
268 |
-
wikidata_dict_per_map['wikidata_emb_list'].append(wikidata_emb)
|
269 |
-
wikidata_dict_per_map['wikidata_uri_list'].append(wikidata_cand['uri'])
|
270 |
-
|
271 |
-
|
272 |
-
wikidata_dict_list.append(wikidata_dict_per_map)
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
for map_name in map_name_list:
|
277 |
-
|
278 |
-
print(map_name)
|
279 |
-
|
280 |
-
|
281 |
-
usgs_dataset = USGS_MapDataset(
|
282 |
-
data_file_path = '../data_processing/outputs/alignment_dir/map_' + map_name + '.json',
|
283 |
-
tokenizer = tokenizer,
|
284 |
-
distance_norm_factor = usgs_distance_norm_factor,
|
285 |
-
spatial_dist_fill = spatial_dist_fill,
|
286 |
-
sep_between_neighbors = sep_between_neighbors)
|
287 |
-
|
288 |
-
|
289 |
-
ret_list = disambiguify(model, model_name, usgs_dataset, wikidata_dict_list, candset_mode= candset_mode, if_use_distance = not if_no_spatial_distance, select_indices = None)
|
290 |
-
|
291 |
-
write_to_csv(out_dir, map_name, ret_list)
|
292 |
-
|
293 |
-
print('Done')
|
294 |
-
|
295 |
-
|
296 |
-
def main():
|
297 |
-
parser = argparse.ArgumentParser()
|
298 |
-
parser.add_argument('--model_name', type=str, default='spatial_bert-base')
|
299 |
-
parser.add_argument('--out_dir', type=str, default=None)
|
300 |
-
parser.add_argument('--map_type', type=str, default='usgs')
|
301 |
-
parser.add_argument('--candset_mode', type=str, default='all_map')
|
302 |
-
|
303 |
-
parser.add_argument('--usgs_distance_norm_factor', type=float, default = 1)
|
304 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
305 |
-
|
306 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
307 |
-
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
308 |
-
|
309 |
-
parser.add_argument('--spatial_bert_weight_dir', type = str, default = None)
|
310 |
-
parser.add_argument('--spatial_bert_weight_name', type = str, default = None)
|
311 |
-
|
312 |
-
args = parser.parse_args()
|
313 |
-
print('\n')
|
314 |
-
print(args)
|
315 |
-
print('\n')
|
316 |
-
|
317 |
-
# out_dir not None, and out_dir does not exist, then create out_dir
|
318 |
-
if args.out_dir is not None and not os.path.isdir(args.out_dir):
|
319 |
-
os.makedirs(args.out_dir)
|
320 |
-
|
321 |
-
entity_linking_func(args)
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
if __name__ == '__main__':
|
326 |
-
|
327 |
-
main()
|
328 |
-
|
329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/__init__.py
DELETED
File without changes
|
models/spabert/experiments/semantic_typing/data_processing/merge_osm_json.py
DELETED
@@ -1,97 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import math
|
4 |
-
import glob
|
5 |
-
import re
|
6 |
-
import pdb
|
7 |
-
|
8 |
-
'''
|
9 |
-
NO LONGER NEEDED
|
10 |
-
|
11 |
-
Process the california, london, and minnesota OSM data and prepare pseudo-sentence, spatial context
|
12 |
-
|
13 |
-
Load the raw output files genrated by sql
|
14 |
-
Unify the json by changing the structure of dictionary
|
15 |
-
Save the output into two files, one for training+ validation, and the other one for testing
|
16 |
-
'''
|
17 |
-
|
18 |
-
region_list = ['california','london','minnesota']
|
19 |
-
|
20 |
-
input_json_dir = '../data/sql_output/sub_files/'
|
21 |
-
output_json_dir = '../data/sql_output/'
|
22 |
-
|
23 |
-
for region_name in region_list:
|
24 |
-
file_list = glob.glob(os.path.join(input_json_dir, 'spatialbert-osm-point-' + region_name + '*.json'))
|
25 |
-
file_list = sorted(file_list)
|
26 |
-
print('found %d files for region %s' % (len(file_list), region_name))
|
27 |
-
|
28 |
-
|
29 |
-
num_test_files = int(math.ceil(len(file_list) * 0.2))
|
30 |
-
num_train_val_files = len(file_list) - num_test_files
|
31 |
-
|
32 |
-
print('%d files for train-val' % num_train_val_files)
|
33 |
-
print('%d files for test-tes' % num_test_files)
|
34 |
-
|
35 |
-
train_val_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_train_val.json')
|
36 |
-
test_output_path = os.path.join(output_json_dir + 'osm-point-' + region_name + '_test.json')
|
37 |
-
|
38 |
-
# refresh the file
|
39 |
-
with open(train_val_output_path, 'w') as f:
|
40 |
-
pass
|
41 |
-
with open(test_output_path, 'w') as f:
|
42 |
-
pass
|
43 |
-
|
44 |
-
for idx in range(len(file_list)):
|
45 |
-
|
46 |
-
if idx < num_train_val_files:
|
47 |
-
output_path = train_val_output_path
|
48 |
-
else:
|
49 |
-
output_path = test_output_path
|
50 |
-
|
51 |
-
file_path = file_list[idx]
|
52 |
-
|
53 |
-
print(file_path)
|
54 |
-
|
55 |
-
with open(file_path, 'r') as f:
|
56 |
-
data = f.readlines()
|
57 |
-
|
58 |
-
line = data[0]
|
59 |
-
|
60 |
-
line = re.sub(r'\n', '', line)
|
61 |
-
line = re.sub(r'\\n', '', line)
|
62 |
-
line = re.sub(r'\\+', '', line)
|
63 |
-
line = re.sub(r'\+', '', line)
|
64 |
-
|
65 |
-
line_dict_list = json.loads(line)
|
66 |
-
|
67 |
-
|
68 |
-
for line_dict in line_dict_list:
|
69 |
-
|
70 |
-
line_dict = line_dict['json_build_object']
|
71 |
-
|
72 |
-
if not line_dict['name'][0].isalpha(): # discard record if the first char is not enghlish etter
|
73 |
-
continue
|
74 |
-
|
75 |
-
neighbor_name_list = line_dict['neighbor_info'][0]['name_list']
|
76 |
-
neighbor_geom_list = line_dict['neighbor_info'][0]['geometry_list']
|
77 |
-
|
78 |
-
assert(len(neighbor_geom_list) == len(neighbor_geom_list))
|
79 |
-
|
80 |
-
temp_dict = \
|
81 |
-
{'info':{'name':line_dict['name'],
|
82 |
-
'geometry':{'coordinates':line_dict['geometry']},
|
83 |
-
'class':line_dict['class']
|
84 |
-
},
|
85 |
-
'neighbor_info':{'name_list': neighbor_name_list,
|
86 |
-
'geometry_list': neighbor_geom_list
|
87 |
-
}
|
88 |
-
}
|
89 |
-
|
90 |
-
with open(output_path, 'a') as f:
|
91 |
-
json.dump(temp_dict, f)
|
92 |
-
f.write('\n')
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/__init__.py
DELETED
File without changes
|
models/spabert/experiments/semantic_typing/src/run_baseline_test.py
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import pdb
|
3 |
-
import argparse
|
4 |
-
import numpy as np
|
5 |
-
import time
|
6 |
-
|
7 |
-
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
8 |
-
'spanbert-base','spanbert-large','luke-base','luke-large',
|
9 |
-
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
10 |
-
|
11 |
-
def execute_command(command, if_print_command):
|
12 |
-
t1 = time.time()
|
13 |
-
|
14 |
-
if if_print_command:
|
15 |
-
print(command)
|
16 |
-
os.system(command)
|
17 |
-
|
18 |
-
t2 = time.time()
|
19 |
-
time_usage = t2 - t1
|
20 |
-
return time_usage
|
21 |
-
|
22 |
-
def run_test(args):
|
23 |
-
weight_dir = args.weight_dir
|
24 |
-
backbone_option = args.backbone_option
|
25 |
-
gpu_id = str(args.gpu_id)
|
26 |
-
if_print_command = args.print_command
|
27 |
-
sep_between_neighbors = args.sep_between_neighbors
|
28 |
-
|
29 |
-
assert backbone_option in MODEL_OPTIONS
|
30 |
-
|
31 |
-
if sep_between_neighbors:
|
32 |
-
sep_str = '_sep'
|
33 |
-
else:
|
34 |
-
sep_str = ''
|
35 |
-
|
36 |
-
if 'large' in backbone_option:
|
37 |
-
checkpoint_dir = os.path.join(weight_dir, 'typing_lr1e-06_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
|
38 |
-
else:
|
39 |
-
checkpoint_dir = os.path.join(weight_dir, 'typing_lr5e-05_%s_nofreeze%s_london_california_bsize12'% (backbone_option, sep_str))
|
40 |
-
weight_files = os.listdir(checkpoint_dir)
|
41 |
-
|
42 |
-
val_loss_list = [weight_file.split('_')[-1] for weight_file in weight_files]
|
43 |
-
min_loss_weight = weight_files[np.argmin(val_loss_list)]
|
44 |
-
|
45 |
-
checkpoint_path = os.path.join(checkpoint_dir, min_loss_weight)
|
46 |
-
|
47 |
-
if sep_between_neighbors:
|
48 |
-
command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --sep_between_neighbors --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
|
49 |
-
else:
|
50 |
-
command = 'CUDA_VISIBLE_DEVICES=%s python3 test_cls_baseline.py --backbone_option=%s --batch_size=8 --with_type --checkpoint_path=%s ' % (gpu_id, backbone_option, checkpoint_path)
|
51 |
-
|
52 |
-
|
53 |
-
execute_command(command, if_print_command)
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
def main():
|
59 |
-
parser = argparse.ArgumentParser()
|
60 |
-
|
61 |
-
parser.add_argument('--weight_dir', type=str, default='/data2/zekun/spatial_bert_baseline_weights/')
|
62 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
63 |
-
parser.add_argument('--backbone_option', type=str, default=None)
|
64 |
-
parser.add_argument('--gpu_id', type=int, default=0) # output prefix
|
65 |
-
|
66 |
-
parser.add_argument('--print_command', default=False, action='store_true')
|
67 |
-
|
68 |
-
|
69 |
-
args = parser.parse_args()
|
70 |
-
print('\n')
|
71 |
-
print(args)
|
72 |
-
print('\n')
|
73 |
-
|
74 |
-
run_test(args)
|
75 |
-
|
76 |
-
|
77 |
-
if __name__ == '__main__':
|
78 |
-
|
79 |
-
main()
|
80 |
-
|
81 |
-
|
82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/test_cls_ablation_spatialbert.py
DELETED
@@ -1,209 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
-
from tqdm import tqdm # for our progress bar
|
5 |
-
from transformers import AdamW
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import DataLoader
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
sys.path.append('../../../')
|
12 |
-
from models.spatial_bert_model import SpatialBertModel
|
13 |
-
from models.spatial_bert_model import SpatialBertConfig
|
14 |
-
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
15 |
-
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
-
from datasets.const import *
|
17 |
-
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
-
|
19 |
-
from sklearn.metrics import label_ranking_average_precision_score
|
20 |
-
from sklearn.metrics import precision_recall_fscore_support
|
21 |
-
import numpy as np
|
22 |
-
import argparse
|
23 |
-
from sklearn.preprocessing import LabelEncoder
|
24 |
-
import pdb
|
25 |
-
|
26 |
-
|
27 |
-
DEBUG = False
|
28 |
-
torch.backends.cudnn.deterministic = True
|
29 |
-
torch.backends.cudnn.benchmark = False
|
30 |
-
torch.manual_seed(42)
|
31 |
-
torch.cuda.manual_seed_all(42)
|
32 |
-
|
33 |
-
|
34 |
-
def testing(args):
|
35 |
-
|
36 |
-
max_token_len = args.max_token_len
|
37 |
-
batch_size = args.batch_size
|
38 |
-
num_workers = args.num_workers
|
39 |
-
distance_norm_factor = args.distance_norm_factor
|
40 |
-
spatial_dist_fill=args.spatial_dist_fill
|
41 |
-
with_type = args.with_type
|
42 |
-
sep_between_neighbors = args.sep_between_neighbors
|
43 |
-
checkpoint_path = args.checkpoint_path
|
44 |
-
if_no_spatial_distance = args.no_spatial_distance
|
45 |
-
|
46 |
-
bert_option = args.bert_option
|
47 |
-
num_neighbor_limit = args.num_neighbor_limit
|
48 |
-
|
49 |
-
|
50 |
-
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
51 |
-
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
52 |
-
|
53 |
-
|
54 |
-
if bert_option == 'bert-base':
|
55 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
56 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(CLASS_9_LIST))
|
57 |
-
elif bert_option == 'bert-large':
|
58 |
-
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
59 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(CLASS_9_LIST))
|
60 |
-
else:
|
61 |
-
raise NotImplementedError
|
62 |
-
|
63 |
-
|
64 |
-
model = SpatialBertForSemanticTyping(config)
|
65 |
-
|
66 |
-
|
67 |
-
label_encoder = LabelEncoder()
|
68 |
-
label_encoder.fit(CLASS_9_LIST)
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
73 |
-
tokenizer = tokenizer,
|
74 |
-
max_token_len = max_token_len,
|
75 |
-
distance_norm_factor = distance_norm_factor,
|
76 |
-
spatial_dist_fill = spatial_dist_fill,
|
77 |
-
with_type = with_type,
|
78 |
-
sep_between_neighbors = sep_between_neighbors,
|
79 |
-
label_encoder = label_encoder,
|
80 |
-
num_neighbor_limit = num_neighbor_limit,
|
81 |
-
mode = 'test',)
|
82 |
-
|
83 |
-
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
84 |
-
tokenizer = tokenizer,
|
85 |
-
max_token_len = max_token_len,
|
86 |
-
distance_norm_factor = distance_norm_factor,
|
87 |
-
spatial_dist_fill = spatial_dist_fill,
|
88 |
-
with_type = with_type,
|
89 |
-
sep_between_neighbors = sep_between_neighbors,
|
90 |
-
label_encoder = label_encoder,
|
91 |
-
num_neighbor_limit = num_neighbor_limit,
|
92 |
-
mode = 'test')
|
93 |
-
|
94 |
-
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
99 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
100 |
-
|
101 |
-
|
102 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
103 |
-
model.to(device)
|
104 |
-
|
105 |
-
|
106 |
-
model.load_state_dict(torch.load(checkpoint_path)) # #
|
107 |
-
|
108 |
-
model.eval()
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
print('start testing...')
|
113 |
-
|
114 |
-
|
115 |
-
# setup loop with TQDM and dataloader
|
116 |
-
loop = tqdm(test_loader, leave=True)
|
117 |
-
|
118 |
-
|
119 |
-
mrr_total = 0.
|
120 |
-
prec_total = 0.
|
121 |
-
sample_cnt = 0
|
122 |
-
|
123 |
-
gt_list = []
|
124 |
-
pred_list = []
|
125 |
-
|
126 |
-
for batch in loop:
|
127 |
-
# initialize calculated gradients (from prev step)
|
128 |
-
|
129 |
-
# pull all tensor batches required for training
|
130 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
131 |
-
attention_mask = batch['attention_mask'].to(device)
|
132 |
-
position_list_x = batch['norm_lng_list'].to(device)
|
133 |
-
position_list_y = batch['norm_lat_list'].to(device)
|
134 |
-
sent_position_ids = batch['sent_position_ids'].to(device)
|
135 |
-
|
136 |
-
#labels = batch['pseudo_sentence'].to(device)
|
137 |
-
labels = batch['pivot_type'].to(device)
|
138 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
139 |
-
|
140 |
-
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
141 |
-
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
142 |
-
|
143 |
-
|
144 |
-
onehot_labels = F.one_hot(labels, num_classes=9)
|
145 |
-
|
146 |
-
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
147 |
-
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
148 |
-
|
149 |
-
#pdb.set_trace()
|
150 |
-
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
151 |
-
mrr_total += mrr * input_ids.shape[0]
|
152 |
-
sample_cnt += input_ids.shape[0]
|
153 |
-
|
154 |
-
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
155 |
-
|
156 |
-
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
157 |
-
|
158 |
-
# print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
159 |
-
# print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
160 |
-
# print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
161 |
-
# print('supports:\n', supports)
|
162 |
-
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
def main():
|
167 |
-
|
168 |
-
parser = argparse.ArgumentParser()
|
169 |
-
|
170 |
-
parser.add_argument('--max_token_len', type=int, default=300)
|
171 |
-
parser.add_argument('--batch_size', type=int, default=12)
|
172 |
-
parser.add_argument('--num_workers', type=int, default=5)
|
173 |
-
|
174 |
-
parser.add_argument('--num_neighbor_limit', type=int, default = None)
|
175 |
-
|
176 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
177 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
178 |
-
|
179 |
-
|
180 |
-
parser.add_argument('--with_type', default=False, action='store_true')
|
181 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
182 |
-
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
183 |
-
|
184 |
-
parser.add_argument('--bert_option', type=str, default='bert-base')
|
185 |
-
parser.add_argument('--prediction_save_dir', type=str, default=None)
|
186 |
-
|
187 |
-
parser.add_argument('--checkpoint_path', type=str, default=None)
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
args = parser.parse_args()
|
192 |
-
print('\n')
|
193 |
-
print(args)
|
194 |
-
print('\n')
|
195 |
-
|
196 |
-
|
197 |
-
# out_dir not None, and out_dir does not exist, then create out_dir
|
198 |
-
if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
|
199 |
-
os.makedirs(args.prediction_save_dir)
|
200 |
-
|
201 |
-
testing(args)
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
if __name__ == '__main__':
|
206 |
-
|
207 |
-
main()
|
208 |
-
|
209 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/test_cls_baseline.py
DELETED
@@ -1,189 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from tqdm import tqdm # for our progress bar
|
4 |
-
import numpy as np
|
5 |
-
import argparse
|
6 |
-
from sklearn.preprocessing import LabelEncoder
|
7 |
-
import pdb
|
8 |
-
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from transformers import AdamW
|
13 |
-
import torch.nn.functional as F
|
14 |
-
|
15 |
-
sys.path.append('../../../')
|
16 |
-
from datasets.osm_sample_loader import PbfMapDataset
|
17 |
-
from datasets.const import *
|
18 |
-
from utils.baseline_utils import get_baseline_model
|
19 |
-
from models.baseline_typing_model import BaselineForSemanticTyping
|
20 |
-
|
21 |
-
from sklearn.metrics import label_ranking_average_precision_score
|
22 |
-
from sklearn.metrics import precision_recall_fscore_support
|
23 |
-
|
24 |
-
torch.backends.cudnn.deterministic = True
|
25 |
-
torch.backends.cudnn.benchmark = False
|
26 |
-
torch.manual_seed(42)
|
27 |
-
torch.cuda.manual_seed_all(42)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
32 |
-
'spanbert-base','spanbert-large','luke-base','luke-large',
|
33 |
-
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
34 |
-
|
35 |
-
def testing(args):
|
36 |
-
|
37 |
-
num_workers = args.num_workers
|
38 |
-
batch_size = args.batch_size
|
39 |
-
max_token_len = args.max_token_len
|
40 |
-
|
41 |
-
distance_norm_factor = args.distance_norm_factor
|
42 |
-
spatial_dist_fill=args.spatial_dist_fill
|
43 |
-
with_type = args.with_type
|
44 |
-
sep_between_neighbors = args.sep_between_neighbors
|
45 |
-
freeze_backbone = args.freeze_backbone
|
46 |
-
|
47 |
-
|
48 |
-
backbone_option = args.backbone_option
|
49 |
-
|
50 |
-
checkpoint_path = args.checkpoint_path
|
51 |
-
|
52 |
-
assert(backbone_option in MODEL_OPTIONS)
|
53 |
-
|
54 |
-
|
55 |
-
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
56 |
-
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
label_encoder = LabelEncoder()
|
61 |
-
label_encoder.fit(CLASS_9_LIST)
|
62 |
-
|
63 |
-
|
64 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
65 |
-
|
66 |
-
backbone_model, tokenizer = get_baseline_model(backbone_option)
|
67 |
-
model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
|
68 |
-
|
69 |
-
model.load_state_dict(torch.load(checkpoint_path) ) #, strict = False # load sentence position embedding weights as well
|
70 |
-
|
71 |
-
model.to(device)
|
72 |
-
model.train()
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
77 |
-
tokenizer = tokenizer,
|
78 |
-
max_token_len = max_token_len,
|
79 |
-
distance_norm_factor = distance_norm_factor,
|
80 |
-
spatial_dist_fill = spatial_dist_fill,
|
81 |
-
with_type = with_type,
|
82 |
-
sep_between_neighbors = sep_between_neighbors,
|
83 |
-
label_encoder = label_encoder,
|
84 |
-
mode = 'test')
|
85 |
-
|
86 |
-
|
87 |
-
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
88 |
-
tokenizer = tokenizer,
|
89 |
-
max_token_len = max_token_len,
|
90 |
-
distance_norm_factor = distance_norm_factor,
|
91 |
-
spatial_dist_fill = spatial_dist_fill,
|
92 |
-
with_type = with_type,
|
93 |
-
sep_between_neighbors = sep_between_neighbors,
|
94 |
-
label_encoder = label_encoder,
|
95 |
-
mode = 'test')
|
96 |
-
|
97 |
-
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
98 |
-
|
99 |
-
|
100 |
-
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
101 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
print('start testing...')
|
108 |
-
|
109 |
-
# setup loop with TQDM and dataloader
|
110 |
-
loop = tqdm(test_loader, leave=True)
|
111 |
-
|
112 |
-
mrr_total = 0.
|
113 |
-
prec_total = 0.
|
114 |
-
sample_cnt = 0
|
115 |
-
|
116 |
-
gt_list = []
|
117 |
-
pred_list = []
|
118 |
-
|
119 |
-
for batch in loop:
|
120 |
-
# initialize calculated gradients (from prev step)
|
121 |
-
|
122 |
-
# pull all tensor batches required for training
|
123 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
124 |
-
attention_mask = batch['attention_mask'].to(device)
|
125 |
-
position_ids = batch['sent_position_ids'].to(device)
|
126 |
-
|
127 |
-
#labels = batch['pseudo_sentence'].to(device)
|
128 |
-
labels = batch['pivot_type'].to(device)
|
129 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
130 |
-
|
131 |
-
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
132 |
-
labels = labels, pivot_len_list = pivot_lens)
|
133 |
-
|
134 |
-
|
135 |
-
onehot_labels = F.one_hot(labels, num_classes=9)
|
136 |
-
|
137 |
-
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
138 |
-
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
139 |
-
|
140 |
-
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
141 |
-
mrr_total += mrr * input_ids.shape[0]
|
142 |
-
sample_cnt += input_ids.shape[0]
|
143 |
-
|
144 |
-
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
145 |
-
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
146 |
-
print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
147 |
-
print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
148 |
-
print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
149 |
-
print('supports:\n', supports)
|
150 |
-
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
151 |
-
|
152 |
-
#pdb.set_trace()
|
153 |
-
#print(mrr_total/sample_cnt)
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
def main():
|
158 |
-
|
159 |
-
parser = argparse.ArgumentParser()
|
160 |
-
parser.add_argument('--num_workers', type=int, default=5)
|
161 |
-
parser.add_argument('--batch_size', type=int, default=12)
|
162 |
-
parser.add_argument('--max_token_len', type=int, default=300)
|
163 |
-
|
164 |
-
|
165 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
166 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
167 |
-
|
168 |
-
parser.add_argument('--with_type', default=False, action='store_true')
|
169 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
170 |
-
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
171 |
-
|
172 |
-
parser.add_argument('--backbone_option', type=str, default='bert-base')
|
173 |
-
parser.add_argument('--checkpoint_path', type=str, default=None)
|
174 |
-
|
175 |
-
|
176 |
-
args = parser.parse_args()
|
177 |
-
print('\n')
|
178 |
-
print(args)
|
179 |
-
print('\n')
|
180 |
-
|
181 |
-
|
182 |
-
testing(args)
|
183 |
-
|
184 |
-
|
185 |
-
if __name__ == '__main__':
|
186 |
-
|
187 |
-
main()
|
188 |
-
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/test_cls_spatialbert.py
DELETED
@@ -1,214 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
-
from tqdm import tqdm # for our progress bar
|
5 |
-
from transformers import AdamW
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import DataLoader
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
sys.path.append('../../../')
|
12 |
-
from models.spatial_bert_model import SpatialBertModel
|
13 |
-
from models.spatial_bert_model import SpatialBertConfig
|
14 |
-
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
15 |
-
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
-
from datasets.const import *
|
17 |
-
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
-
|
19 |
-
from sklearn.metrics import label_ranking_average_precision_score
|
20 |
-
from sklearn.metrics import precision_recall_fscore_support
|
21 |
-
import numpy as np
|
22 |
-
import argparse
|
23 |
-
from sklearn.preprocessing import LabelEncoder
|
24 |
-
import pdb
|
25 |
-
|
26 |
-
|
27 |
-
DEBUG = False
|
28 |
-
|
29 |
-
torch.backends.cudnn.deterministic = True
|
30 |
-
torch.backends.cudnn.benchmark = False
|
31 |
-
torch.manual_seed(42)
|
32 |
-
torch.cuda.manual_seed_all(42)
|
33 |
-
|
34 |
-
|
35 |
-
def testing(args):
|
36 |
-
|
37 |
-
max_token_len = args.max_token_len
|
38 |
-
batch_size = args.batch_size
|
39 |
-
num_workers = args.num_workers
|
40 |
-
distance_norm_factor = args.distance_norm_factor
|
41 |
-
spatial_dist_fill=args.spatial_dist_fill
|
42 |
-
with_type = args.with_type
|
43 |
-
sep_between_neighbors = args.sep_between_neighbors
|
44 |
-
checkpoint_path = args.checkpoint_path
|
45 |
-
if_no_spatial_distance = args.no_spatial_distance
|
46 |
-
|
47 |
-
bert_option = args.bert_option
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
if args.num_classes == 9:
|
52 |
-
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
|
53 |
-
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
|
54 |
-
TYPE_LIST = CLASS_9_LIST
|
55 |
-
type_key_str = 'class'
|
56 |
-
elif args.num_classes == 74:
|
57 |
-
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
|
58 |
-
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
|
59 |
-
TYPE_LIST = CLASS_74_LIST
|
60 |
-
type_key_str = 'fine_class'
|
61 |
-
else:
|
62 |
-
raise NotImplementedError
|
63 |
-
|
64 |
-
if bert_option == 'bert-base':
|
65 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
66 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
|
67 |
-
elif bert_option == 'bert-large':
|
68 |
-
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
69 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24,num_semantic_types=len(TYPE_LIST))
|
70 |
-
else:
|
71 |
-
raise NotImplementedError
|
72 |
-
|
73 |
-
|
74 |
-
model = SpatialBertForSemanticTyping(config)
|
75 |
-
|
76 |
-
|
77 |
-
label_encoder = LabelEncoder()
|
78 |
-
label_encoder.fit(TYPE_LIST)
|
79 |
-
|
80 |
-
london_dataset = PbfMapDataset(data_file_path = london_file_path,
|
81 |
-
tokenizer = tokenizer,
|
82 |
-
max_token_len = max_token_len,
|
83 |
-
distance_norm_factor = distance_norm_factor,
|
84 |
-
spatial_dist_fill = spatial_dist_fill,
|
85 |
-
with_type = with_type,
|
86 |
-
type_key_str = type_key_str,
|
87 |
-
sep_between_neighbors = sep_between_neighbors,
|
88 |
-
label_encoder = label_encoder,
|
89 |
-
mode = 'test')
|
90 |
-
|
91 |
-
california_dataset = PbfMapDataset(data_file_path = california_file_path,
|
92 |
-
tokenizer = tokenizer,
|
93 |
-
max_token_len = max_token_len,
|
94 |
-
distance_norm_factor = distance_norm_factor,
|
95 |
-
spatial_dist_fill = spatial_dist_fill,
|
96 |
-
with_type = with_type,
|
97 |
-
type_key_str = type_key_str,
|
98 |
-
sep_between_neighbors = sep_between_neighbors,
|
99 |
-
label_encoder = label_encoder,
|
100 |
-
mode = 'test')
|
101 |
-
|
102 |
-
test_dataset = torch.utils.data.ConcatDataset([london_dataset, california_dataset])
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
test_loader = DataLoader(test_dataset, batch_size= batch_size, num_workers=num_workers,
|
107 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
108 |
-
|
109 |
-
|
110 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
111 |
-
model.to(device)
|
112 |
-
|
113 |
-
model.load_state_dict(torch.load(checkpoint_path))
|
114 |
-
|
115 |
-
model.eval()
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
print('start testing...')
|
120 |
-
|
121 |
-
|
122 |
-
# setup loop with TQDM and dataloader
|
123 |
-
loop = tqdm(test_loader, leave=True)
|
124 |
-
|
125 |
-
|
126 |
-
mrr_total = 0.
|
127 |
-
prec_total = 0.
|
128 |
-
sample_cnt = 0
|
129 |
-
|
130 |
-
gt_list = []
|
131 |
-
pred_list = []
|
132 |
-
|
133 |
-
for batch in loop:
|
134 |
-
# initialize calculated gradients (from prev step)
|
135 |
-
|
136 |
-
# pull all tensor batches required for training
|
137 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
138 |
-
attention_mask = batch['attention_mask'].to(device)
|
139 |
-
position_list_x = batch['norm_lng_list'].to(device)
|
140 |
-
position_list_y = batch['norm_lat_list'].to(device)
|
141 |
-
sent_position_ids = batch['sent_position_ids'].to(device)
|
142 |
-
|
143 |
-
#labels = batch['pseudo_sentence'].to(device)
|
144 |
-
labels = batch['pivot_type'].to(device)
|
145 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
146 |
-
|
147 |
-
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
148 |
-
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
149 |
-
|
150 |
-
|
151 |
-
onehot_labels = F.one_hot(labels, num_classes=len(TYPE_LIST))
|
152 |
-
|
153 |
-
gt_list.extend(onehot_labels.cpu().detach().numpy())
|
154 |
-
pred_list.extend(outputs.logits.cpu().detach().numpy())
|
155 |
-
|
156 |
-
mrr = label_ranking_average_precision_score(onehot_labels.cpu().detach().numpy(), outputs.logits.cpu().detach().numpy())
|
157 |
-
mrr_total += mrr * input_ids.shape[0]
|
158 |
-
sample_cnt += input_ids.shape[0]
|
159 |
-
|
160 |
-
precisions, recalls, fscores, supports = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average=None)
|
161 |
-
# print('precisions:\n', precisions)
|
162 |
-
# print('recalls:\n', recalls)
|
163 |
-
# print('fscores:\n', fscores)
|
164 |
-
# print('supports:\n', supports)
|
165 |
-
precision, recall, f1, _ = precision_recall_fscore_support(np.argmax(np.array(gt_list),axis=1), np.argmax(np.array(pred_list),axis=1), average='micro')
|
166 |
-
print('precisions:\n', ["{:.3f}".format(prec) for prec in precisions])
|
167 |
-
print('recalls:\n', ["{:.3f}".format(rec) for rec in recalls])
|
168 |
-
print('fscores:\n', ["{:.3f}".format(f1) for f1 in fscores])
|
169 |
-
print('supports:\n', supports)
|
170 |
-
print('micro P, micro R, micro F1', "{:.3f}".format(precision), "{:.3f}".format(recall), "{:.3f}".format(f1))
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
def main():
|
176 |
-
|
177 |
-
parser = argparse.ArgumentParser()
|
178 |
-
|
179 |
-
parser.add_argument('--max_token_len', type=int, default=512)
|
180 |
-
parser.add_argument('--batch_size', type=int, default=12)
|
181 |
-
parser.add_argument('--num_workers', type=int, default=5)
|
182 |
-
|
183 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
184 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
185 |
-
parser.add_argument('--num_classes', type=int, default = 9)
|
186 |
-
|
187 |
-
parser.add_argument('--with_type', default=False, action='store_true')
|
188 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
189 |
-
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
190 |
-
|
191 |
-
parser.add_argument('--bert_option', type=str, default='bert-base')
|
192 |
-
parser.add_argument('--prediction_save_dir', type=str, default=None)
|
193 |
-
|
194 |
-
parser.add_argument('--checkpoint_path', type=str, default=None)
|
195 |
-
|
196 |
-
|
197 |
-
args = parser.parse_args()
|
198 |
-
print('\n')
|
199 |
-
print(args)
|
200 |
-
print('\n')
|
201 |
-
|
202 |
-
|
203 |
-
# out_dir not None, and out_dir does not exist, then create out_dir
|
204 |
-
if args.prediction_save_dir is not None and not os.path.isdir(args.prediction_save_dir):
|
205 |
-
os.makedirs(args.prediction_save_dir)
|
206 |
-
|
207 |
-
testing(args)
|
208 |
-
|
209 |
-
|
210 |
-
if __name__ == '__main__':
|
211 |
-
|
212 |
-
main()
|
213 |
-
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/train_cls_baseline.py
DELETED
@@ -1,227 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from tqdm import tqdm # for our progress bar
|
4 |
-
import numpy as np
|
5 |
-
import argparse
|
6 |
-
from sklearn.preprocessing import LabelEncoder
|
7 |
-
import pdb
|
8 |
-
|
9 |
-
|
10 |
-
import torch
|
11 |
-
from torch.utils.data import DataLoader
|
12 |
-
from transformers import AdamW
|
13 |
-
|
14 |
-
sys.path.append('../../../')
|
15 |
-
from datasets.osm_sample_loader import PbfMapDataset
|
16 |
-
from datasets.const import *
|
17 |
-
from utils.baseline_utils import get_baseline_model
|
18 |
-
from models.baseline_typing_model import BaselineForSemanticTyping
|
19 |
-
|
20 |
-
|
21 |
-
MODEL_OPTIONS = ['bert-base','bert-large','roberta-base','roberta-large',
|
22 |
-
'spanbert-base','spanbert-large','luke-base','luke-large',
|
23 |
-
'simcse-bert-base','simcse-bert-large','simcse-roberta-base','simcse-roberta-large']
|
24 |
-
|
25 |
-
def training(args):
|
26 |
-
|
27 |
-
num_workers = args.num_workers
|
28 |
-
batch_size = args.batch_size
|
29 |
-
epochs = args.epochs
|
30 |
-
lr = args.lr #1e-7 # 5e-5
|
31 |
-
save_interval = args.save_interval
|
32 |
-
max_token_len = args.max_token_len
|
33 |
-
distance_norm_factor = args.distance_norm_factor
|
34 |
-
spatial_dist_fill=args.spatial_dist_fill
|
35 |
-
with_type = args.with_type
|
36 |
-
sep_between_neighbors = args.sep_between_neighbors
|
37 |
-
freeze_backbone = args.freeze_backbone
|
38 |
-
|
39 |
-
|
40 |
-
backbone_option = args.backbone_option
|
41 |
-
|
42 |
-
assert(backbone_option in MODEL_OPTIONS)
|
43 |
-
|
44 |
-
|
45 |
-
london_file_path = '../data/sql_output/osm-point-london-typing.json'
|
46 |
-
california_file_path = '../data/sql_output/osm-point-california-typing.json'
|
47 |
-
|
48 |
-
if args.model_save_dir is None:
|
49 |
-
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
|
50 |
-
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
|
51 |
-
model_save_dir = '/data2/zekun/spatial_bert_baseline_weights/typing_lr' + str("{:.0e}".format(lr)) +'_'+backbone_option+ freeze_pathstr + sep_pathstr + '_london_california_bsize' + str(batch_size)
|
52 |
-
|
53 |
-
if not os.path.isdir(model_save_dir):
|
54 |
-
os.makedirs(model_save_dir)
|
55 |
-
else:
|
56 |
-
model_save_dir = args.model_save_dir
|
57 |
-
|
58 |
-
print('model_save_dir', model_save_dir)
|
59 |
-
print('\n')
|
60 |
-
|
61 |
-
|
62 |
-
label_encoder = LabelEncoder()
|
63 |
-
label_encoder.fit(CLASS_9_LIST)
|
64 |
-
|
65 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
66 |
-
|
67 |
-
backbone_model, tokenizer = get_baseline_model(backbone_option)
|
68 |
-
model = BaselineForSemanticTyping(backbone_model, backbone_model.config.hidden_size, len(CLASS_9_LIST))
|
69 |
-
|
70 |
-
model.to(device)
|
71 |
-
model.train()
|
72 |
-
|
73 |
-
london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
|
74 |
-
tokenizer = tokenizer,
|
75 |
-
max_token_len = max_token_len,
|
76 |
-
distance_norm_factor = distance_norm_factor,
|
77 |
-
spatial_dist_fill = spatial_dist_fill,
|
78 |
-
with_type = with_type,
|
79 |
-
sep_between_neighbors = sep_between_neighbors,
|
80 |
-
label_encoder = label_encoder,
|
81 |
-
mode = 'train')
|
82 |
-
|
83 |
-
percent_80 = int(len(london_train_val_dataset) * 0.8)
|
84 |
-
london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
|
85 |
-
|
86 |
-
california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
|
87 |
-
tokenizer = tokenizer,
|
88 |
-
max_token_len = max_token_len,
|
89 |
-
distance_norm_factor = distance_norm_factor,
|
90 |
-
spatial_dist_fill = spatial_dist_fill,
|
91 |
-
with_type = with_type,
|
92 |
-
sep_between_neighbors = sep_between_neighbors,
|
93 |
-
label_encoder = label_encoder,
|
94 |
-
mode = 'train')
|
95 |
-
percent_80 = int(len(california_train_val_dataset) * 0.8)
|
96 |
-
california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
|
97 |
-
|
98 |
-
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
|
99 |
-
val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
104 |
-
shuffle=True, pin_memory=True, drop_last=True)
|
105 |
-
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
106 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
# initialize optimizer
|
115 |
-
optim = AdamW(model.parameters(), lr = lr)
|
116 |
-
|
117 |
-
print('start training...')
|
118 |
-
|
119 |
-
for epoch in range(epochs):
|
120 |
-
# setup loop with TQDM and dataloader
|
121 |
-
loop = tqdm(train_loader, leave=True)
|
122 |
-
iter = 0
|
123 |
-
for batch in loop:
|
124 |
-
# initialize calculated gradients (from prev step)
|
125 |
-
optim.zero_grad()
|
126 |
-
# pull all tensor batches required for training
|
127 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
128 |
-
attention_mask = batch['attention_mask'].to(device)
|
129 |
-
position_ids = batch['sent_position_ids'].to(device)
|
130 |
-
|
131 |
-
#labels = batch['pseudo_sentence'].to(device)
|
132 |
-
labels = batch['pivot_type'].to(device)
|
133 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
134 |
-
|
135 |
-
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
136 |
-
labels = labels, pivot_len_list = pivot_lens)
|
137 |
-
|
138 |
-
|
139 |
-
loss = outputs.loss
|
140 |
-
loss.backward()
|
141 |
-
optim.step()
|
142 |
-
|
143 |
-
loop.set_description(f'Epoch {epoch}')
|
144 |
-
loop.set_postfix({'loss':loss.item()})
|
145 |
-
|
146 |
-
|
147 |
-
iter += 1
|
148 |
-
|
149 |
-
if iter % save_interval == 0 or iter == loop.total:
|
150 |
-
loss_valid = validating(val_loader, model, device)
|
151 |
-
print('validation loss', loss_valid)
|
152 |
-
|
153 |
-
save_path = os.path.join(model_save_dir, 'ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
|
154 |
-
+ '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
|
155 |
-
|
156 |
-
torch.save(model.state_dict(), save_path)
|
157 |
-
print('saving model checkpoint to', save_path)
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
def validating(val_loader, model, device):
|
162 |
-
|
163 |
-
with torch.no_grad():
|
164 |
-
|
165 |
-
loss_valid = 0
|
166 |
-
loop = tqdm(val_loader, leave=True)
|
167 |
-
|
168 |
-
for batch in loop:
|
169 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
170 |
-
attention_mask = batch['attention_mask'].to(device)
|
171 |
-
position_ids = batch['sent_position_ids'].to(device)
|
172 |
-
|
173 |
-
labels = batch['pivot_type'].to(device)
|
174 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
175 |
-
|
176 |
-
outputs = model(input_ids, attention_mask = attention_mask, position_ids = position_ids,
|
177 |
-
labels = labels, pivot_len_list = pivot_lens)
|
178 |
-
|
179 |
-
loss_valid += outputs.loss
|
180 |
-
|
181 |
-
loss_valid /= len(val_loader)
|
182 |
-
|
183 |
-
return loss_valid
|
184 |
-
|
185 |
-
|
186 |
-
def main():
|
187 |
-
|
188 |
-
parser = argparse.ArgumentParser()
|
189 |
-
parser.add_argument('--num_workers', type=int, default=5)
|
190 |
-
parser.add_argument('--batch_size', type=int, default=12)
|
191 |
-
parser.add_argument('--epochs', type=int, default=10)
|
192 |
-
parser.add_argument('--save_interval', type=int, default=2000)
|
193 |
-
parser.add_argument('--max_token_len', type=int, default=300)
|
194 |
-
|
195 |
-
|
196 |
-
parser.add_argument('--lr', type=float, default = 5e-5)
|
197 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
198 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 20)
|
199 |
-
|
200 |
-
parser.add_argument('--with_type', default=False, action='store_true')
|
201 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
202 |
-
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
203 |
-
|
204 |
-
parser.add_argument('--backbone_option', type=str, default='bert-base')
|
205 |
-
parser.add_argument('--model_save_dir', type=str, default=None)
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
args = parser.parse_args()
|
210 |
-
print('\n')
|
211 |
-
print(args)
|
212 |
-
print('\n')
|
213 |
-
|
214 |
-
|
215 |
-
# out_dir not None, and out_dir does not exist, then create out_dir
|
216 |
-
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
|
217 |
-
os.makedirs(args.model_save_dir)
|
218 |
-
|
219 |
-
training(args)
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
if __name__ == '__main__':
|
224 |
-
|
225 |
-
main()
|
226 |
-
|
227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/spabert/experiments/semantic_typing/src/train_cls_spatialbert.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
from transformers import RobertaTokenizer, BertTokenizer
|
4 |
-
from tqdm import tqdm # for our progress bar
|
5 |
-
from transformers import AdamW
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from torch.utils.data import DataLoader
|
9 |
-
|
10 |
-
sys.path.append('../../../')
|
11 |
-
from models.spatial_bert_model import SpatialBertModel
|
12 |
-
from models.spatial_bert_model import SpatialBertConfig
|
13 |
-
from models.spatial_bert_model import SpatialBertForMaskedLM, SpatialBertForSemanticTyping
|
14 |
-
from datasets.osm_sample_loader import PbfMapDataset
|
15 |
-
from datasets.const import *
|
16 |
-
|
17 |
-
from transformers.models.bert.modeling_bert import BertForMaskedLM
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import argparse
|
21 |
-
from sklearn.preprocessing import LabelEncoder
|
22 |
-
import pdb
|
23 |
-
|
24 |
-
|
25 |
-
DEBUG = False
|
26 |
-
|
27 |
-
|
28 |
-
def training(args):
|
29 |
-
|
30 |
-
num_workers = args.num_workers
|
31 |
-
batch_size = args.batch_size
|
32 |
-
epochs = args.epochs
|
33 |
-
lr = args.lr #1e-7 # 5e-5
|
34 |
-
save_interval = args.save_interval
|
35 |
-
max_token_len = args.max_token_len
|
36 |
-
distance_norm_factor = args.distance_norm_factor
|
37 |
-
spatial_dist_fill=args.spatial_dist_fill
|
38 |
-
with_type = args.with_type
|
39 |
-
sep_between_neighbors = args.sep_between_neighbors
|
40 |
-
freeze_backbone = args.freeze_backbone
|
41 |
-
mlm_checkpoint_path = args.mlm_checkpoint_path
|
42 |
-
|
43 |
-
if_no_spatial_distance = args.no_spatial_distance
|
44 |
-
|
45 |
-
|
46 |
-
bert_option = args.bert_option
|
47 |
-
|
48 |
-
assert bert_option in ['bert-base','bert-large']
|
49 |
-
|
50 |
-
if args.num_classes == 9:
|
51 |
-
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing.json'
|
52 |
-
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing.json'
|
53 |
-
TYPE_LIST = CLASS_9_LIST
|
54 |
-
type_key_str = 'class'
|
55 |
-
elif args.num_classes == 74:
|
56 |
-
london_file_path = '../../semantic_typing/data/sql_output/osm-point-london-typing-ranking.json'
|
57 |
-
california_file_path = '../../semantic_typing/data/sql_output/osm-point-california-typing-ranking.json'
|
58 |
-
TYPE_LIST = CLASS_74_LIST
|
59 |
-
type_key_str = 'fine_class'
|
60 |
-
else:
|
61 |
-
raise NotImplementedError
|
62 |
-
|
63 |
-
|
64 |
-
if args.model_save_dir is None:
|
65 |
-
checkpoint_basename = os.path.basename(mlm_checkpoint_path)
|
66 |
-
checkpoint_prefix = checkpoint_basename.replace("mlm_mem_keeppos_","").strip('.pth')
|
67 |
-
|
68 |
-
sep_pathstr = '_sep' if sep_between_neighbors else '_nosep'
|
69 |
-
freeze_pathstr = '_freeze' if freeze_backbone else '_nofreeze'
|
70 |
-
if if_no_spatial_distance:
|
71 |
-
model_save_dir = '/data2/zekun/spatial_bert_weights_ablation/'
|
72 |
-
else:
|
73 |
-
model_save_dir = '/data2/zekun/spatial_bert_weights/'
|
74 |
-
model_save_dir = os.path.join(model_save_dir, 'typing_lr' + str("{:.0e}".format(lr)) + sep_pathstr +'_'+bert_option+ freeze_pathstr + '_london_california_bsize' + str(batch_size) )
|
75 |
-
model_save_dir = os.path.join(model_save_dir, checkpoint_prefix)
|
76 |
-
|
77 |
-
if not os.path.isdir(model_save_dir):
|
78 |
-
os.makedirs(model_save_dir)
|
79 |
-
else:
|
80 |
-
model_save_dir = args.model_save_dir
|
81 |
-
|
82 |
-
|
83 |
-
print('model_save_dir', model_save_dir)
|
84 |
-
print('\n')
|
85 |
-
|
86 |
-
if bert_option == 'bert-base':
|
87 |
-
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
88 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, num_semantic_types=len(TYPE_LIST))
|
89 |
-
elif bert_option == 'bert-large':
|
90 |
-
tokenizer = BertTokenizer.from_pretrained("bert-large-uncased")
|
91 |
-
config = SpatialBertConfig(use_spatial_distance_embedding = not if_no_spatial_distance, hidden_size = 1024, intermediate_size = 4096, num_attention_heads=16, num_hidden_layers=24, num_semantic_types=len(TYPE_LIST))
|
92 |
-
else:
|
93 |
-
raise NotImplementedError
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
label_encoder = LabelEncoder()
|
98 |
-
label_encoder.fit(TYPE_LIST)
|
99 |
-
|
100 |
-
|
101 |
-
london_train_val_dataset = PbfMapDataset(data_file_path = london_file_path,
|
102 |
-
tokenizer = tokenizer,
|
103 |
-
max_token_len = max_token_len,
|
104 |
-
distance_norm_factor = distance_norm_factor,
|
105 |
-
spatial_dist_fill = spatial_dist_fill,
|
106 |
-
with_type = with_type,
|
107 |
-
type_key_str = type_key_str,
|
108 |
-
sep_between_neighbors = sep_between_neighbors,
|
109 |
-
label_encoder = label_encoder,
|
110 |
-
mode = 'train')
|
111 |
-
|
112 |
-
percent_80 = int(len(london_train_val_dataset) * 0.8)
|
113 |
-
london_train_dataset, london_val_dataset = torch.utils.data.random_split(london_train_val_dataset, [percent_80, len(london_train_val_dataset) - percent_80])
|
114 |
-
|
115 |
-
california_train_val_dataset = PbfMapDataset(data_file_path = california_file_path,
|
116 |
-
tokenizer = tokenizer,
|
117 |
-
max_token_len = max_token_len,
|
118 |
-
distance_norm_factor = distance_norm_factor,
|
119 |
-
spatial_dist_fill = spatial_dist_fill,
|
120 |
-
with_type = with_type,
|
121 |
-
type_key_str = type_key_str,
|
122 |
-
sep_between_neighbors = sep_between_neighbors,
|
123 |
-
label_encoder = label_encoder,
|
124 |
-
mode = 'train')
|
125 |
-
percent_80 = int(len(california_train_val_dataset) * 0.8)
|
126 |
-
california_train_dataset, california_val_dataset = torch.utils.data.random_split(california_train_val_dataset, [percent_80, len(california_train_val_dataset) - percent_80])
|
127 |
-
|
128 |
-
train_dataset = torch.utils.data.ConcatDataset([london_train_dataset, california_train_dataset])
|
129 |
-
val_dataset = torch.utils.data.ConcatDataset([london_val_dataset, california_val_dataset])
|
130 |
-
|
131 |
-
|
132 |
-
if DEBUG:
|
133 |
-
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
134 |
-
shuffle=False, pin_memory=True, drop_last=True)
|
135 |
-
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
136 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
137 |
-
else:
|
138 |
-
train_loader = DataLoader(train_dataset, batch_size= batch_size, num_workers=num_workers,
|
139 |
-
shuffle=True, pin_memory=True, drop_last=True)
|
140 |
-
val_loader = DataLoader(val_dataset, batch_size= batch_size, num_workers=num_workers,
|
141 |
-
shuffle=False, pin_memory=True, drop_last=False)
|
142 |
-
|
143 |
-
|
144 |
-
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
145 |
-
|
146 |
-
model = SpatialBertForSemanticTyping(config)
|
147 |
-
model.to(device)
|
148 |
-
|
149 |
-
|
150 |
-
model.load_state_dict(torch.load(mlm_checkpoint_path), strict = False)
|
151 |
-
|
152 |
-
model.train()
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
# initialize optimizer
|
157 |
-
optim = AdamW(model.parameters(), lr = lr)
|
158 |
-
|
159 |
-
print('start training...')
|
160 |
-
|
161 |
-
for epoch in range(epochs):
|
162 |
-
# setup loop with TQDM and dataloader
|
163 |
-
loop = tqdm(train_loader, leave=True)
|
164 |
-
iter = 0
|
165 |
-
for batch in loop:
|
166 |
-
# initialize calculated gradients (from prev step)
|
167 |
-
optim.zero_grad()
|
168 |
-
# pull all tensor batches required for training
|
169 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
170 |
-
attention_mask = batch['attention_mask'].to(device)
|
171 |
-
position_list_x = batch['norm_lng_list'].to(device)
|
172 |
-
position_list_y = batch['norm_lat_list'].to(device)
|
173 |
-
sent_position_ids = batch['sent_position_ids'].to(device)
|
174 |
-
|
175 |
-
#labels = batch['pseudo_sentence'].to(device)
|
176 |
-
labels = batch['pivot_type'].to(device)
|
177 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
178 |
-
|
179 |
-
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
180 |
-
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
181 |
-
|
182 |
-
|
183 |
-
loss = outputs.loss
|
184 |
-
loss.backward()
|
185 |
-
optim.step()
|
186 |
-
|
187 |
-
loop.set_description(f'Epoch {epoch}')
|
188 |
-
loop.set_postfix({'loss':loss.item()})
|
189 |
-
|
190 |
-
if DEBUG:
|
191 |
-
print('ep'+str(epoch)+'_' + '_iter'+ str(iter).zfill(5), loss.item() )
|
192 |
-
|
193 |
-
iter += 1
|
194 |
-
|
195 |
-
if iter % save_interval == 0 or iter == loop.total:
|
196 |
-
loss_valid = validating(val_loader, model, device)
|
197 |
-
|
198 |
-
save_path = os.path.join(model_save_dir, 'keeppos_ep'+str(epoch) + '_iter'+ str(iter).zfill(5) \
|
199 |
-
+ '_' +str("{:.4f}".format(loss.item())) + '_val' + str("{:.4f}".format(loss_valid)) +'.pth' )
|
200 |
-
|
201 |
-
torch.save(model.state_dict(), save_path)
|
202 |
-
print('validation loss', loss_valid)
|
203 |
-
print('saving model checkpoint to', save_path)
|
204 |
-
|
205 |
-
def validating(val_loader, model, device):
|
206 |
-
|
207 |
-
with torch.no_grad():
|
208 |
-
|
209 |
-
loss_valid = 0
|
210 |
-
loop = tqdm(val_loader, leave=True)
|
211 |
-
|
212 |
-
for batch in loop:
|
213 |
-
input_ids = batch['pseudo_sentence'].to(device)
|
214 |
-
attention_mask = batch['attention_mask'].to(device)
|
215 |
-
position_list_x = batch['norm_lng_list'].to(device)
|
216 |
-
position_list_y = batch['norm_lat_list'].to(device)
|
217 |
-
sent_position_ids = batch['sent_position_ids'].to(device)
|
218 |
-
|
219 |
-
|
220 |
-
labels = batch['pivot_type'].to(device)
|
221 |
-
pivot_lens = batch['pivot_token_len'].to(device)
|
222 |
-
|
223 |
-
outputs = model(input_ids, attention_mask = attention_mask, sent_position_ids = sent_position_ids,
|
224 |
-
position_list_x = position_list_x, position_list_y = position_list_y, labels = labels, pivot_len_list = pivot_lens)
|
225 |
-
|
226 |
-
loss_valid += outputs.loss
|
227 |
-
|
228 |
-
loss_valid /= len(val_loader)
|
229 |
-
|
230 |
-
return loss_valid
|
231 |
-
|
232 |
-
|
233 |
-
def main():
|
234 |
-
|
235 |
-
parser = argparse.ArgumentParser()
|
236 |
-
parser.add_argument('--num_workers', type=int, default=5)
|
237 |
-
parser.add_argument('--batch_size', type=int, default=12)
|
238 |
-
parser.add_argument('--epochs', type=int, default=10)
|
239 |
-
parser.add_argument('--save_interval', type=int, default=2000)
|
240 |
-
parser.add_argument('--max_token_len', type=int, default=512)
|
241 |
-
|
242 |
-
|
243 |
-
parser.add_argument('--lr', type=float, default = 5e-5)
|
244 |
-
parser.add_argument('--distance_norm_factor', type=float, default = 0.0001)
|
245 |
-
parser.add_argument('--spatial_dist_fill', type=float, default = 100)
|
246 |
-
parser.add_argument('--num_classes', type=int, default = 9)
|
247 |
-
|
248 |
-
parser.add_argument('--with_type', default=False, action='store_true')
|
249 |
-
parser.add_argument('--sep_between_neighbors', default=False, action='store_true')
|
250 |
-
parser.add_argument('--freeze_backbone', default=False, action='store_true')
|
251 |
-
parser.add_argument('--no_spatial_distance', default=False, action='store_true')
|
252 |
-
|
253 |
-
parser.add_argument('--bert_option', type=str, default='bert-base')
|
254 |
-
parser.add_argument('--model_save_dir', type=str, default=None)
|
255 |
-
|
256 |
-
parser.add_argument('--mlm_checkpoint_path', type=str, default=None)
|
257 |
-
|
258 |
-
|
259 |
-
args = parser.parse_args()
|
260 |
-
print('\n')
|
261 |
-
print(args)
|
262 |
-
print('\n')
|
263 |
-
|
264 |
-
|
265 |
-
# out_dir not None, and out_dir does not exist, then create out_dir
|
266 |
-
if args.model_save_dir is not None and not os.path.isdir(args.model_save_dir):
|
267 |
-
os.makedirs(args.model_save_dir)
|
268 |
-
|
269 |
-
training(args)
|
270 |
-
|
271 |
-
|
272 |
-
if __name__ == '__main__':
|
273 |
-
|
274 |
-
main()
|
275 |
-
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|