sinr / datasets.py
Oisin Mac Aodha
First model version
505e401
import os
import numpy as np
import json
import pandas as pd
from calendar import monthrange
import torch
import utils
class LocationDataset(torch.utils.data.Dataset):
def __init__(self, locs, labels, classes, class_to_taxa, input_enc, device):
# handle input encoding:
self.input_enc = input_enc
if self.input_enc in ['env', 'sin_cos_env']:
raster = load_env()
else:
raster = None
self.enc = utils.CoordEncoder(input_enc, raster)
# define some properties:
self.locs = locs
self.loc_feats = self.enc.encode(self.locs)
self.labels = labels
self.classes = classes
self.class_to_taxa = class_to_taxa
# useful numbers:
self.num_classes = len(np.unique(labels))
self.input_dim = self.loc_feats.shape[1]
if self.enc.raster is not None:
self.enc.raster = self.enc.raster.to(device)
def __len__(self):
return self.loc_feats.shape[0]
def __getitem__(self, index):
loc_feat = self.loc_feats[index, :]
loc = self.locs[index, :]
class_id = self.labels[index]
return loc_feat, loc, class_id
def load_env():
with open('paths.json', 'r') as f:
paths = json.load(f)
raster = load_context_feats(os.path.join(paths['env'],'bioclim_elevation_scaled.npy'))
return raster
def load_context_feats(data_path):
context_feats = np.load(data_path).astype(np.float32)
context_feats = torch.from_numpy(context_feats)
return context_feats
def load_inat_data(ip_file, taxa_of_interest=None):
print('\nLoading ' + ip_file)
data = pd.read_csv(ip_file)
# remove outliers
num_obs = data.shape[0]
data = data[((data['latitude'] <= 90) & (data['latitude'] >= -90) & (data['longitude'] <= 180) & (data['longitude'] >= -180) )]
if (num_obs - data.shape[0]) > 0:
print(num_obs - data.shape[0], 'items filtered due to invalid locations')
if 'accuracy' in data.columns:
data.drop(['accuracy'], axis=1, inplace=True)
if 'positional_accuracy' in data.columns:
data.drop(['positional_accuracy'], axis=1, inplace=True)
if 'geoprivacy' in data.columns:
data.drop(['geoprivacy'], axis=1, inplace=True)
if 'observed_on' in data.columns:
data.rename(columns = {'observed_on':'date'}, inplace=True)
num_obs_orig = data.shape[0]
data = data.dropna()
size_diff = num_obs_orig - data.shape[0]
if size_diff > 0:
print(size_diff, 'observation(s) with a NaN entry out of' , num_obs_orig, 'removed')
# keep only taxa of interest:
if taxa_of_interest is not None:
num_obs_orig = data.shape[0]
data = data[data['taxon_id'].isin(taxa_of_interest)]
print(num_obs_orig - data.shape[0], 'observation(s) out of' , num_obs_orig, 'from different taxa removed')
print('Number of unique classes {}'.format(np.unique(data['taxon_id'].values).shape[0]))
locs = np.vstack((data['longitude'].values, data['latitude'].values)).T.astype(np.float32)
taxa = data['taxon_id'].values.astype(np.int)
if 'user_id' in data.columns:
users = data['user_id'].values.astype(np.int)
_, users = np.unique(users, return_inverse=True)
elif 'observer_id' in data.columns:
users = data['observer_id'].values.astype(np.int)
_, users = np.unique(users, return_inverse=True)
else:
users = np.ones(taxa.shape[0], dtype=np.int)*-1
# Note - assumes that dates are in format YYYY-MM-DD
years = np.array([int(d_str[:4]) for d_str in data['date'].values])
months = np.array([int(d_str[5:7]) for d_str in data['date'].values])
days = np.array([int(d_str[8:10]) for d_str in data['date'].values])
days_per_month = np.cumsum([0] + [monthrange(2018, mm)[1] for mm in range(1, 12)])
dates = days_per_month[months-1] + days-1
dates = np.round((dates) / 364.0, 4).astype(np.float32)
if 'id' in data.columns:
obs_ids = data['id'].values
elif 'observation_uuid' in data.columns:
obs_ids = data['observation_uuid'].values
return locs, taxa, users, dates, years, obs_ids
def choose_aux_species(current_species, num_aux_species, aux_species_seed):
if num_aux_species == 0:
return []
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
with open(taxa_file, 'r') as f:
inat_large_metadata = json.load(f)
aux_species_candidates = [x['taxon_id'] for x in inat_large_metadata]
aux_species_candidates = np.setdiff1d(aux_species_candidates, current_species)
print(f'choosing {num_aux_species} species to add from {len(aux_species_candidates)} candidates')
rng = np.random.default_rng(aux_species_seed)
idx_rand_aux_species = rng.permutation(len(aux_species_candidates))
aux_species = list(aux_species_candidates[idx_rand_aux_species[:num_aux_species]])
return aux_species
def get_taxa_of_interest(species_set='all', num_aux_species=0, aux_species_seed=123, taxa_file_snt=None):
if species_set == 'all':
return None
if species_set == 'snt_birds':
assert taxa_file_snt is not None
with open(taxa_file_snt, 'r') as f: #
taxa_subsets = json.load(f)
taxa_of_interest = list(taxa_subsets['snt_birds'])
else:
raise NotImplementedError
# optionally add some other species back in:
aux_species = choose_aux_species(taxa_of_interest, num_aux_species, aux_species_seed)
taxa_of_interest.extend(aux_species)
return taxa_of_interest
def get_idx_subsample_observations(labels, hard_cap=-1, hard_cap_seed=123):
if hard_cap == -1:
return np.arange(len(labels))
print(f'subsampling (up to) {hard_cap} per class for the training set')
class_counts = {id: 0 for id in np.unique(labels)}
ss_rng = np.random.default_rng(hard_cap_seed)
idx_rand = ss_rng.permutation(len(labels))
idx_ss = []
for i in idx_rand:
class_id = labels[i]
if class_counts[class_id] < hard_cap:
idx_ss.append(i)
class_counts[class_id] += 1
idx_ss = np.sort(idx_ss)
print(f'final training set size: {len(idx_ss)}')
return idx_ss
def get_train_data(params):
with open('paths.json', 'r') as f:
paths = json.load(f)
data_dir = paths['train']
obs_file = os.path.join(data_dir, 'geo_prior_train.csv')
taxa_file = os.path.join(data_dir, 'geo_prior_train_meta.json')
taxa_file_snt = os.path.join(data_dir, 'taxa_subsets.json')
taxa_of_interest = get_taxa_of_interest(params['species_set'], params['num_aux_species'], params['aux_species_seed'], taxa_file_snt)
locs, labels, _, _, _, _ = load_inat_data(obs_file, taxa_of_interest)
unique_taxa, class_ids = np.unique(labels, return_inverse=True)
class_to_taxa = unique_taxa.tolist()
# load class names
class_info_file = json.load(open(taxa_file, 'r'))
class_names_file = [cc['latin_name'] for cc in class_info_file]
taxa_ids_file = [cc['taxon_id'] for cc in class_info_file]
classes = dict(zip(taxa_ids_file, class_names_file))
idx_ss = get_idx_subsample_observations(labels, params['hard_cap_num_per_class'], params['hard_cap_seed'])
locs = torch.from_numpy(np.array(locs)[idx_ss]) # convert to Tensor
labels = torch.from_numpy(np.array(class_ids)[idx_ss])
ds = LocationDataset(locs, labels, classes, class_to_taxa, params['input_enc'], params['device'])
return ds