File size: 5,220 Bytes
505e401 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import torch
import numpy as np
import math
import datetime
class CoordEncoder:
def __init__(self, input_enc, raster=None):
self.input_enc = input_enc
self.raster = raster
def encode(self, locs, normalize=True):
# assumes lon, lat in range [-180, 180] and [-90, 90]
if normalize:
locs = normalize_coords(locs)
if self.input_enc == 'sin_cos': # sinusoidal encoding
loc_feats = encode_loc(locs)
elif self.input_enc == 'env': # bioclim variables
loc_feats = bilinear_interpolate(locs, self.raster)
elif self.input_enc == 'sin_cos_env': # sinusoidal encoding & bioclim variables
loc_feats = encode_loc(locs)
context_feats = bilinear_interpolate(locs, self.raster)
loc_feats = torch.cat((loc_feats, context_feats), 1)
else:
raise NotImplementedError('Unknown input encoding.')
return loc_feats
def normalize_coords(locs):
# locs is in lon {-180, 180}, lat {90, -90}
# output is in the range [-1, 1]
locs[:,0] /= 180.0
locs[:,1] /= 90.0
return locs
def encode_loc(loc_ip, concat_dim=1):
# assumes inputs location are in range -1 to 1
# location is lon, lat
feats = torch.cat((torch.sin(math.pi*loc_ip), torch.cos(math.pi*loc_ip)), concat_dim)
return feats
def bilinear_interpolate(loc_ip, data, remove_nans_raster=True):
# loc is N x 2 vector, where each row is [lon,lat] entry
# each entry spans range [-1,1]
# data is H x W x C, height x width x channel data matrix
# op will be N x C matrix of interpolated features
assert data is not None
# map to [0,1], then scale to data size
loc = (loc_ip.clone() + 1) / 2.0
loc[:,1] = 1 - loc[:,1] # this is because latitude goes from +90 on top to bottom while
# longitude goes from -90 to 90 left to right
assert not torch.any(torch.isnan(loc))
if remove_nans_raster:
data[torch.isnan(data)] = 0.0 # replace with mean value (0 is mean post-normalization)
# cast locations into pixel space
loc[:, 0] *= (data.shape[1]-1)
loc[:, 1] *= (data.shape[0]-1)
loc_int = torch.floor(loc).long() # integer pixel coordinates
xx = loc_int[:, 0]
yy = loc_int[:, 1]
xx_plus = xx + 1
xx_plus[xx_plus > (data.shape[1]-1)] = data.shape[1]-1
yy_plus = yy + 1
yy_plus[yy_plus > (data.shape[0]-1)] = data.shape[0]-1
loc_delta = loc - torch.floor(loc) # delta values
dx = loc_delta[:, 0].unsqueeze(1)
dy = loc_delta[:, 1].unsqueeze(1)
interp_val = data[yy, xx, :]*(1-dx)*(1-dy) + data[yy, xx_plus, :]*dx*(1-dy) + \
data[yy_plus, xx, :]*(1-dx)*dy + data[yy_plus, xx_plus, :]*dx*dy
return interp_val
def rand_samples(batch_size, device, rand_type='uniform'):
# randomly sample background locations
if rand_type == 'spherical':
rand_loc = torch.rand(batch_size, 2).to(device)
theta1 = 2.0*math.pi*rand_loc[:, 0]
theta2 = torch.acos(2.0*rand_loc[:, 1] - 1.0)
lat = 1.0 - 2.0*theta2/math.pi
lon = (theta1/math.pi) - 1.0
rand_loc = torch.cat((lon.unsqueeze(1), lat.unsqueeze(1)), 1)
elif rand_type == 'uniform':
rand_loc = torch.rand(batch_size, 2).to(device)*2.0 - 1.0
return rand_loc
def get_time_stamp():
cur_time = str(datetime.datetime.now())
date, time = cur_time.split(' ')
h, m, s = time.split(':')
s = s.split('.')[0]
time_stamp = '{}-{}-{}-{}'.format(date, h, m, s)
return time_stamp
def coord_grid(grid_size, split_ids=None, split_of_interest=None):
# generate a grid of locations spaced evenly in coordinate space
feats = np.zeros((grid_size[0], grid_size[1], 2), dtype=np.float32)
mg = np.meshgrid(np.linspace(-180, 180, feats.shape[1]), np.linspace(90, -90, feats.shape[0]))
feats[:, :, 0] = mg[0]
feats[:, :, 1] = mg[1]
if split_ids is None or split_of_interest is None:
# return feats for all locations
# this will be an N x 2 array
return feats.reshape(feats.shape[0]*feats.shape[1], 2)
else:
# only select a subset of locations
ind_y, ind_x = np.where(split_ids==split_of_interest)
# these will be N_subset x 2 in size
return feats[ind_y, ind_x, :]
def create_spatial_split(raster, mask, train_amt=1.0, cell_size=25):
# generates a checkerboard style train test split
# 0 is invalid, 1 is train, and 2 is test
# c_size is units of pixels
split_ids = np.ones((raster.shape[0], raster.shape[1]))
start = cell_size
for ii in np.arange(0, split_ids.shape[0], cell_size):
if start == 0:
start = cell_size
else:
start = 0
for jj in np.arange(start, split_ids.shape[1], cell_size*2):
split_ids[ii:ii+cell_size, jj:jj+cell_size] = 2
split_ids = split_ids*mask
if train_amt < 1.0:
# take a subset of the data
tr_y, tr_x = np.where(split_ids==1)
inds = np.random.choice(len(tr_y), int(len(tr_y)*(1.0-train_amt)), replace=False)
split_ids[tr_y[inds], tr_x[inds]] = 0
return split_ids
|