import torch import re import cupy from modules.cupy_module.cupy_utils import cupy_launch # Code from https://github.com/sniklaus/softmax-splatting/blob/master/softsplat.py kernel_Softsplat_updateOutput = ''' extern "C" __global__ void kernel_Softsplat_updateOutput( const int n, const float* input, const float* flow, float* output ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); const int intX = ( intIndex ) % SIZE_3(output); float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); int intNorthwestX = (int) (floor(fltOutputX)); int intNorthwestY = (int) (floor(fltOutputY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); } if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); } if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); } if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); } } } ''' kernel_Softsplat_updateGradInput = ''' extern "C" __global__ void kernel_Softsplat_updateGradInput( const int n, const float* input, const float* flow, const float* gradOutput, float* gradInput, float* gradFlow ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); const int intX = ( intIndex ) % SIZE_3(gradInput); float fltGradInput = 0.0; float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); int intNorthwestX = (int) (floor(fltOutputX)); int intNorthwestY = (int) (floor(fltOutputY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; } if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; } if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; } if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; } gradInput[intIndex] = fltGradInput; } } ''' kernel_Softsplat_updateGradFlow = ''' extern "C" __global__ void kernel_Softsplat_updateGradFlow( const int n, const float* input, const float* flow, const float* gradOutput, float* gradInput, float* gradFlow ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { float fltGradFlow = 0.0; const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); const int intX = ( intIndex ) % SIZE_3(gradFlow); float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); int intNorthwestX = (int) (floor(fltOutputX)); int intNorthwestY = (int) (floor(fltOutputY)); int intNortheastX = intNorthwestX + 1; int intNortheastY = intNorthwestY; int intSouthwestX = intNorthwestX; int intSouthwestY = intNorthwestY + 1; int intSoutheastX = intNorthwestX + 1; int intSoutheastY = intNorthwestY + 1; float fltNorthwest = 0.0; float fltNortheast = 0.0; float fltSouthwest = 0.0; float fltSoutheast = 0.0; if (intC == 0) { fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); } else if (intC == 1) { fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); } for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { float fltInput = VALUE_4(input, intN, intChannel, intY, intX); if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; } if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; } if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; } if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; } } gradFlow[intIndex] = fltGradFlow; } } ''' def cupy_kernel(strFunction, objVariables): strKernel = globals()[strFunction] while True: objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) if objMatch is None: break # end intArg = int(objMatch.group(2)) strTensor = objMatch.group(4) intSizes = objVariables[strTensor].size() strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) # end while True: objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) if objMatch is None: break # end intArgs = int(objMatch.group(2)) strArgs = objMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') # end while True: objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) if objMatch is None: break # end intArgs = int(objMatch.group(2)) strArgs = objMatch.group(4).split(',') strTensor = strArgs[0] intStrides = objVariables[strTensor].stride() strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') # end return strKernel # end class _FunctionSoftsplat(torch.autograd.Function): @staticmethod def forward(self, input, flow): intSamples = input.shape[0] intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] assert(intFlowDepth == 2) assert(intInputHeight == intFlowHeight) assert(intInputWidth == intFlowWidth) input = input.contiguous(); assert(input.is_cuda == True) flow = flow.contiguous(); assert(flow.is_cuda == True) output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if input.is_cuda == True: n = output.nelement() cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { 'input': input, 'flow': flow, 'output': output }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] ) elif input.is_cuda == False: raise NotImplementedError() # end self.save_for_backward(input, flow) return output # end @staticmethod def backward(self, gradOutput): input, flow = self.saved_tensors intSamples = input.shape[0] intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] assert(intFlowDepth == 2) assert(intInputHeight == intFlowHeight) assert(intInputWidth == intFlowWidth) gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None if input.is_cuda == True: if gradInput is not None: n = gradInput.nelement() cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { 'input': input, 'flow': flow, 'gradOutput': gradOutput, 'gradInput': gradInput, 'gradFlow': gradFlow }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] ) # end if gradFlow is not None: n = gradFlow.nelement() cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { 'input': input, 'flow': flow, 'gradOutput': gradOutput, 'gradInput': gradInput, 'gradFlow': gradFlow }))( grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), block=tuple([ 512, 1, 1 ]), args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] ) # end elif input.is_cuda == False: raise NotImplementedError() # end return gradInput, gradFlow # end # end def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): assert(tenMetric is None or tenMetric.shape[1] == 1) assert(strType in ['summation', 'average', 'linear', 'softmax']) if strType == 'average': tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) elif strType == 'linear': tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) elif strType == 'softmax': tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) # end tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) if strType != 'summation': tenNormalize = tenOutput[:, -1:, :, :] tenNormalize[tenNormalize == 0.0] = 1.0 tenOutput = tenOutput[:, :-1, :, :] / tenNormalize # end return tenOutput # end class ModuleSoftsplat(torch.nn.Module): def __init__(self, strType): super().__init__() self.strType = strType # end def forward(self, tenInput, tenFlow, tenMetric): return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) # end # end