File size: 5,254 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# Code Structure of CUDA operators

This folder contains all non-python code for MMCV custom ops. Please follow the same architecture if you want to add new ops.

## Directories Tree

```folder
.
β”œβ”€β”€ common
β”‚   β”œβ”€β”€ box_iou_rotated_utils.hpp
β”‚   β”œβ”€β”€ parrots_cpp_helper.hpp
β”‚   β”œβ”€β”€ parrots_cuda_helper.hpp
β”‚   β”œβ”€β”€ pytorch_cpp_helper.hpp
β”‚   β”œβ”€β”€ pytorch_cuda_helper.hpp
β”‚   β”œβ”€β”€ pytorch_device_registry.hpp
β”‚Β Β  β”œβ”€β”€ cuda
β”‚Β Β  β”‚   β”œβ”€β”€ common_cuda_helper.hpp
β”‚Β Β  β”‚   β”œβ”€β”€ parrots_cudawarpfunction.cuh
β”‚Β Β  β”‚   β”œβ”€β”€ ...
β”‚Β Β  β”‚   └── ops_cuda_kernel.cuh
|Β Β  β”œβ”€β”€ mps
β”‚Β Β  β”‚   β”œβ”€β”€ MPSLibrary.h
β”‚Β Β  β”‚   β”œβ”€β”€ ...
β”‚Β Β  β”‚   └── MPSUtils.h
|Β Β  β”œβ”€β”€ mlu
β”‚Β Β  β”‚   └── ...
|Β Β  └── utils
β”‚Β Β  β”‚   └── ...
β”œβ”€β”€ parrots
β”‚Β Β  β”œβ”€β”€ ...
β”‚Β Β  β”œβ”€β”€ ops.cpp
β”‚Β Β  β”œβ”€β”€ ops_parrots.cpp
β”‚Β Β  └── ops_pytorch.h
└── pytorch
Β Β Β  β”œβ”€β”€ info.cpp
Β Β Β  β”œβ”€β”€ pybind.cpp
Β Β Β  β”œβ”€β”€ ...
Β Β Β  β”œβ”€β”€ ops.cpp
Β Β Β  β”œβ”€β”€ cuda
Β Β Β  β”‚Β Β  β”œβ”€β”€ ...
Β Β Β  β”‚Β Β  └── ops_cuda.cu
Β Β Β  β”œβ”€β”€ cpu
Β Β Β  β”‚Β Β  β”œβ”€β”€ ...
Β Β Β  β”‚Β Β  └── ops.cpp
Β Β Β  β”œβ”€β”€ mps
Β Β Β  β”‚Β Β  β”œβ”€β”€ ...
Β Β Β  |Β Β  └── op_mps.mm
Β Β Β  └── mlu
Β Β Β   Β Β  β”œβ”€β”€ ...
Β Β Β   Β Β  └── op_mlu.cpp
```

## Components

- `common`: This directory contains all tools and shared codes.
  - `cuda`: The cuda kernels which can be shared by all backends. **HIP** kernel is also here since they have similar syntax.
  - `mps`: The tools used to support MPS ops. **NOTE** that MPS support is **experimental**.
  - `mlu`: The MLU kernels used to support [Cambricon](https://www.cambricon.com/) device.
  - `utils`: The kernels and utils of spconv.
- `parrots`: **Parrots** is a deep learning frame for model training and inference. Parrots custom ops are placed in this directory.
- `pytorch`: **PyTorch** custom ops are supported by binding C++ to Python with **pybind11**. The ops implementation and binding codes are placed in this directory.
  - `cuda`: This directory contains cuda kernel launchers, which feed memory pointers of tensor to the cuda kernel in `common/cuda`. The launchers provide c++ interface of cuda implementation of corresponding custom ops.
  - `cpu`: This directory contain cpu implementations of corresponding custom ops.
  - `mlu`: This directory contain launchers of each MLU kernels.
  - `mps`: MPS ops implementation and launchers.

## How to add new PyTorch ops?

1. (Optional) Add shared kernel in `common` to support special hardware platform.

   ```c++
   // src/common/cuda/new_ops_cuda_kernel.cuh

   template <typename T>
   __global__ void new_ops_forward_cuda_kernel(const T* input, T* output, ...) {
       // forward here
   }

   ```

   Add cuda kernel launcher in `pytorch/cuda`.

   ```c++
   // src/pytorch/cuda
   #include <new_ops_cuda_kernel.cuh>

   void NewOpsForwardCUDAKernelLauncher(Tensor input, Tensor output, ...){
       // initialize
       at::cuda::CUDAGuard device_guard(input.device());
       cudaStream_t stream = at::cuda::getCurrentCUDAStream();
       ...
       AT_DISPATCH_FLOATING_TYPES_AND_HALF(
           input.scalar_type(), "new_ops_forward_cuda_kernel", ([&] {
               new_ops_forward_cuda_kernel<scalar_t>
                   <<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
                       input.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),...);
           }));
       AT_CUDA_CHECK(cudaGetLastError());
   }
   ```

2. Register implementation for different devices.

   ```c++
   // src/pytorch/cuda/cudabind.cpp
   ...

   Tensor new_ops_forward_cuda(Tensor input, Tensor output, ...){
       // implement cuda forward here
       // use `NewOpsForwardCUDAKernelLauncher` here
   }
   // declare interface here.
   Tensor new_ops_forward_impl(Tensor input, Tensor output, ...);
   // register the implementation for given device (CUDA here).
   REGISTER_DEVICE_IMPL(new_ops_forward_impl, CUDA, new_ops_forward_cuda);
   ```

3. Add ops implementation in `pytorch` directory. Select different implementations according to device type.

   ```c++
   // src/pytorch/new_ops.cpp
   Tensor new_ops_forward_impl(Tensor input, Tensor output, ...){
       // dispatch the implementation according to the device type of input.
       DISPATCH_DEVICE_IMPL(new_ops_forward_impl, input, output, ...);
   }
   ...

   Tensor new_ops_forward(Tensor input, Tensor output, ...){
       return new_ops_forward_impl(input, output, ...);
   }
   ```

4. Binding the implementation in `pytorch/pybind.cpp`

   ```c++
   // src/pytorch/pybind.cpp

   ...

   Tensor new_ops_forward(Tensor input, Tensor output, ...);

   ...

   // bind with pybind11
   m.def("new_ops_forward", &new_ops_forward, "new_ops_forward",
           py::arg("input"), py::arg("output"), ...);

   ...

   ```

5. Build MMCV again. Enjoy new ops in python

   ```python
   from ..utils import ext_loader
   ext_module = ext_loader.load_ext('_ext', ['new_ops_forward'])

   ...

   ext_module.new_ops_forward(input, output, ...)

   ```