Spaces:
Running
Running
File size: 11,307 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/native/cuda/KernelUtils.cuh>
#define BLOCK_H 4
#define BLOCK_W 8
#define BLOCK_HW BLOCK_H * BLOCK_W
#define CHANNEL_STRIDE 32
__forceinline__ __device__
bool within_bounds(int h, int w, int H, int W) {
return h >= 0 && h < H && w >= 0 && w < W;
}
template <typename scalar_t>
__global__ void altcorr_forward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW];
__shared__ float x2s[BLOCK_HW];
__shared__ float y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fmap1[b][h1][w1][c+c1];
else
f1[c1][k1] = 0.0;
}
__syncthreads();
for (int n=0; n<N; n++) {
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
if (within_bounds(h1, w1, H1, W1)) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
}
float dx = x2s[tid] - floor(x2s[tid]);
float dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1])) - r + iy;
int w2 = static_cast<int>(floor(x2s[k1])) - r + ix;
int c2 = tid % CHANNEL_STRIDE;
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fmap2[b][h2][w2][c+c2];
else
f2[c2][k1] = static_cast<scalar_t>(0.0);
}
__syncthreads();
scalar_t s = 0.0;
for (int k=0; k<CHANNEL_STRIDE; k++)
s += f1[k][tid] * f2[k][tid];
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
// int ix_nw = ((iy-1) + rd*(ix-1));
// int ix_ne = ((iy-1) + rd*ix);
// int ix_sw = (iy + rd*(ix-1));
// int ix_se = (iy + rd*ix);
scalar_t nw = s * static_cast<scalar_t>((dy) * (dx));
scalar_t ne = s * static_cast<scalar_t>((dy) * (1-dx));
scalar_t sw = s * static_cast<scalar_t>((1-dy) * (dx));
scalar_t se = s * static_cast<scalar_t>((1-dy) * (1-dx));
// if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
// corr[b][n][ix_nw][h1][w1] += nw;
// if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
// corr[b][n][ix_ne][h1][w1] += ne;
// if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
// corr[b][n][ix_sw][h1][w1] += sw;
// if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
// corr[b][n][ix_se][h1][w1] += se;
scalar_t* corr_ptr = &corr[b][n][0][h1][w1];
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_nw) += nw;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_ne) += ne;
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_sw) += sw;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
*(corr_ptr + ix_se) += se;
}
}
}
}
}
template <typename scalar_t>
__global__ void altcorr_backward_kernel(
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1,
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords,
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> corr_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap1_grad,
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> fmap2_grad,
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> coords_grad,
int r)
{
const int b = blockIdx.x;
const int h0 = blockIdx.y * blockDim.x;
const int w0 = blockIdx.z * blockDim.y;
const int tid = threadIdx.x * blockDim.y + threadIdx.y;
const int H1 = fmap1.size(1);
const int W1 = fmap1.size(2);
const int H2 = fmap2.size(1);
const int W2 = fmap2.size(2);
const int N = coords.size(1);
const int C = fmap1.size(3);
__shared__ scalar_t f1[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f1_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t f2_grad[CHANNEL_STRIDE][BLOCK_HW+1];
__shared__ scalar_t x2s[BLOCK_HW];
__shared__ scalar_t y2s[BLOCK_HW];
for (int c=0; c<C; c+=CHANNEL_STRIDE) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
auto fptr = fmap1[b][h1][w1];
if (within_bounds(h1, w1, H1, W1))
f1[c1][k1] = fptr[c+c1];
else
f1[c1][k1] = 0.0;
f1_grad[c1][k1] = 0.0;
}
__syncthreads();
int h1 = h0 + threadIdx.x;
int w1 = w0 + threadIdx.y;
for (int n=0; n<N; n++) {
x2s[tid] = coords[b][n][h1][w1][0];
y2s[tid] = coords[b][n][h1][w1][1];
scalar_t dx = x2s[tid] - floor(x2s[tid]);
scalar_t dy = y2s[tid] - floor(y2s[tid]);
int rd = 2*r + 1;
for (int iy=0; iy<rd+1; iy++) {
for (int ix=0; ix<rd+1; ix++) {
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
auto fptr = fmap2[b][h2][w2];
if (within_bounds(h2, w2, H2, W2))
f2[c2][k1] = fptr[c+c2];
else
f2[c2][k1] = 0.0;
f2_grad[c2][k1] = 0.0;
}
__syncthreads();
const scalar_t* grad_ptr = &corr_grad[b][n][0][h1][w1];
scalar_t g = 0.0;
int ix_nw = H1*W1*((iy-1) + rd*(ix-1));
int ix_ne = H1*W1*((iy-1) + rd*ix);
int ix_sw = H1*W1*(iy + rd*(ix-1));
int ix_se = H1*W1*(iy + rd*ix);
if (iy > 0 && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_nw) * dy * dx;
if (iy > 0 && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_ne) * dy * (1-dx);
if (iy < rd && ix > 0 && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_sw) * (1-dy) * dx;
if (iy < rd && ix < rd && within_bounds(h1, w1, H1, W1))
g += *(grad_ptr + ix_se) * (1-dy) * (1-dx);
for (int k=0; k<CHANNEL_STRIDE; k++) {
f1_grad[k][tid] += g * f2[k][tid];
f2_grad[k][tid] += g * f1[k][tid];
}
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h2 = static_cast<int>(floor(y2s[k1]))-r+iy;
int w2 = static_cast<int>(floor(x2s[k1]))-r+ix;
int c2 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap2_grad[b][h2][w2][0];
if (within_bounds(h2, w2, H2, W2))
atomicAdd(fptr+c+c2, f2_grad[c2][k1]);
}
}
}
}
__syncthreads();
for (int k=0; k<BLOCK_HW; k+=BLOCK_HW/CHANNEL_STRIDE) {
int k1 = k + tid / CHANNEL_STRIDE;
int h1 = h0 + k1 / BLOCK_W;
int w1 = w0 + k1 % BLOCK_W;
int c1 = tid % CHANNEL_STRIDE;
scalar_t* fptr = &fmap1_grad[b][h1][w1][0];
if (within_bounds(h1, w1, H1, W1))
fptr[c+c1] += f1_grad[c1][k1];
}
}
}
std::vector<torch::Tensor> altcorr_cuda_forward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H = coords.size(2);
const auto W = coords.size(3);
const auto rd = 2 * radius + 1;
auto opts = fmap1.options();
auto corr = torch::zeros({B, N, rd*rd, H, W}, opts);
const dim3 blocks(B, (H+BLOCK_H-1)/BLOCK_H, (W+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "altcorr_forward_kernel", ([&] {
altcorr_forward_kernel<scalar_t><<<blocks, threads>>>(
fmap1.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
radius);
}));
return {corr};
}
std::vector<torch::Tensor> altcorr_cuda_backward(
torch::Tensor fmap1,
torch::Tensor fmap2,
torch::Tensor coords,
torch::Tensor corr_grad,
int radius)
{
const auto B = coords.size(0);
const auto N = coords.size(1);
const auto H1 = fmap1.size(1);
const auto W1 = fmap1.size(2);
const auto H2 = fmap2.size(1);
const auto W2 = fmap2.size(2);
const auto C = fmap1.size(3);
auto opts = fmap1.options();
auto fmap1_grad = torch::zeros({B, H1, W1, C}, opts);
auto fmap2_grad = torch::zeros({B, H2, W2, C}, opts);
auto coords_grad = torch::zeros({B, N, H1, W1, 2}, opts);
const dim3 blocks(B, (H1+BLOCK_H-1)/BLOCK_H, (W1+BLOCK_W-1)/BLOCK_W);
const dim3 threads(BLOCK_H, BLOCK_W);
altcorr_backward_kernel<float><<<blocks, threads>>>(
fmap1.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
corr_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
fmap1_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
fmap2_grad.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
coords_grad.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
radius);
return {fmap1_grad, fmap2_grad, coords_grad};
} |