#include typedef unsigned char uint8_t; // extern "C" // __global__ void quantize( // const half* __restrict__ codebook, // nsq x 2^b x d // const half* __restrict__ vectors, // n x (nsq * d) // uint8_t* __restrict__ codes, // nsq x n // int n // ) { // extern __shared__ volatile half centroids[]; // 2^b x d // const int sq_id = blockIdx.x; // const int thread_id = threadIdx.x; // const int n_threads = blockDim.x; // const int n_floats_per_sq = (1 << __B__) * __D__; // #pragma unroll // for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { // centroids[i] = codebook[sq_id * n_floats_per_sq + i]; // } // __syncthreads(); // half subvector[__D__]; // for (int i = thread_id; i < n; i += n_threads) { // #pragma unroll // for (int j = 0; j < __D__; ++j) { // subvector[j] = vectors[(i * __NSQ__ + sq_id) * __D__ + j]; // } // float min_dist = 1 << 16; // uint8_t min_idx; // #pragma unroll // for (int j = 0; j < (1 << __B__); ++j) { // float dist = 0; // #pragma unroll // for (int k = 0; k < __D__; ++k) { // float tmp = __half2float(subvector[k]) - __half2float(centroids[j * __D__ + k]); // dist += tmp * tmp; // } // min_dist = (dist <= min_dist) ? dist : min_dist; // min_idx = (dist == min_dist) ? j : min_idx; // } // // printf("%d %d %d %d\n", sq_id, n, i, min_idx); // codes[sq_id * n + i] = min_idx; // } // } extern "C" __global__ void dequantize( const half* __restrict__ codebook, // nsq x 2^b x d const uint8_t* __restrict__ codes, // nsq x n half* __restrict__ vectors, // n x (nsq x d) int n ) { extern __shared__ volatile half centroids[]; // 2^b x d const int sq_id = blockIdx.x; const int thread_id = threadIdx.x; const int n_threads = blockDim.x; const int n_floats_per_sq = (1 << __B__) * __D__; #pragma unroll for (int i = thread_id; i < n_floats_per_sq; i += n_threads) { centroids[i] = codebook[sq_id * n_floats_per_sq + i]; } __syncthreads(); for (int i = thread_id; i < n; i += n_threads) { uint8_t code = codes[sq_id * n + i]; #pragma unroll for (int dim = 0; dim < __D__; ++dim) { vectors[(i * __NSQ__ + sq_id) * __D__ + dim] = centroids[__D__ * code + dim]; // atomicAdd(vectors + (i * __NSQ__ + sq_id) * __D__ + dim, centroids[__D__ * code + dim]); } } }