Spaces:
Runtime error
Runtime error
File size: 2,741 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
# flake8: noqa
import torch
from torch.nn.modules.module import Module
from torch.autograd import Function
import correlation_cuda
class CorrelationFunction(Function):
@staticmethod
def forward(ctx,
pad_size,
kernel_size,
max_displacement,
stride1,
stride2,
corr_multiply,
input1,
input2):
ctx.save_for_backward(input1, input2)
ctx.pad_size = pad_size
ctx.kernel_size = kernel_size
ctx.max_displacement = max_displacement
ctx.stride1 = stride1
ctx.stride2 = stride2
ctx.corr_multiply = corr_multiply
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
output = input1.new()
correlation_cuda.forward(
input1,
input2,
rbot1,
rbot2,
output,
ctx.pad_size,
ctx.kernel_size,
ctx.max_displacement,
ctx.stride1,
ctx.stride2,
ctx.corr_multiply)
return output
@staticmethod
def backward(ctx, grad_output):
input1, input2 = ctx.saved_tensors
with torch.cuda.device_of(input1):
rbot1 = input1.new()
rbot2 = input2.new()
grad_input1 = input1.new()
grad_input2 = input2.new()
correlation_cuda.backward(
input1,
input2,
rbot1,
rbot2,
grad_output,
grad_input1,
grad_input2,
ctx.pad_size,
ctx.kernel_size,
ctx.max_displacement,
ctx.stride1,
ctx.stride2,
ctx.corr_multiply)
return grad_input1, grad_input2
class Correlation(Module):
def __init__(
self,
pad_size=0,
kernel_size=0,
max_displacement=0,
stride1=1,
stride2=2,
corr_multiply=1):
super(Correlation, self).__init__()
self.pad_size = pad_size
self.kernel_size = kernel_size
self.max_displacement = max_displacement
self.stride1 = stride1
self.stride2 = stride2
self.corr_multiply = corr_multiply
def forward(self, input1, input2):
result = CorrelationFunction.apply(
self.pad_size,
self.kernel_size,
self.max_displacement,
self.stride1,
self.stride2,
self.corr_multiply,
input1,
input2)
return result
|