Oisin Mac Aodha
commited on
Commit
·
505e401
1
Parent(s):
6570723
First model version
Browse files- LICENSE +21 -0
- README.md +73 -12
- app.py +180 -0
- data/masks/ocean_mask.npy +3 -0
- datasets.py +194 -0
- eval.py +362 -0
- images/sinr_traverse.gif +0 -0
- losses.py +146 -0
- models.py +85 -0
- paths.json +9 -0
- pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt +3 -0
- pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt +3 -0
- pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt +3 -0
- pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt +3 -0
- requirements.txt +6 -0
- setup.py +91 -0
- taxa_02_08_2023_names.txt +0 -0
- utils.py +143 -0
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 Elijah Cole
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,73 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Spatial Implicit Neural Representations for Global-Scale Species Mapping - ICML 2023
|
2 |
+
|
3 |
+
Code for training and evaluating global-scale species range estimation models. This code enables the recreation of the results from our ICML 2023 paper [Spatial Implicit Neural Representations for Global-Scale Species Mapping](https://arxiv.org/abs/2306.02564).
|
4 |
+
|
5 |
+
## 🌍 Overview
|
6 |
+
Estimating the geographical range of a species from sparse observations is a challenging and important geospatial prediction problem. Given a set of locations where a species has been observed, the goal is to build a model to predict whether the species is present or absent at any location. In this work, we use Spatial Implicit Neural Representations (SINRs) to jointly estimate the geographical range of thousands of species simultaneously. SINRs scale gracefully, making better predictions as we increase the number of training species and the amount of training data per species. We introduce four new range estimation and spatial representation learning benchmarks, and we use them to demonstrate that noisy and biased crowdsourced data can be combined with implicit neural representations to approximate expert-developed range maps for many species.
|
7 |
+
|
8 |
+

|
9 |
+
<sup>Above we visualize predictions from one of our SINR models trained on data from [iNaturalist](inaturalist.org). On the left we show the learned species embedding space, where each point represents a different species. On the right we see the predicted range of the species corresponding to the red dot on the left.<sup>
|
10 |
+
|
11 |
+
## 🔍 Getting Started
|
12 |
+
|
13 |
+
#### Installing Required Packages
|
14 |
+
|
15 |
+
1. We recommend using an isolated Python environment to avoid dependency issues. Install the Anaconda Python 3.9 distribution for your operating system from [here](https://www.anaconda.com/download).
|
16 |
+
|
17 |
+
2. Create a new environment and activate it:
|
18 |
+
```bash
|
19 |
+
conda create -y --name sinr_icml python==3.9
|
20 |
+
conda activate sinr_icml
|
21 |
+
```
|
22 |
+
|
23 |
+
3. After activating the environment, install the required packages:
|
24 |
+
```bash
|
25 |
+
pip3 install -r requirements.txt
|
26 |
+
```
|
27 |
+
|
28 |
+
#### Data Download and Preparation
|
29 |
+
Instructions for downloading the data in `data/README.md`.
|
30 |
+
|
31 |
+
## 🗺️ Generating Predictions
|
32 |
+
To generate predictions for a model in the form of an image, run the following command:
|
33 |
+
```bash
|
34 |
+
python viz_map.py --taxa_id 130714
|
35 |
+
```
|
36 |
+
Here, `--taxa_id` is the id number for a species of interest from [iNaturalist](https://www.inaturalist.org/taxa/130714). If you want to generate predictions for a random species, add the `--rand_taxa` instead.
|
37 |
+
|
38 |
+
Note, before you run this command you need to first download the data as described in `data/README.md`. In addition, if you want to evaluate some of the pretrained models from the paper, you need to download those first and place them at `sinr/pretrained_models`. See `web_app/README.md` for more details.
|
39 |
+
|
40 |
+
There is also an interactive browser-based demo available in `web_app`.
|
41 |
+
|
42 |
+
## 🚅 Training and Evaluating Models
|
43 |
+
|
44 |
+
To train and evaluate a model, run the following command:
|
45 |
+
```bash
|
46 |
+
python train_and_evaluate_models.py
|
47 |
+
```
|
48 |
+
|
49 |
+
#### Hyperparameters
|
50 |
+
Common parameters of interest can be set within `train_and_evaluate_models.py`. All other parameters are exposed in `setup.py`.
|
51 |
+
|
52 |
+
#### Outputs
|
53 |
+
By default, trained models and evaluation results will be saved to a folder in the `experiments` directory. Evaluation results will also be printed to the command line.
|
54 |
+
|
55 |
+
#### Interactive Model Visualizer
|
56 |
+
To visualize range predictions from pretrained SINR models, please follow the instructions in `web_app/README.md`.
|
57 |
+
|
58 |
+
## 🙏 Acknowledgements
|
59 |
+
This project was enabled by data from the Cornell Lab of Ornithology, The International Union for the Conservation of Nature, iNaturalist, NASA, USGS, JAXA, CIESIN, and UC Merced. We are especially indebted to the [iNaturalist](inaturalist.org) and [eBird](https://ebird.org) communities for their data collection efforts. We also thank Matt Stimas-Mackey and Sam Heinrich for their help with data curation. This project was funded by the [Climate Change AI Innovation Grants](https://www.climatechange.ai/blog/2022-04-13-innovation-grants) program, hosted by Climate Change AI with the support of the Quadrature Climate Foundation, Schmidt Futures, and the Canada Hub of Future Earth. This work was also supported by the Caltech Resnick Sustainability Institute and an NSF Graduate Research Fellowship (grant number DGE1745301).
|
60 |
+
|
61 |
+
If you find our work useful in your research please consider citing our paper.
|
62 |
+
```
|
63 |
+
@inproceedings{SINR_icml23,
|
64 |
+
title = {{Spatial Implicit Neural Representations for Global-Scale Species Mapping}},
|
65 |
+
author = {Cole, Elijah and Van Horn, Grant and Lange, Christian and Shepard, Alexander and Leary, Patrick and Perona, Pietro and Loarie, Scott and Mac Aodha, Oisin},
|
66 |
+
booktitle = {ICML},
|
67 |
+
year = {2023}
|
68 |
+
}
|
69 |
+
```
|
70 |
+
|
71 |
+
## 📜 Disclaimer
|
72 |
+
Extreme care should be taken before making any decisions based on the outputs of models presented here. Our goal in this work is to demonstrate the promise of large-scale representation learning for species range estimation, not to provide definitive range maps. Our models are trained on biased data and have not been calibrated or validated beyond the experiments illustrated in the paper.
|
73 |
+
|
app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
matplotlib.use('Agg')
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
|
10 |
+
import utils
|
11 |
+
import models
|
12 |
+
import datasets
|
13 |
+
|
14 |
+
|
15 |
+
def load_taxa_metadata(file_path):
|
16 |
+
taxa_names_file = open(file_path, "r")
|
17 |
+
data = taxa_names_file.read().split("\n")
|
18 |
+
data = [dd for dd in data if dd != '']
|
19 |
+
taxa_ids = []
|
20 |
+
taxa_names = []
|
21 |
+
for tt in range(len(data)):
|
22 |
+
id, nm = data[tt].split('\t')
|
23 |
+
taxa_ids.append(int(id))
|
24 |
+
taxa_names.append(nm)
|
25 |
+
taxa_names_file.close()
|
26 |
+
return dict(zip(taxa_ids, taxa_names))
|
27 |
+
|
28 |
+
|
29 |
+
def generate_prediction(taxa_id, selected_model, settings, threshold):
|
30 |
+
|
31 |
+
# select the model to use
|
32 |
+
if selected_model == 'AN_FULL max 10':
|
33 |
+
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt'
|
34 |
+
elif selected_model == 'AN_FULL max 100':
|
35 |
+
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt'
|
36 |
+
elif selected_model == 'AN_FULL max 1000':
|
37 |
+
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt'
|
38 |
+
elif selected_model == 'Distilled env model':
|
39 |
+
model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt'
|
40 |
+
|
41 |
+
# load params
|
42 |
+
with open('paths.json', 'r') as f:
|
43 |
+
paths = json.load(f)
|
44 |
+
|
45 |
+
# configs
|
46 |
+
eval_params = {}
|
47 |
+
eval_params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
48 |
+
eval_params['model_path'] = model_path
|
49 |
+
eval_params['taxa_id'] = int(taxa_id)
|
50 |
+
eval_params['rand_taxa'] = 'Random taxa' in settings
|
51 |
+
eval_params['set_max_cmap_to_1'] = False
|
52 |
+
eval_params['disable_ocean_mask'] = 'Distilled env model' in settings
|
53 |
+
eval_params['threshold'] = threshold if 'Threshold' in settings else -1.0
|
54 |
+
|
55 |
+
# load model
|
56 |
+
train_params = torch.load(eval_params['model_path'], map_location='cpu')
|
57 |
+
model = models.get_model(train_params['params'])
|
58 |
+
model.load_state_dict(train_params['state_dict'], strict=True)
|
59 |
+
model = model.to(eval_params['device'])
|
60 |
+
model.eval()
|
61 |
+
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
|
62 |
+
raster = datasets.load_env(norm=train_params['params']['env_norm'])
|
63 |
+
else:
|
64 |
+
raster = None
|
65 |
+
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)
|
66 |
+
|
67 |
+
# user specified random taxa
|
68 |
+
if eval_params['rand_taxa']:
|
69 |
+
print('Selecting random taxa')
|
70 |
+
eval_params['taxa_id'] = np.random.choice(train_params['params']['class_to_taxa'])
|
71 |
+
|
72 |
+
# load taxa of interest
|
73 |
+
if eval_params['taxa_id'] in train_params['params']['class_to_taxa']:
|
74 |
+
class_of_interest = train_params['params']['class_to_taxa'].index(eval_params['taxa_id'])
|
75 |
+
else:
|
76 |
+
print(f'Error: Taxa specified that is not in the model: {eval_params["taxa_id"]}')
|
77 |
+
fig = plt.figure()
|
78 |
+
plt.imshow(np.zeros((1,1)), vmin=0, vmax=1.0, cmap=plt.cm.plasma)
|
79 |
+
plt.axis('off')
|
80 |
+
plt.tight_layout()
|
81 |
+
op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{eval_params["taxa_id"]}</a></h2> Error: specified taxa is not in the model.'
|
82 |
+
|
83 |
+
return op_html, fig, eval_params['taxa_id']
|
84 |
+
print(f'Loading taxa: {eval_params["taxa_id"]}')
|
85 |
+
|
86 |
+
# load ocean mask
|
87 |
+
mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
|
88 |
+
mask_inds = np.where(mask.reshape(-1) == 1)[0]
|
89 |
+
|
90 |
+
# generate input features
|
91 |
+
locs = utils.coord_grid(mask.shape)
|
92 |
+
if not eval_params['disable_ocean_mask']:
|
93 |
+
locs = locs[mask_inds, :]
|
94 |
+
locs = torch.from_numpy(locs)
|
95 |
+
locs_enc = enc.encode(locs).to(eval_params['device'])
|
96 |
+
|
97 |
+
# make prediction
|
98 |
+
with torch.no_grad():
|
99 |
+
preds = model(locs_enc, return_feats=False, class_of_interest=class_of_interest).cpu().numpy()
|
100 |
+
|
101 |
+
# threshold predictions
|
102 |
+
if eval_params['threshold'] > 0:
|
103 |
+
print(f'Applying threshold of {eval_params["threshold"]} to the predictions.')
|
104 |
+
preds[preds<eval_params['threshold']] = 0.0
|
105 |
+
preds[preds>=eval_params['threshold']] = 1.0
|
106 |
+
|
107 |
+
# mask data
|
108 |
+
if not eval_params['disable_ocean_mask']:
|
109 |
+
op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN
|
110 |
+
op_im[mask_inds] = preds
|
111 |
+
else:
|
112 |
+
op_im = preds
|
113 |
+
|
114 |
+
# reshape and create masked array for visualization
|
115 |
+
op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
|
116 |
+
op_im = np.ma.masked_invalid(op_im)
|
117 |
+
if eval_params['set_max_cmap_to_1']:
|
118 |
+
vmax = 1.0
|
119 |
+
else:
|
120 |
+
vmax = np.max(op_im)
|
121 |
+
|
122 |
+
# set color for masked values
|
123 |
+
cmap = plt.cm.plasma
|
124 |
+
cmap.set_bad(color='none')
|
125 |
+
|
126 |
+
plt.rcParams['figure.figsize'] = 24,12
|
127 |
+
fig = plt.figure()
|
128 |
+
plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap)
|
129 |
+
plt.axis('off')
|
130 |
+
plt.tight_layout()
|
131 |
+
|
132 |
+
# generate html for ouput display
|
133 |
+
taxa_name_str = taxa_names[eval_params['taxa_id']]
|
134 |
+
op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{taxa_name_str}</a></h2> (click for more info)'
|
135 |
+
return op_html, fig, gr.Number.update(value=eval_params['taxa_id'])
|
136 |
+
|
137 |
+
|
138 |
+
# load metadata
|
139 |
+
taxa_names = load_taxa_metadata('taxa_02_08_2023_names.txt')
|
140 |
+
|
141 |
+
|
142 |
+
with gr.Blocks(title="SINR Demo") as demo:
|
143 |
+
top_text = "Visualization code to explore species range predictions "\
|
144 |
+
"from Spatial Implicit Neural Representation (SINR) models from "\
|
145 |
+
"[our](https://arxiv.org/abs/2306.02564) ICML 2023 paper."
|
146 |
+
gr.Markdown("# SINR Visualization Demo")
|
147 |
+
gr.Markdown(top_text)
|
148 |
+
|
149 |
+
with gr.Row():
|
150 |
+
selected_taxa = gr.Number(label="Taxa ID", value=130714)
|
151 |
+
select_model = gr.Dropdown(["AN_FULL max 10", "AN_FULL max 100", "AN_FULL max 1000", "Distilled env model"],
|
152 |
+
value="AN_FULL max 1000", label="Model")
|
153 |
+
with gr.Row():
|
154 |
+
settings = gr.CheckboxGroup(["Random taxa", "Disable ocean mask", "Threshold"], label="Settings")
|
155 |
+
threshold = gr.Slider(0, 1, 0, label="Threshold")
|
156 |
+
|
157 |
+
with gr.Row():
|
158 |
+
submit_button = gr.Button("Run Model")
|
159 |
+
|
160 |
+
with gr.Row():
|
161 |
+
output_text = gr.HTML(label="Species Name:")
|
162 |
+
|
163 |
+
with gr.Row():
|
164 |
+
output_image = gr.Plot(label="Predicted occupancy")
|
165 |
+
|
166 |
+
end_text = "**Note:** Extreme care should be taken before making any decisions "\
|
167 |
+
"based on the outputs of models presented here. "\
|
168 |
+
"The goal of this work is to demonstrate the promise of large-scale "\
|
169 |
+
"representation learning for species range estimation. "\
|
170 |
+
"Our models are trained on biased data and have not been calibrated "\
|
171 |
+
"or validated beyondthe experiments illustrated in the paper."
|
172 |
+
gr.Markdown(end_text)
|
173 |
+
|
174 |
+
submit_button.click(
|
175 |
+
fn = generate_prediction,
|
176 |
+
inputs=[selected_taxa, select_model, settings, threshold],
|
177 |
+
outputs=[output_text, output_image, selected_taxa]
|
178 |
+
)
|
179 |
+
|
180 |
+
demo.launch()
|
data/masks/ocean_mask.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c41395204ddb6327089c27cc5bf14083d7d74176cf419ddf62551eb681c54974
|
3 |
+
size 2008136
|
datasets.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import json
|
4 |
+
import pandas as pd
|
5 |
+
from calendar import monthrange
|
6 |
+
import torch
|
7 |
+
import utils
|
8 |
+
|
9 |
+
class LocationDataset(torch.utils.data.Dataset):
|
10 |
+
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device):
|
11 |
+
|
12 |
+
# handle input encoding:
|
13 |
+
self.input_enc = input_enc
|
14 |
+
if self.input_enc in ['env', 'sin_cos_env']:
|
15 |
+
raster = load_env()
|
16 |
+
else:
|
17 |
+
raster = None
|
18 |
+
self.enc = utils.CoordEncoder(input_enc, raster)
|
19 |
+
|
20 |
+
# define some properties:
|
21 |
+
self.locs = locs
|
22 |
+
self.loc_feats = self.enc.encode(self.locs)
|
23 |
+
self.labels = labels
|
24 |
+
self.classes = classes
|
25 |
+
self.class_to_taxa = class_to_taxa
|
26 |
+
|
27 |
+
# useful numbers:
|
28 |
+
self.num_classes = len(np.unique(labels))
|
29 |
+
self.input_dim = self.loc_feats.shape[1]
|
30 |
+
|
31 |
+
if self.enc.raster is not None:
|
32 |
+
self.enc.raster = self.enc.raster.to(device)
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return self.loc_feats.shape[0]
|
36 |
+
|
37 |
+
def __getitem__(self, index):
|
38 |
+
loc_feat = self.loc_feats[index, :]
|
39 |
+
loc = self.locs[index, :]
|
40 |
+
class_id = self.labels[index]
|
41 |
+
return loc_feat, loc, class_id
|
42 |
+
|
43 |
+
def load_env():
|
44 |
+
with open('paths.json', 'r') as f:
|
45 |
+
paths = json.load(f)
|
46 |
+
raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy'))
|
47 |
+
return raster
|
48 |
+
|
49 |
+
def load_context_feats(data_path):
|
50 |
+
context_feats = np.load(data_path).astype(np.float32)
|
51 |
+
context_feats = torch.from_numpy(context_feats)
|
52 |
+
return context_feats
|
53 |
+
|
54 |
+
def load_inat_data(ip_file, taxa_of_interest=None):
|
55 |
+
|
56 |
+
print('\nLoading ' + ip_file)
|
57 |
+
data = pd.read_csv(ip_file)
|
58 |
+
|
59 |
+
# remove outliers
|
60 |
+
num_obs = data.shape[0]
|
61 |
+
data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )]
|
62 |
+
if (num_obs - data.shape[0]) > 0:
|
63 |
+
print(num_obs - data.shape[0], 'items filtered due to invalid locations')
|
64 |
+
|
65 |
+
if 'accuracy' in data.columns:
|
66 |
+
data.drop(['accuracy'], axis=1, inplace=True)
|
67 |
+
|
68 |
+
if 'positional_accuracy' in data.columns:
|
69 |
+
data.drop(['positional_accuracy'], axis=1, inplace=True)
|
70 |
+
|
71 |
+
if 'geoprivacy' in data.columns:
|
72 |
+
data.drop(['geoprivacy'], axis=1, inplace=True)
|
73 |
+
|
74 |
+
if 'observed_on' in data.columns:
|
75 |
+
data.rename(columns = {'observed_on':'date'}, inplace=True)
|
76 |
+
|
77 |
+
num_obs_orig = data.shape[0]
|
78 |
+
data = data.dropna()
|
79 |
+
size_diff = num_obs_orig - data.shape[0]
|
80 |
+
if size_diff > 0:
|
81 |
+
print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed')
|
82 |
+
|
83 |
+
# keep only taxa of interest:
|
84 |
+
if taxa_of_interest is not None:
|
85 |
+
num_obs_orig = data.shape[0]
|
86 |
+
data = data[data['taxon_id'].isin(taxa_of_interest)]
|
87 |
+
print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed')
|
88 |
+
|
89 |
+
print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0]))
|
90 |
+
|
91 |
+
locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32)
|
92 |
+
taxa = data['taxon_id'].values.astype(np.int)
|
93 |
+
|
94 |
+
if 'user_id' in data.columns:
|
95 |
+
users = data['user_id'].values.astype(np.int)
|
96 |
+
_, users = np.unique(users, return_inverse=True)
|
97 |
+
elif 'observer_id' in data.columns:
|
98 |
+
users = data['observer_id'].values.astype(np.int)
|
99 |
+
_, users = np.unique(users, return_inverse=True)
|
100 |
+
else:
|
101 |
+
users = np.ones(taxa.shape[0], dtype=np.int)*-1
|
102 |
+
|
103 |
+
# Note - assumes that dates are in format YYYY-MM-DD
|
104 |
+
years = np.array([int(d_str[:4]) for d_str in data['date'].values])
|
105 |
+
months = np.array([int(d_str[5:7]) for d_str in data['date'].values])
|
106 |
+
days = np.array([int(d_str[8:10]) for d_str in data['date'].values])
|
107 |
+
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)])
|
108 |
+
dates = days_per_month[months-1] + days-1
|
109 |
+
dates = np.round((dates) / 364.0, 4).astype(np.float32)
|
110 |
+
if 'id' in data.columns:
|
111 |
+
obs_ids = data['id'].values
|
112 |
+
elif 'observation_uuid' in data.columns:
|
113 |
+
obs_ids = data['observation_uuid'].values
|
114 |
+
|
115 |
+
return locs, taxa, users, dates, years, obs_ids
|
116 |
+
|
117 |
+
def choose_aux_species(current_species, num_aux_species, aux_species_seed):
|
118 |
+
if num_aux_species == 0:
|
119 |
+
return []
|
120 |
+
with open('paths.json', 'r') as f:
|
121 |
+
paths = json.load(f)
|
122 |
+
data_dir = paths['train']
|
123 |
+
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
|
124 |
+
with open(taxa_file, 'r') as f:
|
125 |
+
inat_large_metadata = json.load(f)
|
126 |
+
aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata]
|
127 |
+
aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species)
|
128 |
+
print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates')
|
129 |
+
rng = np.random.default_rng(aux_species_seed)
|
130 |
+
idx_rand_aux_species = rng.permutation(len(aux_species_candidates))
|
131 |
+
aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]])
|
132 |
+
return aux_species
|
133 |
+
|
134 |
+
def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file_snt=None):
|
135 |
+
if species_set == 'all':
|
136 |
+
return None
|
137 |
+
if species_set == 'snt_birds':
|
138 |
+
assert taxa_file_snt is not None
|
139 |
+
with open(taxa_file_snt, 'r') as f: #
|
140 |
+
taxa_subsets = json.load(f)
|
141 |
+
taxa_of_interest = list(taxa_subsets['snt_birds'])
|
142 |
+
else:
|
143 |
+
raise NotImplementedError
|
144 |
+
# optionally add some other species back in:
|
145 |
+
aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed)
|
146 |
+
taxa_of_interest.extend(aux_species)
|
147 |
+
return taxa_of_interest
|
148 |
+
|
149 |
+
def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123):
|
150 |
+
if hard_cap == -1:
|
151 |
+
return np.arange(len(labels))
|
152 |
+
print(f'subsampling (up to) {hard_cap} per class for the training set')
|
153 |
+
class_counts = {id: 0 for id in np.unique(labels)}
|
154 |
+
ss_rng = np.random.default_rng(hard_cap_seed)
|
155 |
+
idx_rand = ss_rng.permutation(len(labels))
|
156 |
+
idx_ss = []
|
157 |
+
for i in idx_rand:
|
158 |
+
class_id = labels[i]
|
159 |
+
if class_counts[class_id] < hard_cap:
|
160 |
+
idx_ss.append(i)
|
161 |
+
class_counts[class_id] += 1
|
162 |
+
idx_ss = np.sort(idx_ss)
|
163 |
+
print(f'final training set size: {len(idx_ss)}')
|
164 |
+
return idx_ss
|
165 |
+
|
166 |
+
def get_train_data(params):
|
167 |
+
with open('paths.json', 'r') as f:
|
168 |
+
paths = json.load(f)
|
169 |
+
data_dir = paths['train']
|
170 |
+
obs_file = os.path.join(data_dir, 'geo_prior_train.csv')
|
171 |
+
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
|
172 |
+
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
|
173 |
+
|
174 |
+
taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], taxa_file_snt)
|
175 |
+
|
176 |
+
locs, labels, _, _, _, _ = load_inat_data(obs_file, taxa_of_interest)
|
177 |
+
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
|
178 |
+
class_to_taxa = unique_taxa.tolist()
|
179 |
+
|
180 |
+
# load class names
|
181 |
+
class_info_file = json.load(open(taxa_file, 'r'))
|
182 |
+
class_names_file = [cc['latin_name'] for cc in class_info_file]
|
183 |
+
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
|
184 |
+
classes = dict(zip(taxa_ids_file, class_names_file))
|
185 |
+
|
186 |
+
idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'])
|
187 |
+
|
188 |
+
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
|
189 |
+
|
190 |
+
labels = torch.from_numpy(np.array(class_ids)[idx_ss])
|
191 |
+
|
192 |
+
ds = LocationDataset(locs, labels, classes, class_to_taxa, params['input_enc'], params['device'])
|
193 |
+
|
194 |
+
return ds
|
eval.py
ADDED
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import time
|
6 |
+
import os
|
7 |
+
import copy
|
8 |
+
import json
|
9 |
+
import tifffile
|
10 |
+
import h3
|
11 |
+
import setup
|
12 |
+
|
13 |
+
from sklearn.linear_model import RidgeCV
|
14 |
+
from sklearn.preprocessing import MinMaxScaler
|
15 |
+
from sklearn.metrics import average_precision_score
|
16 |
+
|
17 |
+
import utils
|
18 |
+
import models
|
19 |
+
import datasets
|
20 |
+
|
21 |
+
class EvaluatorSNT:
|
22 |
+
def __init__(self, train_params, eval_params):
|
23 |
+
self.train_params = train_params
|
24 |
+
self.eval_params = eval_params
|
25 |
+
with open('paths.json', 'r') as f:
|
26 |
+
paths = json.load(f)
|
27 |
+
D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
|
28 |
+
D = D.item()
|
29 |
+
self.loc_indices_per_species = D['loc_indices_per_species']
|
30 |
+
self.labels_per_species = D['labels_per_species']
|
31 |
+
self.taxa = D['taxa']
|
32 |
+
self.obs_locs = D['obs_locs']
|
33 |
+
self.obs_locs_idx = D['obs_locs_idx']
|
34 |
+
|
35 |
+
def get_labels(self, species):
|
36 |
+
species = str(species)
|
37 |
+
lat = []
|
38 |
+
lon = []
|
39 |
+
gt = []
|
40 |
+
for hx in self.data:
|
41 |
+
cur_lat, cur_lon = h3.h3_to_geo(hx)
|
42 |
+
if species in self.data[hx]:
|
43 |
+
cur_label = int(len(self.data[hx][species]) > 0)
|
44 |
+
gt.append(cur_label)
|
45 |
+
lat.append(cur_lat)
|
46 |
+
lon.append(cur_lon)
|
47 |
+
lat = np.array(lat).astype(np.float32)
|
48 |
+
lon = np.array(lon).astype(np.float32)
|
49 |
+
obs_locs = np.vstack((lon, lat)).T
|
50 |
+
gt = np.array(gt).astype(np.float32)
|
51 |
+
return obs_locs, gt
|
52 |
+
|
53 |
+
def run_evaluation(self, model, enc):
|
54 |
+
results = {}
|
55 |
+
|
56 |
+
# set seeds:
|
57 |
+
np.random.seed(self.eval_params['seed'])
|
58 |
+
random.seed(self.eval_params['seed'])
|
59 |
+
|
60 |
+
# evaluate the geo model for each taxon
|
61 |
+
results['mean_average_precision'] = np.zeros((len(self.taxa)), dtype=np.float32)
|
62 |
+
# get eval locations and apply input encoding
|
63 |
+
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device'])
|
64 |
+
loc_feat = enc.encode(obs_locs)
|
65 |
+
# get classes to eval
|
66 |
+
classes_of_interest = np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze()
|
67 |
+
classes_of_interest = torch.from_numpy(classes_of_interest)
|
68 |
+
# generate model predictions for classes of interest at eval locations
|
69 |
+
with torch.no_grad():
|
70 |
+
loc_emb = model(loc_feat, return_feats=True)
|
71 |
+
wt = model.class_emb.weight[classes_of_interest, :]
|
72 |
+
pred_mtx = torch.matmul(loc_emb, wt.T).cpu().numpy()
|
73 |
+
|
74 |
+
split_rng = np.random.default_rng(self.eval_params['split_seed'])
|
75 |
+
|
76 |
+
for tt_id, tt in enumerate(self.taxa):
|
77 |
+
# generate ground truth labels for current taxa
|
78 |
+
cur_class_of_interest = np.where(self.taxa == tt)[0][0]
|
79 |
+
cur_loc_indices = np.array(self.loc_indices_per_species[cur_class_of_interest])
|
80 |
+
cur_labels = np.array(self.labels_per_species[cur_class_of_interest])
|
81 |
+
|
82 |
+
# apply per-species split:
|
83 |
+
assert self.eval_params['split'] in ['all', 'val', 'test']
|
84 |
+
if self.eval_params['split'] != 'all':
|
85 |
+
num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int)
|
86 |
+
idx_rand = split_rng.permutation(len(cur_labels))
|
87 |
+
if self.eval_params['split'] == 'val':
|
88 |
+
idx_sel = idx_rand[:num_val]
|
89 |
+
elif self.eval_params['split'] == 'test':
|
90 |
+
idx_sel = idx_rand[num_val:]
|
91 |
+
cur_loc_indices = cur_loc_indices[idx_sel]
|
92 |
+
cur_labels = cur_labels[idx_sel]
|
93 |
+
|
94 |
+
# extract model predictions for current taxa from prediction matrix:
|
95 |
+
pred = pred_mtx[cur_loc_indices, tt_id]
|
96 |
+
|
97 |
+
# compute the AP for each taxa
|
98 |
+
results['mean_average_precision'][tt_id] = average_precision_score((cur_labels > 0).astype(np.int32), pred)
|
99 |
+
|
100 |
+
|
101 |
+
valid_taxa = ~np.isnan(results['mean_average_precision'])
|
102 |
+
|
103 |
+
# store results
|
104 |
+
results['per_species_average_precision_all'] = copy.deepcopy(results['mean_average_precision'])
|
105 |
+
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa]
|
106 |
+
results['mean_average_precision'] = per_species_average_precision_valid.mean()
|
107 |
+
results['num_eval_species_w_valid_ap'] = valid_taxa.sum()
|
108 |
+
results['num_eval_species_total'] = len(self.taxa)
|
109 |
+
|
110 |
+
return results
|
111 |
+
|
112 |
+
def report(self, results):
|
113 |
+
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']:
|
114 |
+
print(f'{field}: {results[field]}')
|
115 |
+
|
116 |
+
class EvaluatorIUCN:
|
117 |
+
|
118 |
+
def __init__(self, train_params, eval_params):
|
119 |
+
self.train_params = train_params
|
120 |
+
self.eval_params = eval_params
|
121 |
+
with open('paths.json', 'r') as f:
|
122 |
+
paths = json.load(f)
|
123 |
+
with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
|
124 |
+
self.data = json.load(f)
|
125 |
+
self.obs_locs = np.array(self.data['locs'], dtype=np.float32)
|
126 |
+
self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()]
|
127 |
+
|
128 |
+
def run_evaluation(self, model, enc):
|
129 |
+
results = {}
|
130 |
+
|
131 |
+
results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32)
|
132 |
+
# get eval locations and apply input encoding
|
133 |
+
obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device'])
|
134 |
+
loc_feat = enc.encode(obs_locs)
|
135 |
+
|
136 |
+
# get classes to eval
|
137 |
+
classes_of_interest = torch.from_numpy(np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze())
|
138 |
+
with torch.no_grad():
|
139 |
+
# generate model predictions for classes of interest at eval locations
|
140 |
+
loc_emb = model(loc_feat, return_feats=True)
|
141 |
+
wt = model.class_emb.weight[classes_of_interest, :]
|
142 |
+
pred_mtx = torch.matmul(loc_emb, wt.T)
|
143 |
+
|
144 |
+
for tt_id, tt in enumerate(self.taxa):
|
145 |
+
class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0]
|
146 |
+
|
147 |
+
if len(class_of_interest) == 0:
|
148 |
+
# taxa of interest is not in the model
|
149 |
+
pred = None
|
150 |
+
else:
|
151 |
+
# extract model predictions for current taxa from prediction matrix
|
152 |
+
pred = pred_mtx[:, tt_id]
|
153 |
+
|
154 |
+
# evaluate accuracy
|
155 |
+
if pred is None:
|
156 |
+
results['per_species_average_precision_all'][tt_id] = np.nan
|
157 |
+
else:
|
158 |
+
gt = np.zeros(obs_locs.shape[0], dtype=np.float32)
|
159 |
+
gt[self.data['taxa_presence'][str(tt)]] = 1.0
|
160 |
+
# average precision score:
|
161 |
+
results['per_species_average_precision_all'][tt_id] = average_precision_score(gt, pred)
|
162 |
+
|
163 |
+
valid_taxa = ~np.isnan(results['per_species_average_precision_all'])
|
164 |
+
|
165 |
+
# store results
|
166 |
+
per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa]
|
167 |
+
results['mean_average_precision'] = per_species_average_precision_valid.mean()
|
168 |
+
results['num_eval_species_w_valid_ap'] = valid_taxa.sum()
|
169 |
+
results['num_eval_species_total'] = len(self.taxa)
|
170 |
+
return results
|
171 |
+
|
172 |
+
def report(self, results):
|
173 |
+
for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']:
|
174 |
+
print(f'{field}: {results[field]}')
|
175 |
+
|
176 |
+
class EvaluatorGeoPrior:
|
177 |
+
|
178 |
+
def __init__(self, train_params, eval_params):
|
179 |
+
# store parameters:
|
180 |
+
self.train_params = train_params
|
181 |
+
self.eval_params = eval_params
|
182 |
+
with open('paths.json', 'r') as f:
|
183 |
+
paths = json.load(f)
|
184 |
+
# load vision model predictions:
|
185 |
+
self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz'))
|
186 |
+
print('\n', self.data['probs'].shape[0], 'total test observations')
|
187 |
+
# load locations:
|
188 |
+
meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv'))
|
189 |
+
self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32)
|
190 |
+
# taxonomic mapping:
|
191 |
+
self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa'])
|
192 |
+
print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models')
|
193 |
+
|
194 |
+
def find_mapping_between_models(self, vision_taxa, geo_taxa):
|
195 |
+
# this will output an array of size N_overlap X 2
|
196 |
+
# the first column will be the indices of the vision model, and the second is their
|
197 |
+
# corresponding index in the geo model
|
198 |
+
taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1
|
199 |
+
taxon_map[:, 0] = np.arange(vision_taxa.shape[0])
|
200 |
+
geo_taxa_arr = np.array(geo_taxa)
|
201 |
+
for tt_id, tt in enumerate(vision_taxa):
|
202 |
+
ind = np.where(geo_taxa_arr==tt)[0]
|
203 |
+
if len(ind) > 0:
|
204 |
+
taxon_map[tt_id, 1] = ind[0]
|
205 |
+
inds = np.where(taxon_map[:, 1]>-1)[0]
|
206 |
+
taxon_map = taxon_map[inds, :]
|
207 |
+
return taxon_map
|
208 |
+
|
209 |
+
def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map):
|
210 |
+
# this is slow as we turn the sparse input back into the same size as the dense one
|
211 |
+
vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
|
212 |
+
geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
|
213 |
+
vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob
|
214 |
+
|
215 |
+
geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]]
|
216 |
+
|
217 |
+
return geo_pred, vision_pred
|
218 |
+
|
219 |
+
def run_evaluation(self, model, enc):
|
220 |
+
results = {}
|
221 |
+
|
222 |
+
# loop over in batches
|
223 |
+
batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0]))
|
224 |
+
correct_pred = np.zeros(self.data['probs'].shape[0])
|
225 |
+
|
226 |
+
print('\nbid\t w geo\t wo geo')
|
227 |
+
for bb_id, bb in enumerate(range(len(batch_start)-1)):
|
228 |
+
batch_inds = np.arange(batch_start[bb], batch_start[bb+1])
|
229 |
+
|
230 |
+
vision_probs = self.data['probs'][batch_inds, :]
|
231 |
+
vision_inds = self.data['inds'][batch_inds, :]
|
232 |
+
gt = self.data['labels'][batch_inds]
|
233 |
+
|
234 |
+
obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device'])
|
235 |
+
loc_feat = enc.encode(obs_locs_batch)
|
236 |
+
|
237 |
+
with torch.no_grad():
|
238 |
+
geo_pred = model(loc_feat).cpu().numpy()
|
239 |
+
|
240 |
+
geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds,
|
241 |
+
self.data['model_to_taxa'], self.taxon_map)
|
242 |
+
|
243 |
+
comb_pred = np.argmax(vision_pred*geo_pred, 1)
|
244 |
+
comb_pred = (comb_pred==gt)
|
245 |
+
correct_pred[batch_inds] = comb_pred
|
246 |
+
|
247 |
+
results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean())
|
248 |
+
results['vision_geo_top_1'] = float(correct_pred.mean())
|
249 |
+
return results
|
250 |
+
|
251 |
+
def report(self, results):
|
252 |
+
print('\nOverall accuracy vision only model', round(results['vision_only_top_1'], 3))
|
253 |
+
print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3))
|
254 |
+
print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3))
|
255 |
+
|
256 |
+
class EvaluatorGeoFeature:
|
257 |
+
|
258 |
+
def __init__(self, train_params, eval_params):
|
259 |
+
self.train_params = train_params
|
260 |
+
self.eval_params = eval_params
|
261 |
+
with open('paths.json', 'r') as f:
|
262 |
+
paths = json.load(f)
|
263 |
+
self.data_path = paths['geo_feature']
|
264 |
+
self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1
|
265 |
+
self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER']
|
266 |
+
self.raster_names_log_transform = ['POPULATION_DENSITY']
|
267 |
+
|
268 |
+
def load_raster(self, raster_name, log_transform=False):
|
269 |
+
raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32)
|
270 |
+
valid_mask = ~np.isnan(raster).copy() & self.country_mask
|
271 |
+
# log scaling:
|
272 |
+
if log_transform:
|
273 |
+
raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min())
|
274 |
+
# 0/1 scaling:
|
275 |
+
raster[valid_mask] -= raster[valid_mask].min()
|
276 |
+
raster[valid_mask] /= raster[valid_mask].max()
|
277 |
+
|
278 |
+
return raster, valid_mask
|
279 |
+
|
280 |
+
def get_split_labels(self, raster, split_ids, split_of_interest):
|
281 |
+
# get the GT labels for a subset
|
282 |
+
inds_y, inds_x = np.where(split_ids==split_of_interest)
|
283 |
+
return raster[inds_y, inds_x]
|
284 |
+
|
285 |
+
def get_split_feats(self, model, enc, split_ids, split_of_interest):
|
286 |
+
locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest)
|
287 |
+
locs = torch.from_numpy(locs).to(self.eval_params['device'])
|
288 |
+
locs_enc = enc.encode(locs)
|
289 |
+
with torch.no_grad():
|
290 |
+
feats = model(locs_enc, return_feats=True).cpu().numpy()
|
291 |
+
return feats
|
292 |
+
|
293 |
+
def run_evaluation(self, model, enc):
|
294 |
+
results = {}
|
295 |
+
for raster_name in self.raster_names:
|
296 |
+
do_log_transform = raster_name in self.raster_names_log_transform
|
297 |
+
raster, valid_mask = self.load_raster(raster_name, do_log_transform)
|
298 |
+
split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size'])
|
299 |
+
feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1)
|
300 |
+
feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2)
|
301 |
+
labels_train = self.get_split_labels(raster, split_ids, 1)
|
302 |
+
labels_test = self.get_split_labels(raster, split_ids, 2)
|
303 |
+
scaler = MinMaxScaler()
|
304 |
+
feats_train_scaled = scaler.fit_transform(feats_train)
|
305 |
+
feats_test_scaled = scaler.transform(feats_test)
|
306 |
+
clf = RidgeCV(alphas=(0.1, 1.0, 10.0), normalize=False, cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train)
|
307 |
+
train_score = clf.score(feats_train_scaled, labels_train)
|
308 |
+
test_score = clf.score(feats_test_scaled, labels_test)
|
309 |
+
results[f'train_r2_{raster_name}'] = float(train_score)
|
310 |
+
results[f'test_r2_{raster_name}'] = float(test_score)
|
311 |
+
results[f'alpha_{raster_name}'] = float(clf.alpha_)
|
312 |
+
return results
|
313 |
+
|
314 |
+
def report(self, results):
|
315 |
+
report_fields = [x for x in results if 'test_r2' in x]
|
316 |
+
for field in report_fields:
|
317 |
+
print(f'{field}: {results[field]}')
|
318 |
+
print(np.mean([results[field] for field in report_fields]))
|
319 |
+
|
320 |
+
def launch_eval_run(overrides):
|
321 |
+
|
322 |
+
eval_params = setup.get_default_params_eval(overrides)
|
323 |
+
|
324 |
+
# set up model:
|
325 |
+
eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name'])
|
326 |
+
train_params = torch.load(eval_params['model_path'], map_location='cpu')
|
327 |
+
model = models.get_model(train_params['params'])
|
328 |
+
model.load_state_dict(train_params['state_dict'], strict=True)
|
329 |
+
model = model.to(eval_params['device'])
|
330 |
+
model.eval()
|
331 |
+
|
332 |
+
# create input encoder:
|
333 |
+
if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
|
334 |
+
raster = datasets.load_env().to(eval_params['device'])
|
335 |
+
else:
|
336 |
+
raster = None
|
337 |
+
enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)
|
338 |
+
|
339 |
+
t = time.time()
|
340 |
+
if eval_params['eval_type'] == 'snt':
|
341 |
+
eval_params['split'] = 'test' # val, test, all
|
342 |
+
eval_params['val_frac'] = 0.50
|
343 |
+
eval_params['split_seed'] = 7499
|
344 |
+
evaluator = EvaluatorSNT(train_params['params'], eval_params)
|
345 |
+
results = evaluator.run_evaluation(model, enc)
|
346 |
+
evaluator.report(results)
|
347 |
+
elif eval_params['eval_type'] == 'iucn':
|
348 |
+
evaluator = EvaluatorIUCN(train_params['params'], eval_params)
|
349 |
+
results = evaluator.run_evaluation(model, enc)
|
350 |
+
evaluator.report(results)
|
351 |
+
elif eval_params['eval_type'] == 'geo_prior':
|
352 |
+
evaluator = EvaluatorGeoPrior(train_params['params'], eval_params)
|
353 |
+
results = evaluator.run_evaluation(model, enc)
|
354 |
+
evaluator.report(results)
|
355 |
+
elif eval_params['eval_type'] == 'geo_feature':
|
356 |
+
evaluator = EvaluatorGeoFeature(train_params['params'], eval_params)
|
357 |
+
results = evaluator.run_evaluation(model, enc)
|
358 |
+
evaluator.report(results)
|
359 |
+
else:
|
360 |
+
raise NotImplementedError('Eval type not implemented.')
|
361 |
+
print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min')
|
362 |
+
return results
|
images/sinr_traverse.gif
ADDED
![]() |
losses.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import utils
|
3 |
+
|
4 |
+
def get_loss_function(params):
|
5 |
+
if params['loss'] == 'an_full':
|
6 |
+
return an_full
|
7 |
+
elif params['loss'] == 'an_slds':
|
8 |
+
return an_slds
|
9 |
+
elif params['loss'] == 'an_ssdl':
|
10 |
+
return an_ssdl
|
11 |
+
elif params['loss'] == 'an_full_me':
|
12 |
+
return an_full_me
|
13 |
+
elif params['loss'] == 'an_slds_me':
|
14 |
+
return an_slds_me
|
15 |
+
elif params['loss'] == 'an_ssdl_me':
|
16 |
+
return an_ssdl_me
|
17 |
+
|
18 |
+
def neg_log(x):
|
19 |
+
return -torch.log(x + 1e-5)
|
20 |
+
|
21 |
+
def bernoulli_entropy(p):
|
22 |
+
entropy = p * neg_log(p) + (1-p) * neg_log(1-p)
|
23 |
+
return entropy
|
24 |
+
|
25 |
+
def an_ssdl(batch, model, params, loc_to_feats, neg_type='hard'):
|
26 |
+
|
27 |
+
inds = torch.arange(params['batch_size'])
|
28 |
+
|
29 |
+
loc_feat, _, class_id = batch
|
30 |
+
loc_feat = loc_feat.to(params['device'])
|
31 |
+
class_id = class_id.to(params['device'])
|
32 |
+
|
33 |
+
assert model.inc_bias == False
|
34 |
+
batch_size = loc_feat.shape[0]
|
35 |
+
|
36 |
+
# create random background samples and extract features
|
37 |
+
rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
|
38 |
+
rand_feat = loc_to_feats(rand_loc, normalize=False)
|
39 |
+
|
40 |
+
# get location embeddings
|
41 |
+
loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
|
42 |
+
loc_emb_cat = model(loc_cat, return_feats=True)
|
43 |
+
loc_emb = loc_emb_cat[:batch_size, :]
|
44 |
+
loc_emb_rand = loc_emb_cat[batch_size:, :]
|
45 |
+
|
46 |
+
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
|
47 |
+
loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
|
48 |
+
|
49 |
+
# data loss
|
50 |
+
loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
|
51 |
+
if neg_type == 'hard':
|
52 |
+
loss_bg = neg_log(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # assume negative
|
53 |
+
elif neg_type == 'entropy':
|
54 |
+
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # entropy
|
55 |
+
else:
|
56 |
+
raise NotImplementedError
|
57 |
+
|
58 |
+
# total loss
|
59 |
+
loss = loss_pos.mean() + loss_bg.mean()
|
60 |
+
|
61 |
+
return loss
|
62 |
+
|
63 |
+
def an_slds(batch, model, params, loc_to_feats, neg_type='hard'):
|
64 |
+
|
65 |
+
inds = torch.arange(params['batch_size'])
|
66 |
+
|
67 |
+
loc_feat, _, class_id = batch
|
68 |
+
loc_feat = loc_feat.to(params['device'])
|
69 |
+
class_id = class_id.to(params['device'])
|
70 |
+
|
71 |
+
assert model.inc_bias == False
|
72 |
+
batch_size = loc_feat.shape[0]
|
73 |
+
|
74 |
+
loc_emb = model(loc_feat, return_feats=True)
|
75 |
+
|
76 |
+
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
|
77 |
+
|
78 |
+
num_classes = loc_pred.shape[1]
|
79 |
+
bg_class = torch.randint(low=0, high=num_classes-1, size=(batch_size,), device=params['device'])
|
80 |
+
bg_class[bg_class >= class_id[:batch_size]] += 1
|
81 |
+
|
82 |
+
# data loss
|
83 |
+
loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
|
84 |
+
if neg_type == 'hard':
|
85 |
+
loss_bg = neg_log(1.0 - loc_pred[inds[:batch_size], bg_class]) # assume negative
|
86 |
+
elif neg_type == 'entropy':
|
87 |
+
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred[inds[:batch_size], bg_class]) # entropy
|
88 |
+
else:
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
# total loss
|
92 |
+
loss = loss_pos.mean() + loss_bg.mean()
|
93 |
+
|
94 |
+
return loss
|
95 |
+
|
96 |
+
def an_full(batch, model, params, loc_to_feats, neg_type='hard'):
|
97 |
+
|
98 |
+
inds = torch.arange(params['batch_size'])
|
99 |
+
|
100 |
+
loc_feat, _, class_id = batch
|
101 |
+
loc_feat = loc_feat.to(params['device'])
|
102 |
+
class_id = class_id.to(params['device'])
|
103 |
+
|
104 |
+
assert model.inc_bias == False
|
105 |
+
batch_size = loc_feat.shape[0]
|
106 |
+
|
107 |
+
# create random background samples and extract features
|
108 |
+
rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
|
109 |
+
rand_feat = loc_to_feats(rand_loc, normalize=False)
|
110 |
+
|
111 |
+
# get location embeddings
|
112 |
+
loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
|
113 |
+
loc_emb_cat = model(loc_cat, return_feats=True)
|
114 |
+
loc_emb = loc_emb_cat[:batch_size, :]
|
115 |
+
loc_emb_rand = loc_emb_cat[batch_size:, :]
|
116 |
+
# get predictions for locations and background locations
|
117 |
+
loc_pred = torch.sigmoid(model.class_emb(loc_emb))
|
118 |
+
loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
|
119 |
+
|
120 |
+
# data loss
|
121 |
+
if neg_type == 'hard':
|
122 |
+
loss_pos = neg_log(1.0 - loc_pred) # assume negative
|
123 |
+
loss_bg = neg_log(1.0 - loc_pred_rand) # assume negative
|
124 |
+
elif neg_type == 'entropy':
|
125 |
+
loss_pos = -1 * bernoulli_entropy(1.0 - loc_pred) # entropy
|
126 |
+
loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand) # entropy
|
127 |
+
else:
|
128 |
+
raise NotImplementedError
|
129 |
+
loss_pos[inds[:batch_size], class_id] = params['pos_weight'] * neg_log(loc_pred[inds[:batch_size], class_id])
|
130 |
+
|
131 |
+
# total loss
|
132 |
+
loss = loss_pos.mean() + loss_bg.mean()
|
133 |
+
|
134 |
+
return loss
|
135 |
+
|
136 |
+
def an_full_me(batch, model, params, loc_to_feats):
|
137 |
+
|
138 |
+
return an_full(batch, model, params, loc_to_feats, neg_type='entropy')
|
139 |
+
|
140 |
+
def an_ssdl_me(batch, model, params, loc_to_feats):
|
141 |
+
|
142 |
+
return an_ssdl(batch, model, params, loc_to_feats, neg_type='entropy')
|
143 |
+
|
144 |
+
def an_slds_me(batch, model, params, loc_to_feats):
|
145 |
+
|
146 |
+
return an_slds(batch, model, params, loc_to_feats, neg_type='entropy')
|
models.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.data
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
def get_model(params):
|
6 |
+
|
7 |
+
if params['model'] == 'ResidualFCNet':
|
8 |
+
return ResidualFCNet(params['input_dim'], params['num_classes'], params['num_filts'], params['depth'])
|
9 |
+
elif params['model'] == 'LinNet':
|
10 |
+
return LinNet(params['input_dim'], params['num_classes'])
|
11 |
+
else:
|
12 |
+
raise NotImplementedError('Invalid model specified.')
|
13 |
+
|
14 |
+
class ResLayer(nn.Module):
|
15 |
+
def __init__(self, linear_size):
|
16 |
+
super(ResLayer, self).__init__()
|
17 |
+
self.l_size = linear_size
|
18 |
+
self.nonlin1 = nn.ReLU(inplace=True)
|
19 |
+
self.nonlin2 = nn.ReLU(inplace=True)
|
20 |
+
self.dropout1 = nn.Dropout()
|
21 |
+
self.w1 = nn.Linear(self.l_size, self.l_size)
|
22 |
+
self.w2 = nn.Linear(self.l_size, self.l_size)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
y = self.w1(x)
|
26 |
+
y = self.nonlin1(y)
|
27 |
+
y = self.dropout1(y)
|
28 |
+
y = self.w2(y)
|
29 |
+
y = self.nonlin2(y)
|
30 |
+
out = x + y
|
31 |
+
return out
|
32 |
+
|
33 |
+
class ResidualFCNet(nn.Module):
|
34 |
+
|
35 |
+
def __init__(self, num_inputs, num_classes, num_filts, depth=4):
|
36 |
+
super(ResidualFCNet, self).__init__()
|
37 |
+
self.inc_bias = False
|
38 |
+
self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias)
|
39 |
+
layers = []
|
40 |
+
layers.append(nn.Linear(num_inputs, num_filts))
|
41 |
+
layers.append(nn.ReLU(inplace=True))
|
42 |
+
for i in range(depth):
|
43 |
+
layers.append(ResLayer(num_filts))
|
44 |
+
self.feats = torch.nn.Sequential(*layers)
|
45 |
+
|
46 |
+
def forward(self, x, class_of_interest=None, return_feats=False):
|
47 |
+
loc_emb = self.feats(x)
|
48 |
+
if return_feats:
|
49 |
+
return loc_emb
|
50 |
+
if class_of_interest is None:
|
51 |
+
class_pred = self.class_emb(loc_emb)
|
52 |
+
else:
|
53 |
+
class_pred = self.eval_single_class(loc_emb, class_of_interest)
|
54 |
+
return torch.sigmoid(class_pred)
|
55 |
+
|
56 |
+
def eval_single_class(self, x, class_of_interest):
|
57 |
+
if self.inc_bias:
|
58 |
+
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest]
|
59 |
+
else:
|
60 |
+
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T)
|
61 |
+
|
62 |
+
class LinNet(nn.Module):
|
63 |
+
def __init__(self, num_inputs, num_classes):
|
64 |
+
super(LinNet, self).__init__()
|
65 |
+
self.num_layers = 0
|
66 |
+
self.inc_bias = False
|
67 |
+
self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
|
68 |
+
self.feats = nn.Identity() # does not do anything
|
69 |
+
|
70 |
+
def forward(self, x, class_of_interest=None, return_feats=False):
|
71 |
+
loc_emb = self.feats(x)
|
72 |
+
if return_feats:
|
73 |
+
return loc_emb
|
74 |
+
if class_of_interest is None:
|
75 |
+
class_pred = self.class_emb(loc_emb)
|
76 |
+
else:
|
77 |
+
class_pred = self.eval_single_class(loc_emb, class_of_interest)
|
78 |
+
|
79 |
+
return torch.sigmoid(class_pred)
|
80 |
+
|
81 |
+
def eval_single_class(self, x, class_of_interest):
|
82 |
+
if self.inc_bias:
|
83 |
+
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest]
|
84 |
+
else:
|
85 |
+
return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T)
|
paths.json
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"masks": "data/masks/",
|
3 |
+
"env": "data/env/",
|
4 |
+
"train": "data/train/",
|
5 |
+
"geo_prior": "data/eval/geo_prior/",
|
6 |
+
"snt": "data/eval/snt/",
|
7 |
+
"iucn": "data/eval/iucn/",
|
8 |
+
"geo_feature": "data/eval/geo_feature/"
|
9 |
+
}
|
pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e8408dbfdcdc3008cfce801318cba2263149aaec0451a656d52958bf81115547
|
3 |
+
size 50849971
|
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:438265b758df7cf58f2ed39410205be6a9fa944e559d7556d9d5c7c0f501c4ae
|
3 |
+
size 50850118
|
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31811bf8e0a8bf8f59b9efc7fa56db015e49f02c594316a3a4389dc91ad6aae9
|
3 |
+
size 50850139
|
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7bf4dacb0f9b4cf8c5323e1c186b1e725267199053b9aa672d4b5dc1c3dbc235
|
3 |
+
size 50850160
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==3.36.1
|
2 |
+
h3==3.7.6
|
3 |
+
matplotlib==3.7.1
|
4 |
+
numpy==1.25.0
|
5 |
+
pandas==2.0.3
|
6 |
+
torch==1.12.1
|
setup.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import torch
|
3 |
+
|
4 |
+
def apply_overrides(params, overrides):
|
5 |
+
params = copy.deepcopy(params)
|
6 |
+
for param_name in overrides:
|
7 |
+
if param_name not in params:
|
8 |
+
print(f'override failed: no parameter named {param_name}')
|
9 |
+
raise ValueError
|
10 |
+
params[param_name] = overrides[param_name]
|
11 |
+
return params
|
12 |
+
|
13 |
+
def get_default_params_train(overrides={}):
|
14 |
+
|
15 |
+
params = {}
|
16 |
+
|
17 |
+
'''
|
18 |
+
misc
|
19 |
+
'''
|
20 |
+
params['device'] = 'cuda' # cuda, cpu
|
21 |
+
params['save_base'] = './experiments/'
|
22 |
+
params['experiment_name'] = 'demo'
|
23 |
+
params['timestamp'] = False
|
24 |
+
|
25 |
+
'''
|
26 |
+
data
|
27 |
+
'''
|
28 |
+
params['species_set'] = 'all' # all, snt_birds
|
29 |
+
params['hard_cap_seed'] = 9472
|
30 |
+
params['hard_cap_num_per_class'] = -1 # -1 for no hard capping
|
31 |
+
params['aux_species_seed'] = 8099
|
32 |
+
params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in
|
33 |
+
|
34 |
+
'''
|
35 |
+
model
|
36 |
+
'''
|
37 |
+
params['model'] = 'ResidualFCNet' # ResidualFCNet, LinNet
|
38 |
+
params['num_filts'] = 256 # embedding dimension
|
39 |
+
params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env
|
40 |
+
params['depth'] = 4
|
41 |
+
|
42 |
+
'''
|
43 |
+
loss
|
44 |
+
'''
|
45 |
+
params['loss'] = 'an_full' # an_full, an_ssdl, an_slds
|
46 |
+
params['pos_weight'] = 2048
|
47 |
+
|
48 |
+
'''
|
49 |
+
optimization
|
50 |
+
'''
|
51 |
+
params['batch_size'] = 2048
|
52 |
+
params['lr'] = 0.0005
|
53 |
+
params['lr_decay'] = 0.98
|
54 |
+
params['num_epochs'] = 10
|
55 |
+
|
56 |
+
'''
|
57 |
+
saving
|
58 |
+
'''
|
59 |
+
params['log_frequency'] = 512
|
60 |
+
|
61 |
+
params = apply_overrides(params, overrides)
|
62 |
+
|
63 |
+
return params
|
64 |
+
|
65 |
+
def get_default_params_eval(overrides={}):
|
66 |
+
|
67 |
+
params = {}
|
68 |
+
|
69 |
+
'''
|
70 |
+
misc
|
71 |
+
'''
|
72 |
+
params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
73 |
+
params['seed'] = 2022
|
74 |
+
params['exp_base'] = './experiments'
|
75 |
+
params['ckp_name'] = 'model.pt'
|
76 |
+
params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature
|
77 |
+
params['experiment_name'] = 'demo'
|
78 |
+
|
79 |
+
'''
|
80 |
+
geo prior
|
81 |
+
'''
|
82 |
+
params['batch_size'] = 2048
|
83 |
+
|
84 |
+
'''
|
85 |
+
geo feature
|
86 |
+
'''
|
87 |
+
params['cell_size'] = 25
|
88 |
+
|
89 |
+
params = apply_overrides(params, overrides)
|
90 |
+
|
91 |
+
return params
|
taxa_02_08_2023_names.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import datetime
|
5 |
+
|
6 |
+
class CoordEncoder:
|
7 |
+
|
8 |
+
def __init__(self, input_enc, raster=None):
|
9 |
+
self.input_enc = input_enc
|
10 |
+
self.raster = raster
|
11 |
+
|
12 |
+
def encode(self, locs, normalize=True):
|
13 |
+
# assumes lon, lat in range [-180, 180] and [-90, 90]
|
14 |
+
if normalize:
|
15 |
+
locs = normalize_coords(locs)
|
16 |
+
if self.input_enc == 'sin_cos': # sinusoidal encoding
|
17 |
+
loc_feats = encode_loc(locs)
|
18 |
+
elif self.input_enc == 'env': # bioclim variables
|
19 |
+
loc_feats = bilinear_interpolate(locs, self.raster)
|
20 |
+
elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
|
21 |
+
loc_feats = encode_loc(locs)
|
22 |
+
context_feats = bilinear_interpolate(locs, self.raster)
|
23 |
+
loc_feats = torch.cat((loc_feats, context_feats), 1)
|
24 |
+
else:
|
25 |
+
raise NotImplementedError('Unknown input encoding.')
|
26 |
+
return loc_feats
|
27 |
+
|
28 |
+
def normalize_coords(locs):
|
29 |
+
# locs is in lon {-180, 180}, lat {90, -90}
|
30 |
+
# output is in the range [-1, 1]
|
31 |
+
|
32 |
+
locs[:,0] /= 180.0
|
33 |
+
locs[:,1] /= 90.0
|
34 |
+
|
35 |
+
return locs
|
36 |
+
|
37 |
+
def encode_loc(loc_ip, concat_dim=1):
|
38 |
+
# assumes inputs location are in range -1 to 1
|
39 |
+
# location is lon, lat
|
40 |
+
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)
|
41 |
+
return feats
|
42 |
+
|
43 |
+
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
|
44 |
+
# loc is N x 2 vector, where each row is [lon,lat] entry
|
45 |
+
# each entry spans range [-1,1]
|
46 |
+
# data is H x W x C, height x width x channel data matrix
|
47 |
+
# op will be N x C matrix of interpolated features
|
48 |
+
|
49 |
+
assert data is not None
|
50 |
+
|
51 |
+
# map to [0,1], then scale to data size
|
52 |
+
loc = (loc_ip.clone() + 1) / 2.0
|
53 |
+
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
|
54 |
+
# longitude goes from -90 to 90 left to right
|
55 |
+
|
56 |
+
assert not torch.any(torch.isnan(loc))
|
57 |
+
|
58 |
+
if remove_nans_raster:
|
59 |
+
data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
|
60 |
+
|
61 |
+
# cast locations into pixel space
|
62 |
+
loc[:, 0] *= (data.shape[1]-1)
|
63 |
+
loc[:, 1] *= (data.shape[0]-1)
|
64 |
+
|
65 |
+
loc_int = torch.floor(loc).long() # integer pixel coordinates
|
66 |
+
xx = loc_int[:, 0]
|
67 |
+
yy = loc_int[:, 1]
|
68 |
+
xx_plus = xx + 1
|
69 |
+
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
|
70 |
+
yy_plus = yy + 1
|
71 |
+
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
|
72 |
+
|
73 |
+
loc_delta = loc - torch.floor(loc) # delta values
|
74 |
+
dx = loc_delta[:, 0].unsqueeze(1)
|
75 |
+
dy = loc_delta[:, 1].unsqueeze(1)
|
76 |
+
|
77 |
+
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
|
78 |
+
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
|
79 |
+
|
80 |
+
return interp_val
|
81 |
+
|
82 |
+
def rand_samples(batch_size, device, rand_type='uniform'):
|
83 |
+
# randomly sample background locations
|
84 |
+
|
85 |
+
if rand_type == 'spherical':
|
86 |
+
rand_loc = torch.rand(batch_size, 2).to(device)
|
87 |
+
theta1 = 2.0*math.pi*rand_loc[:, 0]
|
88 |
+
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
|
89 |
+
lat = 1.0 - 2.0*theta2/math.pi
|
90 |
+
lon = (theta1/math.pi) - 1.0
|
91 |
+
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
|
92 |
+
|
93 |
+
elif rand_type == 'uniform':
|
94 |
+
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
|
95 |
+
|
96 |
+
return rand_loc
|
97 |
+
|
98 |
+
def get_time_stamp():
|
99 |
+
cur_time = str(datetime.datetime.now())
|
100 |
+
date, time = cur_time.split(' ')
|
101 |
+
h, m, s = time.split(':')
|
102 |
+
s = s.split('.')[0]
|
103 |
+
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
|
104 |
+
return time_stamp
|
105 |
+
|
106 |
+
def coord_grid(grid_size, split_ids=None, split_of_interest=None):
|
107 |
+
# generate a grid of locations spaced evenly in coordinate space
|
108 |
+
|
109 |
+
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
|
110 |
+
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
|
111 |
+
feats[:, :, 0] = mg[0]
|
112 |
+
feats[:, :, 1] = mg[1]
|
113 |
+
if split_ids is None or split_of_interest is None:
|
114 |
+
# return feats for all locations
|
115 |
+
# this will be an N x 2 array
|
116 |
+
return feats.reshape(feats.shape[0]*feats.shape[1], 2)
|
117 |
+
else:
|
118 |
+
# only select a subset of locations
|
119 |
+
ind_y, ind_x = np.where(split_ids==split_of_interest)
|
120 |
+
|
121 |
+
# these will be N_subset x 2 in size
|
122 |
+
return feats[ind_y, ind_x, :]
|
123 |
+
|
124 |
+
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
|
125 |
+
# generates a checkerboard style train test split
|
126 |
+
# 0 is invalid, 1 is train, and 2 is test
|
127 |
+
# c_size is units of pixels
|
128 |
+
split_ids = np.ones((raster.shape[0], raster.shape[1]))
|
129 |
+
start = cell_size
|
130 |
+
for ii in np.arange(0, split_ids.shape[0], cell_size):
|
131 |
+
if start == 0:
|
132 |
+
start = cell_size
|
133 |
+
else:
|
134 |
+
start = 0
|
135 |
+
for jj in np.arange(start, split_ids.shape[1], cell_size*2):
|
136 |
+
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
|
137 |
+
split_ids = split_ids*mask
|
138 |
+
if train_amt < 1.0:
|
139 |
+
# take a subset of the data
|
140 |
+
tr_y, tr_x = np.where(split_ids==1)
|
141 |
+
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
|
142 |
+
split_ids[tr_y[inds], tr_x[inds]] = 0
|
143 |
+
return split_ids
|