File size: 5,844 Bytes
58627fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
#include <pthread.h>
#include <torch/extension.h>

typedef struct decompress_args {
    int tid;
    int nthreads;

    int npids;
    int dim;
    int packed_dim;
    int npacked_vals_per_byte;

    int* pids;
    int64_t* lengths;
    int64_t* offsets;
    float* bucket_weights;
    uint8_t* reversed_bit_map;
    uint8_t* bucket_weight_combinations;
    uint8_t* binary_residuals;
    int* codes;
    float* centroids;
    int64_t* cumulative_lengths;

    float* output;
} decompress_args_t;

void* decompress(void* args) {
    decompress_args_t* decompress_args = (decompress_args_t*)args;

    int npids_per_thread = (int)std::ceil(((float)decompress_args->npids) /
                                          decompress_args->nthreads);
    int start = decompress_args->tid * npids_per_thread;
    int end = std::min((decompress_args->tid + 1) * npids_per_thread,
                       decompress_args->npids);

    // Iterate over all documents
    for (int i = start; i < end; i++) {
        int pid = decompress_args->pids[i];

        // Offset into packed list of token vectors for the given document
        int64_t offset = decompress_args->offsets[pid];

        // For each document, iterate over all token vectors
        for (int j = 0; j < decompress_args->lengths[pid]; j++) {
            const int code = decompress_args->codes[offset + j];

            // For each token vector, iterate over the packed (8-bit) residual
            // values
            for (int k = 0; k < decompress_args->packed_dim; k++) {
                uint8_t x =
                    decompress_args->binary_residuals
                        [(offset + j) * decompress_args->packed_dim + k];
                x = decompress_args->reversed_bit_map[x];

                // For each packed residual value, iterate over the bucket
                // weight indices. If we use n-bit compression, that means there
                // will be (8 / n) indices per packed value.
                for (int l = 0; l < decompress_args->npacked_vals_per_byte;
                     l++) {
                    const int output_dim_idx =
                        k * decompress_args->npacked_vals_per_byte + l;
                    const int bucket_weight_idx =
                        decompress_args->bucket_weight_combinations
                            [x * decompress_args->npacked_vals_per_byte + l];
                    decompress_args
                        ->output[(decompress_args->cumulative_lengths[i] + j) *
                                     decompress_args->dim +
                                 output_dim_idx] =
                        decompress_args->bucket_weights[bucket_weight_idx] +
                        decompress_args->centroids[code * decompress_args->dim +
                                                   output_dim_idx];
                }
            }
        }
    }

    return NULL;
}

torch::Tensor decompress_residuals(
    const torch::Tensor pids, const torch::Tensor lengths,
    const torch::Tensor offsets, const torch::Tensor bucket_weights,
    const torch::Tensor reversed_bit_map,
    const torch::Tensor bucket_weight_combinations,
    const torch::Tensor binary_residuals, const torch::Tensor codes,
    const torch::Tensor centroids, const int dim, const int nbits) {
    const int npacked_vals_per_byte = (8 / nbits);
    const int packed_dim = (int)(dim / npacked_vals_per_byte);

    int npids = pids.size(0);
    int* pids_a = pids.data_ptr<int>();
    int64_t* lengths_a = lengths.data_ptr<int64_t>();
    int64_t* offsets_a = offsets.data_ptr<int64_t>();
    float* bucket_weights_a = bucket_weights.data_ptr<float>();
    uint8_t* reversed_bit_map_a = reversed_bit_map.data_ptr<uint8_t>();
    uint8_t* bucket_weight_combinations_a =
        bucket_weight_combinations.data_ptr<uint8_t>();
    uint8_t* binary_residuals_a = binary_residuals.data_ptr<uint8_t>();
    int* codes_a = codes.data_ptr<int>();
    float* centroids_a = centroids.data_ptr<float>();

    int64_t cumulative_lengths[npids + 1];
    int noutputs = 0;
    cumulative_lengths[0] = 0;
    for (int i = 0; i < npids; i++) {
        noutputs += lengths_a[pids_a[i]];
        cumulative_lengths[i + 1] =
            cumulative_lengths[i] + lengths_a[pids_a[i]];
    }

    auto options =
        torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false);
    torch::Tensor output = torch::zeros({noutputs, dim}, options);
    float* output_a = output.data_ptr<float>();

    auto nthreads = at::get_num_threads();

    pthread_t threads[nthreads];
    decompress_args_t args[nthreads];

    for (int i = 0; i < nthreads; i++) {
        args[i].tid = i;
        args[i].nthreads = nthreads;

        args[i].npids = npids;
        args[i].dim = dim;
        args[i].packed_dim = packed_dim;
        args[i].npacked_vals_per_byte = npacked_vals_per_byte;

        args[i].pids = pids_a;
        args[i].lengths = lengths_a;
        args[i].offsets = offsets_a;
        args[i].bucket_weights = bucket_weights_a;
        args[i].reversed_bit_map = reversed_bit_map_a;
        args[i].bucket_weight_combinations = bucket_weight_combinations_a;
        args[i].binary_residuals = binary_residuals_a;
        args[i].codes = codes_a;
        args[i].centroids = centroids_a;
        args[i].cumulative_lengths = cumulative_lengths;

        args[i].output = output_a;

        int rc = pthread_create(&threads[i], NULL, decompress, (void*)&args[i]);
        if (rc) {
            fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
            std::exit(1);
        }
    }

    for (int i = 0; i < nthreads; i++) {
        pthread_join(threads[i], NULL);
    }

    return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("decompress_residuals_cpp", &decompress_residuals,
          "Decompress residuals");
}