Oisin Mac Aodha commited on
Commit
505e401
·
1 Parent(s): 6570723

First model version

Browse files
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Elijah Cole
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,73 @@
1
- ---
2
- title: Sinr
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 3.38.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spatial Implicit Neural Representations for Global-Scale Species Mapping - ICML 2023
2
+
3
+ Code for training and evaluating global-scale species range estimation models. This code enables the recreation of the results from our ICML 2023 paper [Spatial Implicit Neural Representations for Global-Scale Species Mapping](https://arxiv.org/abs/2306.02564).
4
+
5
+ ## 🌍 Overview
6
+ Estimating the geographical range of a species from sparse observations is a challenging and important geospatial prediction problem. Given a set of locations where a species has been observed, the goal is to build a model to predict whether the species is present or absent at any location. In this work, we use Spatial Implicit Neural Representations (SINRs) to jointly estimate the geographical range of thousands of species simultaneously. SINRs scale gracefully, making better predictions as we increase the number of training species and the amount of training data per species. We introduce four new range estimation and spatial representation learning benchmarks, and we use them to demonstrate that noisy and biased crowdsourced data can be combined with implicit neural representations to approximate expert-developed range maps for many species.
7
+
8
+ ![Model Prediction](images/sinr_traverse.gif)
9
+ <sup>Above we visualize predictions from one of our SINR models trained on data from [iNaturalist](inaturalist.org). On the left we show the learned species embedding space, where each point represents a different species. On the right we see the predicted range of the species corresponding to the red dot on the left.<sup>
10
+
11
+ ## 🔍 Getting Started
12
+
13
+ #### Installing Required Packages
14
+
15
+ 1. We recommend using an isolated Python environment to avoid dependency issues. Install the Anaconda Python 3.9 distribution for your operating system from [here](https://www.anaconda.com/download).
16
+
17
+ 2. Create a new environment and activate it:
18
+ ```bash
19
+ conda create -y --name sinr_icml python==3.9
20
+ conda activate sinr_icml
21
+ ```
22
+
23
+ 3. After activating the environment, install the required packages:
24
+ ```bash
25
+ pip3 install -r requirements.txt
26
+ ```
27
+
28
+ #### Data Download and Preparation
29
+ Instructions for downloading the data in `data/README.md`.
30
+
31
+ ## 🗺️ Generating Predictions
32
+ To generate predictions for a model in the form of an image, run the following command:
33
+ ```bash
34
+ python viz_map.py --taxa_id 130714
35
+ ```
36
+ Here, `--taxa_id` is the id number for a species of interest from [iNaturalist](https://www.inaturalist.org/taxa/130714). If you want to generate predictions for a random species, add the `--rand_taxa` instead.
37
+
38
+ Note, before you run this command you need to first download the data as described in `data/README.md`. In addition, if you want to evaluate some of the pretrained models from the paper, you need to download those first and place them at `sinr/pretrained_models`. See `web_app/README.md` for more details.
39
+
40
+ There is also an interactive browser-based demo available in `web_app`.
41
+
42
+ ## 🚅 Training and Evaluating Models
43
+
44
+ To train and evaluate a model, run the following command:
45
+ ```bash
46
+ python train_and_evaluate_models.py
47
+ ```
48
+
49
+ #### Hyperparameters
50
+ Common parameters of interest can be set within `train_and_evaluate_models.py`. All other parameters are exposed in `setup.py`.
51
+
52
+ #### Outputs
53
+ By default, trained models and evaluation results will be saved to a folder in the `experiments` directory. Evaluation results will also be printed to the command line.
54
+
55
+ #### Interactive Model Visualizer
56
+ To visualize range predictions from pretrained SINR models, please follow the instructions in `web_app/README.md`.
57
+
58
+ ## 🙏 Acknowledgements
59
+ This project was enabled by data from the Cornell Lab of Ornithology, The International Union for the Conservation of Nature, iNaturalist, NASA, USGS, JAXA, CIESIN, and UC Merced. We are especially indebted to the [iNaturalist](inaturalist.org) and [eBird](https://ebird.org) communities for their data collection efforts. We also thank Matt Stimas-Mackey and Sam Heinrich for their help with data curation. This project was funded by the [Climate Change AI Innovation Grants](https://www.climatechange.ai/blog/2022-04-13-innovation-grants) program, hosted by Climate Change AI with the support of the Quadrature Climate Foundation, Schmidt Futures, and the Canada Hub of Future Earth. This work was also supported by the Caltech Resnick Sustainability Institute and an NSF Graduate Research Fellowship (grant number DGE1745301).
60
+
61
+ If you find our work useful in your research please consider citing our paper.
62
+ ```
63
+ @inproceedings{SINR_icml23,
64
+ title = {{Spatial Implicit Neural Representations for Global-Scale Species Mapping}},
65
+ author = {Cole, Elijah and Van Horn, Grant and Lange, Christian and Shepard, Alexander and Leary, Patrick and Perona, Pietro and Loarie, Scott and Mac Aodha, Oisin},
66
+ booktitle = {ICML},
67
+ year = {2023}
68
+ }
69
+ ```
70
+
71
+ ## 📜 Disclaimer
72
+ Extreme care should be taken before making any decisions based on the outputs of models presented here. Our goal in this work is to demonstrate the promise of large-scale representation learning for species range estimation, not to provide definitive range maps. Our models are trained on biased data and have not been calibrated or validated beyond the experiments illustrated in the paper.
73
+
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib
4
+ matplotlib.use('Agg')
5
+ import matplotlib.pyplot as plt
6
+ import json
7
+ import os
8
+ import torch
9
+
10
+ import utils
11
+ import models
12
+ import datasets
13
+
14
+
15
+ def load_taxa_metadata(file_path):
16
+ taxa_names_file = open(file_path, "r")
17
+ data = taxa_names_file.read().split("\n")
18
+ data = [dd for dd in data if dd != '']
19
+ taxa_ids = []
20
+ taxa_names = []
21
+ for tt in range(len(data)):
22
+ id, nm = data[tt].split('\t')
23
+ taxa_ids.append(int(id))
24
+ taxa_names.append(nm)
25
+ taxa_names_file.close()
26
+ return dict(zip(taxa_ids, taxa_names))
27
+
28
+
29
+ def generate_prediction(taxa_id, selected_model, settings, threshold):
30
+
31
+ # select the model to use
32
+ if selected_model == 'AN_FULL max 10':
33
+ model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt'
34
+ elif selected_model == 'AN_FULL max 100':
35
+ model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt'
36
+ elif selected_model == 'AN_FULL max 1000':
37
+ model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt'
38
+ elif selected_model == 'Distilled env model':
39
+ model_path = 'pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt'
40
+
41
+ # load params
42
+ with open('paths.json', 'r') as f:
43
+ paths = json.load(f)
44
+
45
+ # configs
46
+ eval_params = {}
47
+ eval_params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+ eval_params['model_path'] = model_path
49
+ eval_params['taxa_id'] = int(taxa_id)
50
+ eval_params['rand_taxa'] = 'Random taxa' in settings
51
+ eval_params['set_max_cmap_to_1'] = False
52
+ eval_params['disable_ocean_mask'] = 'Distilled env model' in settings
53
+ eval_params['threshold'] = threshold if 'Threshold' in settings else -1.0
54
+
55
+ # load model
56
+ train_params = torch.load(eval_params['model_path'], map_location='cpu')
57
+ model = models.get_model(train_params['params'])
58
+ model.load_state_dict(train_params['state_dict'], strict=True)
59
+ model = model.to(eval_params['device'])
60
+ model.eval()
61
+ if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
62
+ raster = datasets.load_env(norm=train_params['params']['env_norm'])
63
+ else:
64
+ raster = None
65
+ enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)
66
+
67
+ # user specified random taxa
68
+ if eval_params['rand_taxa']:
69
+ print('Selecting random taxa')
70
+ eval_params['taxa_id'] = np.random.choice(train_params['params']['class_to_taxa'])
71
+
72
+ # load taxa of interest
73
+ if eval_params['taxa_id'] in train_params['params']['class_to_taxa']:
74
+ class_of_interest = train_params['params']['class_to_taxa'].index(eval_params['taxa_id'])
75
+ else:
76
+ print(f'Error: Taxa specified that is not in the model: {eval_params["taxa_id"]}')
77
+ fig = plt.figure()
78
+ plt.imshow(np.zeros((1,1)), vmin=0, vmax=1.0, cmap=plt.cm.plasma)
79
+ plt.axis('off')
80
+ plt.tight_layout()
81
+ op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{eval_params["taxa_id"]}</a></h2> Error: specified taxa is not in the model.'
82
+
83
+ return op_html, fig, eval_params['taxa_id']
84
+ print(f'Loading taxa: {eval_params["taxa_id"]}')
85
+
86
+ # load ocean mask
87
+ mask = np.load(os.path.join(paths['masks'], 'ocean_mask.npy'))
88
+ mask_inds = np.where(mask.reshape(-1) == 1)[0]
89
+
90
+ # generate input features
91
+ locs = utils.coord_grid(mask.shape)
92
+ if not eval_params['disable_ocean_mask']:
93
+ locs = locs[mask_inds, :]
94
+ locs = torch.from_numpy(locs)
95
+ locs_enc = enc.encode(locs).to(eval_params['device'])
96
+
97
+ # make prediction
98
+ with torch.no_grad():
99
+ preds = model(locs_enc, return_feats=False, class_of_interest=class_of_interest).cpu().numpy()
100
+
101
+ # threshold predictions
102
+ if eval_params['threshold'] > 0:
103
+ print(f'Applying threshold of {eval_params["threshold"]} to the predictions.')
104
+ preds[preds<eval_params['threshold']] = 0.0
105
+ preds[preds>=eval_params['threshold']] = 1.0
106
+
107
+ # mask data
108
+ if not eval_params['disable_ocean_mask']:
109
+ op_im = np.ones((mask.shape[0] * mask.shape[1])) * np.nan # set to NaN
110
+ op_im[mask_inds] = preds
111
+ else:
112
+ op_im = preds
113
+
114
+ # reshape and create masked array for visualization
115
+ op_im = op_im.reshape((mask.shape[0], mask.shape[1]))
116
+ op_im = np.ma.masked_invalid(op_im)
117
+ if eval_params['set_max_cmap_to_1']:
118
+ vmax = 1.0
119
+ else:
120
+ vmax = np.max(op_im)
121
+
122
+ # set color for masked values
123
+ cmap = plt.cm.plasma
124
+ cmap.set_bad(color='none')
125
+
126
+ plt.rcParams['figure.figsize'] = 24,12
127
+ fig = plt.figure()
128
+ plt.imshow(op_im, vmin=0, vmax=vmax, cmap=cmap)
129
+ plt.axis('off')
130
+ plt.tight_layout()
131
+
132
+ # generate html for ouput display
133
+ taxa_name_str = taxa_names[eval_params['taxa_id']]
134
+ op_html = f'<h2><a href="https://www.inaturalist.org/taxa/{eval_params["taxa_id"]}" target="_blank">{taxa_name_str}</a></h2> (click for more info)'
135
+ return op_html, fig, gr.Number.update(value=eval_params['taxa_id'])
136
+
137
+
138
+ # load metadata
139
+ taxa_names = load_taxa_metadata('taxa_02_08_2023_names.txt')
140
+
141
+
142
+ with gr.Blocks(title="SINR Demo") as demo:
143
+ top_text = "Visualization code to explore species range predictions "\
144
+ "from Spatial Implicit Neural Representation (SINR) models from "\
145
+ "[our](https://arxiv.org/abs/2306.02564) ICML 2023 paper."
146
+ gr.Markdown("# SINR Visualization Demo")
147
+ gr.Markdown(top_text)
148
+
149
+ with gr.Row():
150
+ selected_taxa = gr.Number(label="Taxa ID", value=130714)
151
+ select_model = gr.Dropdown(["AN_FULL max 10", "AN_FULL max 100", "AN_FULL max 1000", "Distilled env model"],
152
+ value="AN_FULL max 1000", label="Model")
153
+ with gr.Row():
154
+ settings = gr.CheckboxGroup(["Random taxa", "Disable ocean mask", "Threshold"], label="Settings")
155
+ threshold = gr.Slider(0, 1, 0, label="Threshold")
156
+
157
+ with gr.Row():
158
+ submit_button = gr.Button("Run Model")
159
+
160
+ with gr.Row():
161
+ output_text = gr.HTML(label="Species Name:")
162
+
163
+ with gr.Row():
164
+ output_image = gr.Plot(label="Predicted occupancy")
165
+
166
+ end_text = "**Note:** Extreme care should be taken before making any decisions "\
167
+ "based on the outputs of models presented here. "\
168
+ "The goal of this work is to demonstrate the promise of large-scale "\
169
+ "representation learning for species range estimation. "\
170
+ "Our models are trained on biased data and have not been calibrated "\
171
+ "or validated beyondthe experiments illustrated in the paper."
172
+ gr.Markdown(end_text)
173
+
174
+ submit_button.click(
175
+ fn = generate_prediction,
176
+ inputs=[selected_taxa, select_model, settings, threshold],
177
+ outputs=[output_text, output_image, selected_taxa]
178
+ )
179
+
180
+ demo.launch()
data/masks/ocean_mask.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c41395204ddb6327089c27cc5bf14083d7d74176cf419ddf62551eb681c54974
3
+ size 2008136
datasets.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import json
4
+ import pandas as pd
5
+ from calendar import monthrange
6
+ import torch
7
+ import utils
8
+
9
+ class LocationDataset(torch.utils.data.Dataset):
10
+ def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device):
11
+
12
+ # handle input encoding:
13
+ self.input_enc = input_enc
14
+ if self.input_enc in ['env', 'sin_cos_env']:
15
+ raster = load_env()
16
+ else:
17
+ raster = None
18
+ self.enc = utils.CoordEncoder(input_enc, raster)
19
+
20
+ # define some properties:
21
+ self.locs = locs
22
+ self.loc_feats = self.enc.encode(self.locs)
23
+ self.labels = labels
24
+ self.classes = classes
25
+ self.class_to_taxa = class_to_taxa
26
+
27
+ # useful numbers:
28
+ self.num_classes = len(np.unique(labels))
29
+ self.input_dim = self.loc_feats.shape[1]
30
+
31
+ if self.enc.raster is not None:
32
+ self.enc.raster = self.enc.raster.to(device)
33
+
34
+ def __len__(self):
35
+ return self.loc_feats.shape[0]
36
+
37
+ def __getitem__(self, index):
38
+ loc_feat = self.loc_feats[index, :]
39
+ loc = self.locs[index, :]
40
+ class_id = self.labels[index]
41
+ return loc_feat, loc, class_id
42
+
43
+ def load_env():
44
+ with open('paths.json', 'r') as f:
45
+ paths = json.load(f)
46
+ raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy'))
47
+ return raster
48
+
49
+ def load_context_feats(data_path):
50
+ context_feats = np.load(data_path).astype(np.float32)
51
+ context_feats = torch.from_numpy(context_feats)
52
+ return context_feats
53
+
54
+ def load_inat_data(ip_file, taxa_of_interest=None):
55
+
56
+ print('\nLoading ' + ip_file)
57
+ data = pd.read_csv(ip_file)
58
+
59
+ # remove outliers
60
+ num_obs = data.shape[0]
61
+ data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )]
62
+ if (num_obs - data.shape[0]) > 0:
63
+ print(num_obs - data.shape[0], 'items filtered due to invalid locations')
64
+
65
+ if 'accuracy' in data.columns:
66
+ data.drop(['accuracy'], axis=1, inplace=True)
67
+
68
+ if 'positional_accuracy' in data.columns:
69
+ data.drop(['positional_accuracy'], axis=1, inplace=True)
70
+
71
+ if 'geoprivacy' in data.columns:
72
+ data.drop(['geoprivacy'], axis=1, inplace=True)
73
+
74
+ if 'observed_on' in data.columns:
75
+ data.rename(columns = {'observed_on':'date'}, inplace=True)
76
+
77
+ num_obs_orig = data.shape[0]
78
+ data = data.dropna()
79
+ size_diff = num_obs_orig - data.shape[0]
80
+ if size_diff > 0:
81
+ print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed')
82
+
83
+ # keep only taxa of interest:
84
+ if taxa_of_interest is not None:
85
+ num_obs_orig = data.shape[0]
86
+ data = data[data['taxon_id'].isin(taxa_of_interest)]
87
+ print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed')
88
+
89
+ print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0]))
90
+
91
+ locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32)
92
+ taxa = data['taxon_id'].values.astype(np.int)
93
+
94
+ if 'user_id' in data.columns:
95
+ users = data['user_id'].values.astype(np.int)
96
+ _, users = np.unique(users, return_inverse=True)
97
+ elif 'observer_id' in data.columns:
98
+ users = data['observer_id'].values.astype(np.int)
99
+ _, users = np.unique(users, return_inverse=True)
100
+ else:
101
+ users = np.ones(taxa.shape[0], dtype=np.int)*-1
102
+
103
+ # Note - assumes that dates are in format YYYY-MM-DD
104
+ years = np.array([int(d_str[:4]) for d_str in data['date'].values])
105
+ months = np.array([int(d_str[5:7]) for d_str in data['date'].values])
106
+ days = np.array([int(d_str[8:10]) for d_str in data['date'].values])
107
+ days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)])
108
+ dates = days_per_month[months-1] + days-1
109
+ dates = np.round((dates) / 364.0, 4).astype(np.float32)
110
+ if 'id' in data.columns:
111
+ obs_ids = data['id'].values
112
+ elif 'observation_uuid' in data.columns:
113
+ obs_ids = data['observation_uuid'].values
114
+
115
+ return locs, taxa, users, dates, years, obs_ids
116
+
117
+ def choose_aux_species(current_species, num_aux_species, aux_species_seed):
118
+ if num_aux_species == 0:
119
+ return []
120
+ with open('paths.json', 'r') as f:
121
+ paths = json.load(f)
122
+ data_dir = paths['train']
123
+ taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
124
+ with open(taxa_file, 'r') as f:
125
+ inat_large_metadata = json.load(f)
126
+ aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata]
127
+ aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species)
128
+ print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates')
129
+ rng = np.random.default_rng(aux_species_seed)
130
+ idx_rand_aux_species = rng.permutation(len(aux_species_candidates))
131
+ aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]])
132
+ return aux_species
133
+
134
+ def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file_snt=None):
135
+ if species_set == 'all':
136
+ return None
137
+ if species_set == 'snt_birds':
138
+ assert taxa_file_snt is not None
139
+ with open(taxa_file_snt, 'r') as f: #
140
+ taxa_subsets = json.load(f)
141
+ taxa_of_interest = list(taxa_subsets['snt_birds'])
142
+ else:
143
+ raise NotImplementedError
144
+ # optionally add some other species back in:
145
+ aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed)
146
+ taxa_of_interest.extend(aux_species)
147
+ return taxa_of_interest
148
+
149
+ def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123):
150
+ if hard_cap == -1:
151
+ return np.arange(len(labels))
152
+ print(f'subsampling (up to) {hard_cap} per class for the training set')
153
+ class_counts = {id: 0 for id in np.unique(labels)}
154
+ ss_rng = np.random.default_rng(hard_cap_seed)
155
+ idx_rand = ss_rng.permutation(len(labels))
156
+ idx_ss = []
157
+ for i in idx_rand:
158
+ class_id = labels[i]
159
+ if class_counts[class_id] < hard_cap:
160
+ idx_ss.append(i)
161
+ class_counts[class_id] += 1
162
+ idx_ss = np.sort(idx_ss)
163
+ print(f'final training set size: {len(idx_ss)}')
164
+ return idx_ss
165
+
166
+ def get_train_data(params):
167
+ with open('paths.json', 'r') as f:
168
+ paths = json.load(f)
169
+ data_dir = paths['train']
170
+ obs_file = os.path.join(data_dir, 'geo_prior_train.csv')
171
+ taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
172
+ taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
173
+
174
+ taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], taxa_file_snt)
175
+
176
+ locs, labels, _, _, _, _ = load_inat_data(obs_file, taxa_of_interest)
177
+ unique_taxa, class_ids = np.unique(labels, return_inverse=True)
178
+ class_to_taxa = unique_taxa.tolist()
179
+
180
+ # load class names
181
+ class_info_file = json.load(open(taxa_file, 'r'))
182
+ class_names_file = [cc['latin_name'] for cc in class_info_file]
183
+ taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
184
+ classes = dict(zip(taxa_ids_file, class_names_file))
185
+
186
+ idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'])
187
+
188
+ locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
189
+
190
+ labels = torch.from_numpy(np.array(class_ids)[idx_ss])
191
+
192
+ ds = LocationDataset(locs, labels, classes, class_to_taxa, params['input_enc'], params['device'])
193
+
194
+ return ds
eval.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import random
4
+ import torch
5
+ import time
6
+ import os
7
+ import copy
8
+ import json
9
+ import tifffile
10
+ import h3
11
+ import setup
12
+
13
+ from sklearn.linear_model import RidgeCV
14
+ from sklearn.preprocessing import MinMaxScaler
15
+ from sklearn.metrics import average_precision_score
16
+
17
+ import utils
18
+ import models
19
+ import datasets
20
+
21
+ class EvaluatorSNT:
22
+ def __init__(self, train_params, eval_params):
23
+ self.train_params = train_params
24
+ self.eval_params = eval_params
25
+ with open('paths.json', 'r') as f:
26
+ paths = json.load(f)
27
+ D = np.load(os.path.join(paths['snt'], 'snt_res_5.npy'), allow_pickle=True)
28
+ D = D.item()
29
+ self.loc_indices_per_species = D['loc_indices_per_species']
30
+ self.labels_per_species = D['labels_per_species']
31
+ self.taxa = D['taxa']
32
+ self.obs_locs = D['obs_locs']
33
+ self.obs_locs_idx = D['obs_locs_idx']
34
+
35
+ def get_labels(self, species):
36
+ species = str(species)
37
+ lat = []
38
+ lon = []
39
+ gt = []
40
+ for hx in self.data:
41
+ cur_lat, cur_lon = h3.h3_to_geo(hx)
42
+ if species in self.data[hx]:
43
+ cur_label = int(len(self.data[hx][species]) > 0)
44
+ gt.append(cur_label)
45
+ lat.append(cur_lat)
46
+ lon.append(cur_lon)
47
+ lat = np.array(lat).astype(np.float32)
48
+ lon = np.array(lon).astype(np.float32)
49
+ obs_locs = np.vstack((lon, lat)).T
50
+ gt = np.array(gt).astype(np.float32)
51
+ return obs_locs, gt
52
+
53
+ def run_evaluation(self, model, enc):
54
+ results = {}
55
+
56
+ # set seeds:
57
+ np.random.seed(self.eval_params['seed'])
58
+ random.seed(self.eval_params['seed'])
59
+
60
+ # evaluate the geo model for each taxon
61
+ results['mean_average_precision'] = np.zeros((len(self.taxa)), dtype=np.float32)
62
+ # get eval locations and apply input encoding
63
+ obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device'])
64
+ loc_feat = enc.encode(obs_locs)
65
+ # get classes to eval
66
+ classes_of_interest = np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze()
67
+ classes_of_interest = torch.from_numpy(classes_of_interest)
68
+ # generate model predictions for classes of interest at eval locations
69
+ with torch.no_grad():
70
+ loc_emb = model(loc_feat, return_feats=True)
71
+ wt = model.class_emb.weight[classes_of_interest, :]
72
+ pred_mtx = torch.matmul(loc_emb, wt.T).cpu().numpy()
73
+
74
+ split_rng = np.random.default_rng(self.eval_params['split_seed'])
75
+
76
+ for tt_id, tt in enumerate(self.taxa):
77
+ # generate ground truth labels for current taxa
78
+ cur_class_of_interest = np.where(self.taxa == tt)[0][0]
79
+ cur_loc_indices = np.array(self.loc_indices_per_species[cur_class_of_interest])
80
+ cur_labels = np.array(self.labels_per_species[cur_class_of_interest])
81
+
82
+ # apply per-species split:
83
+ assert self.eval_params['split'] in ['all', 'val', 'test']
84
+ if self.eval_params['split'] != 'all':
85
+ num_val = np.floor(len(cur_labels) * self.eval_params['val_frac']).astype(int)
86
+ idx_rand = split_rng.permutation(len(cur_labels))
87
+ if self.eval_params['split'] == 'val':
88
+ idx_sel = idx_rand[:num_val]
89
+ elif self.eval_params['split'] == 'test':
90
+ idx_sel = idx_rand[num_val:]
91
+ cur_loc_indices = cur_loc_indices[idx_sel]
92
+ cur_labels = cur_labels[idx_sel]
93
+
94
+ # extract model predictions for current taxa from prediction matrix:
95
+ pred = pred_mtx[cur_loc_indices, tt_id]
96
+
97
+ # compute the AP for each taxa
98
+ results['mean_average_precision'][tt_id] = average_precision_score((cur_labels > 0).astype(np.int32), pred)
99
+
100
+
101
+ valid_taxa = ~np.isnan(results['mean_average_precision'])
102
+
103
+ # store results
104
+ results['per_species_average_precision_all'] = copy.deepcopy(results['mean_average_precision'])
105
+ per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa]
106
+ results['mean_average_precision'] = per_species_average_precision_valid.mean()
107
+ results['num_eval_species_w_valid_ap'] = valid_taxa.sum()
108
+ results['num_eval_species_total'] = len(self.taxa)
109
+
110
+ return results
111
+
112
+ def report(self, results):
113
+ for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']:
114
+ print(f'{field}: {results[field]}')
115
+
116
+ class EvaluatorIUCN:
117
+
118
+ def __init__(self, train_params, eval_params):
119
+ self.train_params = train_params
120
+ self.eval_params = eval_params
121
+ with open('paths.json', 'r') as f:
122
+ paths = json.load(f)
123
+ with open(os.path.join(paths['iucn'], 'iucn_res_5.json'), 'r') as f:
124
+ self.data = json.load(f)
125
+ self.obs_locs = np.array(self.data['locs'], dtype=np.float32)
126
+ self.taxa = [int(tt) for tt in self.data['taxa_presence'].keys()]
127
+
128
+ def run_evaluation(self, model, enc):
129
+ results = {}
130
+
131
+ results['per_species_average_precision_all'] = np.zeros(len(self.taxa), dtype=np.float32)
132
+ # get eval locations and apply input encoding
133
+ obs_locs = torch.from_numpy(self.obs_locs).to(self.eval_params['device'])
134
+ loc_feat = enc.encode(obs_locs)
135
+
136
+ # get classes to eval
137
+ classes_of_interest = torch.from_numpy(np.array([np.where(np.array(self.train_params['class_to_taxa']) == tt)[0] for tt in self.taxa]).squeeze())
138
+ with torch.no_grad():
139
+ # generate model predictions for classes of interest at eval locations
140
+ loc_emb = model(loc_feat, return_feats=True)
141
+ wt = model.class_emb.weight[classes_of_interest, :]
142
+ pred_mtx = torch.matmul(loc_emb, wt.T)
143
+
144
+ for tt_id, tt in enumerate(self.taxa):
145
+ class_of_interest = np.where(np.array(self.train_params['class_to_taxa']) == tt)[0]
146
+
147
+ if len(class_of_interest) == 0:
148
+ # taxa of interest is not in the model
149
+ pred = None
150
+ else:
151
+ # extract model predictions for current taxa from prediction matrix
152
+ pred = pred_mtx[:, tt_id]
153
+
154
+ # evaluate accuracy
155
+ if pred is None:
156
+ results['per_species_average_precision_all'][tt_id] = np.nan
157
+ else:
158
+ gt = np.zeros(obs_locs.shape[0], dtype=np.float32)
159
+ gt[self.data['taxa_presence'][str(tt)]] = 1.0
160
+ # average precision score:
161
+ results['per_species_average_precision_all'][tt_id] = average_precision_score(gt, pred)
162
+
163
+ valid_taxa = ~np.isnan(results['per_species_average_precision_all'])
164
+
165
+ # store results
166
+ per_species_average_precision_valid = results['per_species_average_precision_all'][valid_taxa]
167
+ results['mean_average_precision'] = per_species_average_precision_valid.mean()
168
+ results['num_eval_species_w_valid_ap'] = valid_taxa.sum()
169
+ results['num_eval_species_total'] = len(self.taxa)
170
+ return results
171
+
172
+ def report(self, results):
173
+ for field in ['mean_average_precision', 'num_eval_species_w_valid_ap', 'num_eval_species_total']:
174
+ print(f'{field}: {results[field]}')
175
+
176
+ class EvaluatorGeoPrior:
177
+
178
+ def __init__(self, train_params, eval_params):
179
+ # store parameters:
180
+ self.train_params = train_params
181
+ self.eval_params = eval_params
182
+ with open('paths.json', 'r') as f:
183
+ paths = json.load(f)
184
+ # load vision model predictions:
185
+ self.data = np.load(os.path.join(paths['geo_prior'], 'geo_prior_model_preds.npz'))
186
+ print('\n', self.data['probs'].shape[0], 'total test observations')
187
+ # load locations:
188
+ meta = pd.read_csv(os.path.join(paths['geo_prior'], 'geo_prior_model_meta.csv'))
189
+ self.obs_locs = np.vstack((meta['longitude'].values, meta['latitude'].values)).T.astype(np.float32)
190
+ # taxonomic mapping:
191
+ self.taxon_map = self.find_mapping_between_models(self.data['model_to_taxa'], self.train_params['class_to_taxa'])
192
+ print(self.taxon_map.shape[0], 'out of', len(self.data['model_to_taxa']), 'taxa in both vision and geo models')
193
+
194
+ def find_mapping_between_models(self, vision_taxa, geo_taxa):
195
+ # this will output an array of size N_overlap X 2
196
+ # the first column will be the indices of the vision model, and the second is their
197
+ # corresponding index in the geo model
198
+ taxon_map = np.ones((vision_taxa.shape[0], 2), dtype=np.int32)*-1
199
+ taxon_map[:, 0] = np.arange(vision_taxa.shape[0])
200
+ geo_taxa_arr = np.array(geo_taxa)
201
+ for tt_id, tt in enumerate(vision_taxa):
202
+ ind = np.where(geo_taxa_arr==tt)[0]
203
+ if len(ind) > 0:
204
+ taxon_map[tt_id, 1] = ind[0]
205
+ inds = np.where(taxon_map[:, 1]>-1)[0]
206
+ taxon_map = taxon_map[inds, :]
207
+ return taxon_map
208
+
209
+ def convert_to_inat_vision_order(self, geo_pred_ip, vision_top_k_prob, vision_top_k_inds, vision_taxa, taxon_map):
210
+ # this is slow as we turn the sparse input back into the same size as the dense one
211
+ vision_pred = np.zeros((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
212
+ geo_pred = np.ones((geo_pred_ip.shape[0], len(vision_taxa)), dtype=np.float32)
213
+ vision_pred[np.arange(vision_pred.shape[0])[..., np.newaxis], vision_top_k_inds] = vision_top_k_prob
214
+
215
+ geo_pred[:, taxon_map[:, 0]] = geo_pred_ip[:, taxon_map[:, 1]]
216
+
217
+ return geo_pred, vision_pred
218
+
219
+ def run_evaluation(self, model, enc):
220
+ results = {}
221
+
222
+ # loop over in batches
223
+ batch_start = np.hstack((np.arange(0, self.data['probs'].shape[0], self.eval_params['batch_size']), self.data['probs'].shape[0]))
224
+ correct_pred = np.zeros(self.data['probs'].shape[0])
225
+
226
+ print('\nbid\t w geo\t wo geo')
227
+ for bb_id, bb in enumerate(range(len(batch_start)-1)):
228
+ batch_inds = np.arange(batch_start[bb], batch_start[bb+1])
229
+
230
+ vision_probs = self.data['probs'][batch_inds, :]
231
+ vision_inds = self.data['inds'][batch_inds, :]
232
+ gt = self.data['labels'][batch_inds]
233
+
234
+ obs_locs_batch = torch.from_numpy(self.obs_locs[batch_inds, :]).to(self.eval_params['device'])
235
+ loc_feat = enc.encode(obs_locs_batch)
236
+
237
+ with torch.no_grad():
238
+ geo_pred = model(loc_feat).cpu().numpy()
239
+
240
+ geo_pred, vision_pred = self.convert_to_inat_vision_order(geo_pred, vision_probs, vision_inds,
241
+ self.data['model_to_taxa'], self.taxon_map)
242
+
243
+ comb_pred = np.argmax(vision_pred*geo_pred, 1)
244
+ comb_pred = (comb_pred==gt)
245
+ correct_pred[batch_inds] = comb_pred
246
+
247
+ results['vision_only_top_1'] = float((self.data['inds'][:, -1] == self.data['labels']).mean())
248
+ results['vision_geo_top_1'] = float(correct_pred.mean())
249
+ return results
250
+
251
+ def report(self, results):
252
+ print('\nOverall accuracy vision only model', round(results['vision_only_top_1'], 3))
253
+ print('Overall accuracy of geo model ', round(results['vision_geo_top_1'], 3))
254
+ print('Gain ', round(results['vision_geo_top_1'] - results['vision_only_top_1'], 3))
255
+
256
+ class EvaluatorGeoFeature:
257
+
258
+ def __init__(self, train_params, eval_params):
259
+ self.train_params = train_params
260
+ self.eval_params = eval_params
261
+ with open('paths.json', 'r') as f:
262
+ paths = json.load(f)
263
+ self.data_path = paths['geo_feature']
264
+ self.country_mask = tifffile.imread(os.path.join(paths['masks'], 'USA_MASK.tif')) == 1
265
+ self.raster_names = ['ABOVE_GROUND_CARBON', 'ELEVATION', 'LEAF_AREA_INDEX', 'NON_TREE_VEGITATED', 'NOT_VEGITATED', 'POPULATION_DENSITY', 'SNOW_COVER', 'SOIL_MOISTURE', 'TREE_COVER']
266
+ self.raster_names_log_transform = ['POPULATION_DENSITY']
267
+
268
+ def load_raster(self, raster_name, log_transform=False):
269
+ raster = tifffile.imread(os.path.join(self.data_path, raster_name + '.tif')).astype(np.float32)
270
+ valid_mask = ~np.isnan(raster).copy() & self.country_mask
271
+ # log scaling:
272
+ if log_transform:
273
+ raster[valid_mask] = np.log1p(raster[valid_mask] - raster[valid_mask].min())
274
+ # 0/1 scaling:
275
+ raster[valid_mask] -= raster[valid_mask].min()
276
+ raster[valid_mask] /= raster[valid_mask].max()
277
+
278
+ return raster, valid_mask
279
+
280
+ def get_split_labels(self, raster, split_ids, split_of_interest):
281
+ # get the GT labels for a subset
282
+ inds_y, inds_x = np.where(split_ids==split_of_interest)
283
+ return raster[inds_y, inds_x]
284
+
285
+ def get_split_feats(self, model, enc, split_ids, split_of_interest):
286
+ locs = utils.coord_grid(self.country_mask.shape, split_ids=split_ids, split_of_interest=split_of_interest)
287
+ locs = torch.from_numpy(locs).to(self.eval_params['device'])
288
+ locs_enc = enc.encode(locs)
289
+ with torch.no_grad():
290
+ feats = model(locs_enc, return_feats=True).cpu().numpy()
291
+ return feats
292
+
293
+ def run_evaluation(self, model, enc):
294
+ results = {}
295
+ for raster_name in self.raster_names:
296
+ do_log_transform = raster_name in self.raster_names_log_transform
297
+ raster, valid_mask = self.load_raster(raster_name, do_log_transform)
298
+ split_ids = utils.create_spatial_split(raster, valid_mask, cell_size=self.eval_params['cell_size'])
299
+ feats_train = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=1)
300
+ feats_test = self.get_split_feats(model, enc, split_ids=split_ids, split_of_interest=2)
301
+ labels_train = self.get_split_labels(raster, split_ids, 1)
302
+ labels_test = self.get_split_labels(raster, split_ids, 2)
303
+ scaler = MinMaxScaler()
304
+ feats_train_scaled = scaler.fit_transform(feats_train)
305
+ feats_test_scaled = scaler.transform(feats_test)
306
+ clf = RidgeCV(alphas=(0.1, 1.0, 10.0), normalize=False, cv=10, fit_intercept=True, scoring='r2').fit(feats_train_scaled, labels_train)
307
+ train_score = clf.score(feats_train_scaled, labels_train)
308
+ test_score = clf.score(feats_test_scaled, labels_test)
309
+ results[f'train_r2_{raster_name}'] = float(train_score)
310
+ results[f'test_r2_{raster_name}'] = float(test_score)
311
+ results[f'alpha_{raster_name}'] = float(clf.alpha_)
312
+ return results
313
+
314
+ def report(self, results):
315
+ report_fields = [x for x in results if 'test_r2' in x]
316
+ for field in report_fields:
317
+ print(f'{field}: {results[field]}')
318
+ print(np.mean([results[field] for field in report_fields]))
319
+
320
+ def launch_eval_run(overrides):
321
+
322
+ eval_params = setup.get_default_params_eval(overrides)
323
+
324
+ # set up model:
325
+ eval_params['model_path'] = os.path.join(eval_params['exp_base'], eval_params['experiment_name'], eval_params['ckp_name'])
326
+ train_params = torch.load(eval_params['model_path'], map_location='cpu')
327
+ model = models.get_model(train_params['params'])
328
+ model.load_state_dict(train_params['state_dict'], strict=True)
329
+ model = model.to(eval_params['device'])
330
+ model.eval()
331
+
332
+ # create input encoder:
333
+ if train_params['params']['input_enc'] in ['env', 'sin_cos_env']:
334
+ raster = datasets.load_env().to(eval_params['device'])
335
+ else:
336
+ raster = None
337
+ enc = utils.CoordEncoder(train_params['params']['input_enc'], raster=raster)
338
+
339
+ t = time.time()
340
+ if eval_params['eval_type'] == 'snt':
341
+ eval_params['split'] = 'test' # val, test, all
342
+ eval_params['val_frac'] = 0.50
343
+ eval_params['split_seed'] = 7499
344
+ evaluator = EvaluatorSNT(train_params['params'], eval_params)
345
+ results = evaluator.run_evaluation(model, enc)
346
+ evaluator.report(results)
347
+ elif eval_params['eval_type'] == 'iucn':
348
+ evaluator = EvaluatorIUCN(train_params['params'], eval_params)
349
+ results = evaluator.run_evaluation(model, enc)
350
+ evaluator.report(results)
351
+ elif eval_params['eval_type'] == 'geo_prior':
352
+ evaluator = EvaluatorGeoPrior(train_params['params'], eval_params)
353
+ results = evaluator.run_evaluation(model, enc)
354
+ evaluator.report(results)
355
+ elif eval_params['eval_type'] == 'geo_feature':
356
+ evaluator = EvaluatorGeoFeature(train_params['params'], eval_params)
357
+ results = evaluator.run_evaluation(model, enc)
358
+ evaluator.report(results)
359
+ else:
360
+ raise NotImplementedError('Eval type not implemented.')
361
+ print(f'evaluation completed in {np.around((time.time()-t)/60, 1)} min')
362
+ return results
images/sinr_traverse.gif ADDED
losses.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import utils
3
+
4
+ def get_loss_function(params):
5
+ if params['loss'] == 'an_full':
6
+ return an_full
7
+ elif params['loss'] == 'an_slds':
8
+ return an_slds
9
+ elif params['loss'] == 'an_ssdl':
10
+ return an_ssdl
11
+ elif params['loss'] == 'an_full_me':
12
+ return an_full_me
13
+ elif params['loss'] == 'an_slds_me':
14
+ return an_slds_me
15
+ elif params['loss'] == 'an_ssdl_me':
16
+ return an_ssdl_me
17
+
18
+ def neg_log(x):
19
+ return -torch.log(x + 1e-5)
20
+
21
+ def bernoulli_entropy(p):
22
+ entropy = p * neg_log(p) + (1-p) * neg_log(1-p)
23
+ return entropy
24
+
25
+ def an_ssdl(batch, model, params, loc_to_feats, neg_type='hard'):
26
+
27
+ inds = torch.arange(params['batch_size'])
28
+
29
+ loc_feat, _, class_id = batch
30
+ loc_feat = loc_feat.to(params['device'])
31
+ class_id = class_id.to(params['device'])
32
+
33
+ assert model.inc_bias == False
34
+ batch_size = loc_feat.shape[0]
35
+
36
+ # create random background samples and extract features
37
+ rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
38
+ rand_feat = loc_to_feats(rand_loc, normalize=False)
39
+
40
+ # get location embeddings
41
+ loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
42
+ loc_emb_cat = model(loc_cat, return_feats=True)
43
+ loc_emb = loc_emb_cat[:batch_size, :]
44
+ loc_emb_rand = loc_emb_cat[batch_size:, :]
45
+
46
+ loc_pred = torch.sigmoid(model.class_emb(loc_emb))
47
+ loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
48
+
49
+ # data loss
50
+ loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
51
+ if neg_type == 'hard':
52
+ loss_bg = neg_log(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # assume negative
53
+ elif neg_type == 'entropy':
54
+ loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand[inds[:batch_size], class_id]) # entropy
55
+ else:
56
+ raise NotImplementedError
57
+
58
+ # total loss
59
+ loss = loss_pos.mean() + loss_bg.mean()
60
+
61
+ return loss
62
+
63
+ def an_slds(batch, model, params, loc_to_feats, neg_type='hard'):
64
+
65
+ inds = torch.arange(params['batch_size'])
66
+
67
+ loc_feat, _, class_id = batch
68
+ loc_feat = loc_feat.to(params['device'])
69
+ class_id = class_id.to(params['device'])
70
+
71
+ assert model.inc_bias == False
72
+ batch_size = loc_feat.shape[0]
73
+
74
+ loc_emb = model(loc_feat, return_feats=True)
75
+
76
+ loc_pred = torch.sigmoid(model.class_emb(loc_emb))
77
+
78
+ num_classes = loc_pred.shape[1]
79
+ bg_class = torch.randint(low=0, high=num_classes-1, size=(batch_size,), device=params['device'])
80
+ bg_class[bg_class >= class_id[:batch_size]] += 1
81
+
82
+ # data loss
83
+ loss_pos = neg_log(loc_pred[inds[:batch_size], class_id])
84
+ if neg_type == 'hard':
85
+ loss_bg = neg_log(1.0 - loc_pred[inds[:batch_size], bg_class]) # assume negative
86
+ elif neg_type == 'entropy':
87
+ loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred[inds[:batch_size], bg_class]) # entropy
88
+ else:
89
+ raise NotImplementedError
90
+
91
+ # total loss
92
+ loss = loss_pos.mean() + loss_bg.mean()
93
+
94
+ return loss
95
+
96
+ def an_full(batch, model, params, loc_to_feats, neg_type='hard'):
97
+
98
+ inds = torch.arange(params['batch_size'])
99
+
100
+ loc_feat, _, class_id = batch
101
+ loc_feat = loc_feat.to(params['device'])
102
+ class_id = class_id.to(params['device'])
103
+
104
+ assert model.inc_bias == False
105
+ batch_size = loc_feat.shape[0]
106
+
107
+ # create random background samples and extract features
108
+ rand_loc = utils.rand_samples(batch_size, params['device'], rand_type='spherical')
109
+ rand_feat = loc_to_feats(rand_loc, normalize=False)
110
+
111
+ # get location embeddings
112
+ loc_cat = torch.cat((loc_feat, rand_feat), 0) # stack vertically
113
+ loc_emb_cat = model(loc_cat, return_feats=True)
114
+ loc_emb = loc_emb_cat[:batch_size, :]
115
+ loc_emb_rand = loc_emb_cat[batch_size:, :]
116
+ # get predictions for locations and background locations
117
+ loc_pred = torch.sigmoid(model.class_emb(loc_emb))
118
+ loc_pred_rand = torch.sigmoid(model.class_emb(loc_emb_rand))
119
+
120
+ # data loss
121
+ if neg_type == 'hard':
122
+ loss_pos = neg_log(1.0 - loc_pred) # assume negative
123
+ loss_bg = neg_log(1.0 - loc_pred_rand) # assume negative
124
+ elif neg_type == 'entropy':
125
+ loss_pos = -1 * bernoulli_entropy(1.0 - loc_pred) # entropy
126
+ loss_bg = -1 * bernoulli_entropy(1.0 - loc_pred_rand) # entropy
127
+ else:
128
+ raise NotImplementedError
129
+ loss_pos[inds[:batch_size], class_id] = params['pos_weight'] * neg_log(loc_pred[inds[:batch_size], class_id])
130
+
131
+ # total loss
132
+ loss = loss_pos.mean() + loss_bg.mean()
133
+
134
+ return loss
135
+
136
+ def an_full_me(batch, model, params, loc_to_feats):
137
+
138
+ return an_full(batch, model, params, loc_to_feats, neg_type='entropy')
139
+
140
+ def an_ssdl_me(batch, model, params, loc_to_feats):
141
+
142
+ return an_ssdl(batch, model, params, loc_to_feats, neg_type='entropy')
143
+
144
+ def an_slds_me(batch, model, params, loc_to_feats):
145
+
146
+ return an_slds(batch, model, params, loc_to_feats, neg_type='entropy')
models.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ import torch.nn as nn
4
+
5
+ def get_model(params):
6
+
7
+ if params['model'] == 'ResidualFCNet':
8
+ return ResidualFCNet(params['input_dim'], params['num_classes'], params['num_filts'], params['depth'])
9
+ elif params['model'] == 'LinNet':
10
+ return LinNet(params['input_dim'], params['num_classes'])
11
+ else:
12
+ raise NotImplementedError('Invalid model specified.')
13
+
14
+ class ResLayer(nn.Module):
15
+ def __init__(self, linear_size):
16
+ super(ResLayer, self).__init__()
17
+ self.l_size = linear_size
18
+ self.nonlin1 = nn.ReLU(inplace=True)
19
+ self.nonlin2 = nn.ReLU(inplace=True)
20
+ self.dropout1 = nn.Dropout()
21
+ self.w1 = nn.Linear(self.l_size, self.l_size)
22
+ self.w2 = nn.Linear(self.l_size, self.l_size)
23
+
24
+ def forward(self, x):
25
+ y = self.w1(x)
26
+ y = self.nonlin1(y)
27
+ y = self.dropout1(y)
28
+ y = self.w2(y)
29
+ y = self.nonlin2(y)
30
+ out = x + y
31
+ return out
32
+
33
+ class ResidualFCNet(nn.Module):
34
+
35
+ def __init__(self, num_inputs, num_classes, num_filts, depth=4):
36
+ super(ResidualFCNet, self).__init__()
37
+ self.inc_bias = False
38
+ self.class_emb = nn.Linear(num_filts, num_classes, bias=self.inc_bias)
39
+ layers = []
40
+ layers.append(nn.Linear(num_inputs, num_filts))
41
+ layers.append(nn.ReLU(inplace=True))
42
+ for i in range(depth):
43
+ layers.append(ResLayer(num_filts))
44
+ self.feats = torch.nn.Sequential(*layers)
45
+
46
+ def forward(self, x, class_of_interest=None, return_feats=False):
47
+ loc_emb = self.feats(x)
48
+ if return_feats:
49
+ return loc_emb
50
+ if class_of_interest is None:
51
+ class_pred = self.class_emb(loc_emb)
52
+ else:
53
+ class_pred = self.eval_single_class(loc_emb, class_of_interest)
54
+ return torch.sigmoid(class_pred)
55
+
56
+ def eval_single_class(self, x, class_of_interest):
57
+ if self.inc_bias:
58
+ return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest]
59
+ else:
60
+ return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T)
61
+
62
+ class LinNet(nn.Module):
63
+ def __init__(self, num_inputs, num_classes):
64
+ super(LinNet, self).__init__()
65
+ self.num_layers = 0
66
+ self.inc_bias = False
67
+ self.class_emb = nn.Linear(num_inputs, num_classes, bias=self.inc_bias)
68
+ self.feats = nn.Identity() # does not do anything
69
+
70
+ def forward(self, x, class_of_interest=None, return_feats=False):
71
+ loc_emb = self.feats(x)
72
+ if return_feats:
73
+ return loc_emb
74
+ if class_of_interest is None:
75
+ class_pred = self.class_emb(loc_emb)
76
+ else:
77
+ class_pred = self.eval_single_class(loc_emb, class_of_interest)
78
+
79
+ return torch.sigmoid(class_pred)
80
+
81
+ def eval_single_class(self, x, class_of_interest):
82
+ if self.inc_bias:
83
+ return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T) + self.class_emb.bias[class_of_interest]
84
+ else:
85
+ return torch.matmul(x, self.class_emb.weight[class_of_interest, :].T)
paths.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "masks": "data/masks/",
3
+ "env": "data/env/",
4
+ "train": "data/train/",
5
+ "geo_prior": "data/eval/geo_prior/",
6
+ "snt": "data/eval/snt/",
7
+ "iucn": "data/eval/iucn/",
8
+ "geo_feature": "data/eval/geo_feature/"
9
+ }
pretrained_models/model_an_full_input_enc_sin_cos_distilled_from_env.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8408dbfdcdc3008cfce801318cba2263149aaec0451a656d52958bf81115547
3
+ size 50849971
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:438265b758df7cf58f2ed39410205be6a9fa944e559d7556d9d5c7c0f501c4ae
3
+ size 50850118
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_100.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31811bf8e0a8bf8f59b9efc7fa56db015e49f02c594316a3a4389dc91ad6aae9
3
+ size 50850139
pretrained_models/model_an_full_input_enc_sin_cos_hard_cap_num_per_class_1000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7bf4dacb0f9b4cf8c5323e1c186b1e725267199053b9aa672d4b5dc1c3dbc235
3
+ size 50850160
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio==3.36.1
2
+ h3==3.7.6
3
+ matplotlib==3.7.1
4
+ numpy==1.25.0
5
+ pandas==2.0.3
6
+ torch==1.12.1
setup.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+
4
+ def apply_overrides(params, overrides):
5
+ params = copy.deepcopy(params)
6
+ for param_name in overrides:
7
+ if param_name not in params:
8
+ print(f'override failed: no parameter named {param_name}')
9
+ raise ValueError
10
+ params[param_name] = overrides[param_name]
11
+ return params
12
+
13
+ def get_default_params_train(overrides={}):
14
+
15
+ params = {}
16
+
17
+ '''
18
+ misc
19
+ '''
20
+ params['device'] = 'cuda' # cuda, cpu
21
+ params['save_base'] = './experiments/'
22
+ params['experiment_name'] = 'demo'
23
+ params['timestamp'] = False
24
+
25
+ '''
26
+ data
27
+ '''
28
+ params['species_set'] = 'all' # all, snt_birds
29
+ params['hard_cap_seed'] = 9472
30
+ params['hard_cap_num_per_class'] = -1 # -1 for no hard capping
31
+ params['aux_species_seed'] = 8099
32
+ params['num_aux_species'] = 0 # for snt_birds case, how many other species to add in
33
+
34
+ '''
35
+ model
36
+ '''
37
+ params['model'] = 'ResidualFCNet' # ResidualFCNet, LinNet
38
+ params['num_filts'] = 256 # embedding dimension
39
+ params['input_enc'] = 'sin_cos' # sin_cos, env, sin_cos_env
40
+ params['depth'] = 4
41
+
42
+ '''
43
+ loss
44
+ '''
45
+ params['loss'] = 'an_full' # an_full, an_ssdl, an_slds
46
+ params['pos_weight'] = 2048
47
+
48
+ '''
49
+ optimization
50
+ '''
51
+ params['batch_size'] = 2048
52
+ params['lr'] = 0.0005
53
+ params['lr_decay'] = 0.98
54
+ params['num_epochs'] = 10
55
+
56
+ '''
57
+ saving
58
+ '''
59
+ params['log_frequency'] = 512
60
+
61
+ params = apply_overrides(params, overrides)
62
+
63
+ return params
64
+
65
+ def get_default_params_eval(overrides={}):
66
+
67
+ params = {}
68
+
69
+ '''
70
+ misc
71
+ '''
72
+ params['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
+ params['seed'] = 2022
74
+ params['exp_base'] = './experiments'
75
+ params['ckp_name'] = 'model.pt'
76
+ params['eval_type'] = 'snt' # snt, iucn, geo_prior, geo_feature
77
+ params['experiment_name'] = 'demo'
78
+
79
+ '''
80
+ geo prior
81
+ '''
82
+ params['batch_size'] = 2048
83
+
84
+ '''
85
+ geo feature
86
+ '''
87
+ params['cell_size'] = 25
88
+
89
+ params = apply_overrides(params, overrides)
90
+
91
+ return params
taxa_02_08_2023_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import math
4
+ import datetime
5
+
6
+ class CoordEncoder:
7
+
8
+ def __init__(self, input_enc, raster=None):
9
+ self.input_enc = input_enc
10
+ self.raster = raster
11
+
12
+ def encode(self, locs, normalize=True):
13
+ # assumes lon, lat in range [-180, 180] and [-90, 90]
14
+ if normalize:
15
+ locs = normalize_coords(locs)
16
+ if self.input_enc == 'sin_cos': # sinusoidal encoding
17
+ loc_feats = encode_loc(locs)
18
+ elif self.input_enc == 'env': # bioclim variables
19
+ loc_feats = bilinear_interpolate(locs, self.raster)
20
+ elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
21
+ loc_feats = encode_loc(locs)
22
+ context_feats = bilinear_interpolate(locs, self.raster)
23
+ loc_feats = torch.cat((loc_feats, context_feats), 1)
24
+ else:
25
+ raise NotImplementedError('Unknown input encoding.')
26
+ return loc_feats
27
+
28
+ def normalize_coords(locs):
29
+ # locs is in lon {-180, 180}, lat {90, -90}
30
+ # output is in the range [-1, 1]
31
+
32
+ locs[:,0] /= 180.0
33
+ locs[:,1] /= 90.0
34
+
35
+ return locs
36
+
37
+ def encode_loc(loc_ip, concat_dim=1):
38
+ # assumes inputs location are in range -1 to 1
39
+ # location is lon, lat
40
+ feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)
41
+ return feats
42
+
43
+ def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
44
+ # loc is N x 2 vector, where each row is [lon,lat] entry
45
+ # each entry spans range [-1,1]
46
+ # data is H x W x C, height x width x channel data matrix
47
+ # op will be N x C matrix of interpolated features
48
+
49
+ assert data is not None
50
+
51
+ # map to [0,1], then scale to data size
52
+ loc = (loc_ip.clone() + 1) / 2.0
53
+ loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
54
+ # longitude goes from -90 to 90 left to right
55
+
56
+ assert not torch.any(torch.isnan(loc))
57
+
58
+ if remove_nans_raster:
59
+ data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
60
+
61
+ # cast locations into pixel space
62
+ loc[:, 0] *= (data.shape[1]-1)
63
+ loc[:, 1] *= (data.shape[0]-1)
64
+
65
+ loc_int = torch.floor(loc).long() # integer pixel coordinates
66
+ xx = loc_int[:, 0]
67
+ yy = loc_int[:, 1]
68
+ xx_plus = xx + 1
69
+ xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
70
+ yy_plus = yy + 1
71
+ yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
72
+
73
+ loc_delta = loc - torch.floor(loc) # delta values
74
+ dx = loc_delta[:, 0].unsqueeze(1)
75
+ dy = loc_delta[:, 1].unsqueeze(1)
76
+
77
+ interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
78
+ data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
79
+
80
+ return interp_val
81
+
82
+ def rand_samples(batch_size, device, rand_type='uniform'):
83
+ # randomly sample background locations
84
+
85
+ if rand_type == 'spherical':
86
+ rand_loc = torch.rand(batch_size, 2).to(device)
87
+ theta1 = 2.0*math.pi*rand_loc[:, 0]
88
+ theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
89
+ lat = 1.0 - 2.0*theta2/math.pi
90
+ lon = (theta1/math.pi) - 1.0
91
+ rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
92
+
93
+ elif rand_type == 'uniform':
94
+ rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
95
+
96
+ return rand_loc
97
+
98
+ def get_time_stamp():
99
+ cur_time = str(datetime.datetime.now())
100
+ date, time = cur_time.split(' ')
101
+ h, m, s = time.split(':')
102
+ s = s.split('.')[0]
103
+ time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
104
+ return time_stamp
105
+
106
+ def coord_grid(grid_size, split_ids=None, split_of_interest=None):
107
+ # generate a grid of locations spaced evenly in coordinate space
108
+
109
+ feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
110
+ mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
111
+ feats[:, :, 0] = mg[0]
112
+ feats[:, :, 1] = mg[1]
113
+ if split_ids is None or split_of_interest is None:
114
+ # return feats for all locations
115
+ # this will be an N x 2 array
116
+ return feats.reshape(feats.shape[0]*feats.shape[1], 2)
117
+ else:
118
+ # only select a subset of locations
119
+ ind_y, ind_x = np.where(split_ids==split_of_interest)
120
+
121
+ # these will be N_subset x 2 in size
122
+ return feats[ind_y, ind_x, :]
123
+
124
+ def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
125
+ # generates a checkerboard style train test split
126
+ # 0 is invalid, 1 is train, and 2 is test
127
+ # c_size is units of pixels
128
+ split_ids = np.ones((raster.shape[0], raster.shape[1]))
129
+ start = cell_size
130
+ for ii in np.arange(0, split_ids.shape[0], cell_size):
131
+ if start == 0:
132
+ start = cell_size
133
+ else:
134
+ start = 0
135
+ for jj in np.arange(start, split_ids.shape[1], cell_size*2):
136
+ split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
137
+ split_ids = split_ids*mask
138
+ if train_amt < 1.0:
139
+ # take a subset of the data
140
+ tr_y, tr_x = np.where(split_ids==1)
141
+ inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
142
+ split_ids[tr_y[inds], tr_x[inds]] = 0
143
+ return split_ids