Spaces:
Runtime error
Runtime error
File size: 5,844 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
#include <pthread.h>
#include <torch/extension.h>
typedef struct decompress_args {
int tid;
int nthreads;
int npids;
int dim;
int packed_dim;
int npacked_vals_per_byte;
int* pids;
int64_t* lengths;
int64_t* offsets;
float* bucket_weights;
uint8_t* reversed_bit_map;
uint8_t* bucket_weight_combinations;
uint8_t* binary_residuals;
int* codes;
float* centroids;
int64_t* cumulative_lengths;
float* output;
} decompress_args_t;
void* decompress(void* args) {
decompress_args_t* decompress_args = (decompress_args_t*)args;
int npids_per_thread = (int)std::ceil(((float)decompress_args->npids) /
decompress_args->nthreads);
int start = decompress_args->tid * npids_per_thread;
int end = std::min((decompress_args->tid + 1) * npids_per_thread,
decompress_args->npids);
// Iterate over all documents
for (int i = start; i < end; i++) {
int pid = decompress_args->pids[i];
// Offset into packed list of token vectors for the given document
int64_t offset = decompress_args->offsets[pid];
// For each document, iterate over all token vectors
for (int j = 0; j < decompress_args->lengths[pid]; j++) {
const int code = decompress_args->codes[offset + j];
// For each token vector, iterate over the packed (8-bit) residual
// values
for (int k = 0; k < decompress_args->packed_dim; k++) {
uint8_t x =
decompress_args->binary_residuals
[(offset + j) * decompress_args->packed_dim + k];
x = decompress_args->reversed_bit_map[x];
// For each packed residual value, iterate over the bucket
// weight indices. If we use n-bit compression, that means there
// will be (8 / n) indices per packed value.
for (int l = 0; l < decompress_args->npacked_vals_per_byte;
l++) {
const int output_dim_idx =
k * decompress_args->npacked_vals_per_byte + l;
const int bucket_weight_idx =
decompress_args->bucket_weight_combinations
[x * decompress_args->npacked_vals_per_byte + l];
decompress_args
->output[(decompress_args->cumulative_lengths[i] + j) *
decompress_args->dim +
output_dim_idx] =
decompress_args->bucket_weights[bucket_weight_idx] +
decompress_args->centroids[code * decompress_args->dim +
output_dim_idx];
}
}
}
}
return NULL;
}
torch::Tensor decompress_residuals(
const torch::Tensor pids, const torch::Tensor lengths,
const torch::Tensor offsets, const torch::Tensor bucket_weights,
const torch::Tensor reversed_bit_map,
const torch::Tensor bucket_weight_combinations,
const torch::Tensor binary_residuals, const torch::Tensor codes,
const torch::Tensor centroids, const int dim, const int nbits) {
const int npacked_vals_per_byte = (8 / nbits);
const int packed_dim = (int)(dim / npacked_vals_per_byte);
int npids = pids.size(0);
int* pids_a = pids.data_ptr<int>();
int64_t* lengths_a = lengths.data_ptr<int64_t>();
int64_t* offsets_a = offsets.data_ptr<int64_t>();
float* bucket_weights_a = bucket_weights.data_ptr<float>();
uint8_t* reversed_bit_map_a = reversed_bit_map.data_ptr<uint8_t>();
uint8_t* bucket_weight_combinations_a =
bucket_weight_combinations.data_ptr<uint8_t>();
uint8_t* binary_residuals_a = binary_residuals.data_ptr<uint8_t>();
int* codes_a = codes.data_ptr<int>();
float* centroids_a = centroids.data_ptr<float>();
int64_t cumulative_lengths[npids + 1];
int noutputs = 0;
cumulative_lengths[0] = 0;
for (int i = 0; i < npids; i++) {
noutputs += lengths_a[pids_a[i]];
cumulative_lengths[i + 1] =
cumulative_lengths[i] + lengths_a[pids_a[i]];
}
auto options =
torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
torch::Tensor output = torch::zeros({noutputs, dim}, options);
float* output_a = output.data_ptr<float>();
auto nthreads = at::get_num_threads();
pthread_t threads[nthreads];
decompress_args_t args[nthreads];
for (int i = 0; i < nthreads; i++) {
args[i].tid = i;
args[i].nthreads = nthreads;
args[i].npids = npids;
args[i].dim = dim;
args[i].packed_dim = packed_dim;
args[i].npacked_vals_per_byte = npacked_vals_per_byte;
args[i].pids = pids_a;
args[i].lengths = lengths_a;
args[i].offsets = offsets_a;
args[i].bucket_weights = bucket_weights_a;
args[i].reversed_bit_map = reversed_bit_map_a;
args[i].bucket_weight_combinations = bucket_weight_combinations_a;
args[i].binary_residuals = binary_residuals_a;
args[i].codes = codes_a;
args[i].centroids = centroids_a;
args[i].cumulative_lengths = cumulative_lengths;
args[i].output = output_a;
int rc = pthread_create(&threads[i], NULL, decompress, (void*)&args[i]);
if (rc) {
fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
std::exit(1);
}
}
for (int i = 0; i < nthreads; i++) {
pthread_join(threads[i], NULL);
}
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decompress_residuals_cpp", &decompress_residuals,
"Decompress residuals");
}
|