zb12138 commited on
Commit
e389f7b
·
0 Parent(s):
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NumpyAc: Fast Autoregressive Arithmetic Coding
2
+
3
+ ## About
4
+
5
+ This is a modified version of the [torchac](https://github.com/fab-jul/torchac). NumpyAc takes numpy array as input and can decode in an autoregressive mode.
6
+
7
+ The backend is written in C++, the API is for PyTorch tensors. It will compile in the first run with ninja.
8
+
9
+ The implementation is based on [this blog post](https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html),
10
+ meaning that we implement _arithmetic coding_.
11
+ While it could be further optimized, it is already much faster than doing the equivalent thing in pure-Python (because of all the
12
+ bit-shifts etc.).
13
+
14
+ ### Set up conda environment
15
+
16
+ This library has been tested with
17
+
18
+ - PyTorch 1.5, 1.6, 1.7
19
+ - Python 3.8
20
+ And that's all you need. Other versions of Python may also work,
21
+ but on-the-fly ninja compilation only works for PyTorch 1.5+.
22
+
23
+ ### Example
24
+
25
+ ```python
26
+ import numpyAc
27
+ import numpy as np
28
+
29
+ # Generate random symbols and pdf.
30
+ dim = 128
31
+ symsNum = 2000
32
+ pdf = np.random.rand(symsNum,dim)
33
+ pdf = pdf / (np.sum(pdf,1,keepdims=True))
34
+ sym = np.random.randint(0,dim,symsNum,dtype=np.int16)
35
+ output_pdf = pdf
36
+
37
+ # Encode to bytestream.
38
+ codec = numpyAc.arithmeticCoding()
39
+ byte_stream,real_bits = codec.encode(pdf, sym,'out.b')
40
+
41
+ # Number of bits taken by the stream.
42
+ print('real_bits',real_bits)
43
+
44
+ # Theoretical bits number
45
+ print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
46
+
47
+ # Decode from bytestream.
48
+ decodec = numpyAc.arithmeticDeCoding(None,symsNum,dim,'out.b')
49
+
50
+ # Autoregressive decoding and output will be equal to the input.
51
+ for i,s in enumerate(sym):
52
+ assert decodec.decode(output_pdf[i:i+1,:]) == s
53
+ ```
54
+
55
+
56
+ ## Important Implementation Details
57
+
58
+ ### How we represent probability distributions
59
+
60
+ The probabilities are specified as [PDFs](https://en.wikipedia.org/wiki/Probability_density_function).
61
+ For each possible symbol, we need one PDF. This means that if there are `symsNum` possible symbols, and the values of them are distributed in `{0, ..., dim-1}`. The PDF ( shape (`symsNum,dim`) ) must specified the value for `symsNum` symbols.
62
+
63
+ **Example**:
64
+
65
+ ```
66
+ For a symsNum = 1 particular symbol, let's say we have dim = 3 possible values.
67
+ We can draw 4 CDF from 3 PDF to specify the symbols distribution:
68
+
69
+ symbol: 0 1 2
70
+ pdf: P(0) P(1) P(2)
71
+ cdf: C_0 C_1 C_2 C_3
72
+
73
+ This corresponds to the 3 probabilities
74
+
75
+ P(0) = C_1 - C_0
76
+ P(1) = C_2 - C_1
77
+ P(2) = C_3 - C_2
78
+
79
+ where PDF =[[ P(0), P(1) ,P(2) ]]
80
+ NOTE: The arithmetic coder assumes that P(0) + P(1) + P(2) = 1, C_0 = 0, C_3 = 1
81
+ ```
82
+ The theoretical bits number can estimated by Shannon’s source coding theorem:
83
+ $\sum_{s}-log_2P(s)$
84
+ ## Citation
85
+ Reference from [torchac](https://github.com/fab-jul/torchac), thanks!
numpyAc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from numpyAc.numpyAc import arithmeticCoding,arithmeticDeCoding
numpyAc/backend/numpyAc_backend.cpp ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /**
2
+ * COPYRIGHT 2020 ETH Zurich
3
+ * BASED on
4
+ *
5
+ * https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
6
+ */
7
+
8
+ #include <torch/extension.h>
9
+
10
+ #include <iostream>
11
+ #include <vector>
12
+ #include <tuple>
13
+ #include <fstream>
14
+ #include <algorithm>
15
+ #include <string>
16
+ #include <chrono>
17
+ #include <numeric>
18
+ #include <iterator>
19
+
20
+ #include <bitset>
21
+
22
+ using cdf_t = uint16_t;
23
+
24
+ /** Encapsulates a pointer to a CDF tensor */
25
+ struct cdf_ptr {
26
+ cdf_t* data; // expected to be a N_sym x Lp matrix, stored in row major.
27
+ const int N_sym; // Number of symbols stored by `data`.
28
+ const int Lp; // == L+1, where L is the number of possible values a symbol can take.
29
+ cdf_ptr(cdf_t* data,
30
+ const int N_sym,
31
+ const int Lp) : data(data), N_sym(N_sym), Lp(Lp) {};
32
+ };
33
+
34
+
35
+
36
+ /** Class to save output bit by bit to a byte string */
37
+ class OutCacheString {
38
+ private:
39
+ public:
40
+ std::string out="";
41
+ uint8_t cache=0;
42
+ uint8_t count=0;
43
+ void append(const int bit) {
44
+ cache <<= 1;
45
+ cache |= bit;
46
+ count += 1;
47
+ if (count == 8) {
48
+ out.append(reinterpret_cast<const char *>(&cache), 1);
49
+ count = 0;
50
+ }
51
+ }
52
+ void flush() {
53
+ if (count > 0) {
54
+ for (int i = count; i < 8; ++i) {
55
+ append(0);
56
+ }
57
+ assert(count==0);
58
+ }
59
+ }
60
+ void append_bit_and_pending(const int bit, uint64_t &pending_bits) {
61
+ append(bit);
62
+ while (pending_bits > 0) {
63
+ append(!bit);
64
+ pending_bits -= 1;
65
+ }
66
+ }
67
+ };
68
+
69
+ /** Class to read byte string bit by bit */
70
+ class InCacheString {
71
+ private:
72
+ const std::string in_;
73
+
74
+ public:
75
+ explicit InCacheString(const std::string& in) : in_(in) {};
76
+
77
+ uint8_t cache=0;
78
+ uint8_t cached_bits=0;
79
+ size_t in_ptr=0;
80
+
81
+ void get(uint32_t& value) {
82
+
83
+ if (cached_bits == 0) {
84
+ if (in_ptr == in_.size()){
85
+ value <<= 1;
86
+ return;
87
+ }
88
+ /// Read 1 byte
89
+
90
+ cache = (uint8_t) in_[in_ptr];
91
+ in_ptr++;
92
+ cached_bits = 8;
93
+ }
94
+ value <<= 1;
95
+ value |= (cache >> (cached_bits - 1)) & 1;
96
+ cached_bits--;
97
+ }
98
+
99
+ void initialize(uint32_t& value) {
100
+ for (int i = 0; i < 32; ++i) {
101
+ get(value);
102
+ }
103
+ }
104
+ };
105
+
106
+
107
+ //------------------------------------------------------------------------------
108
+
109
+
110
+ cdf_t binsearch(py::list &cdf, cdf_t target, cdf_t max_sym,
111
+ const int offset) /* i * Lp */
112
+ {
113
+ cdf_t left = 0;
114
+ cdf_t right = max_sym + 1; // len(cdf) == max_sym + 2
115
+
116
+ while (left + 1 < right) { // ?
117
+ // Left and right will be < 0x10000 in practice,
118
+ // so left+right fits in uint16_t.
119
+ const auto m = static_cast<const cdf_t>((left + right) / 2);
120
+ const auto v = cdf[offset + m].cast<cdf_t>();
121
+ if (v < target) {
122
+ left = m;
123
+ } else if (v > target) {
124
+ right = m;
125
+ } else {
126
+ return m;
127
+ }
128
+ }
129
+
130
+ return left;
131
+ }
132
+
133
+
134
+ class decode
135
+ {
136
+ private:
137
+
138
+ public:
139
+ int dataID=0;
140
+ const int Lp;// To calculate offset
141
+ const int N_sym;// To know the # of syms to decode. Is encoded in the stream!
142
+ const int max_symbol;
143
+ uint32_t low = 0;
144
+ uint32_t high = 0xFFFFFFFFU;
145
+ const uint32_t c_count = 0x10000U;
146
+ const int precision = 16;
147
+ cdf_t sym_i = 0;
148
+ uint32_t value = 0;
149
+ InCacheString in_cache;
150
+ decode(const std::string &in, const int&sysNum_,const int&sysNumDim_):in_cache(in),N_sym(sysNum_),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
151
+ in_cache.initialize(value);
152
+
153
+ };
154
+
155
+ int16_t decodeAsym(py::list cdf) {
156
+
157
+
158
+ for (; dataID < N_sym; ++dataID) {
159
+
160
+ const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
161
+ // always < 0x10000 ???
162
+ const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
163
+
164
+ int offset = 0;
165
+
166
+ sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
167
+
168
+
169
+ if (dataID == N_sym-1) {
170
+ break;
171
+ }
172
+
173
+
174
+ const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
175
+ const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
176
+
177
+ high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
178
+ low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
179
+
180
+ while (true) {
181
+ if (low >= 0x80000000U || high < 0x80000000U) {
182
+ low <<= 1;
183
+ high <<= 1;
184
+ high |= 1;
185
+
186
+ in_cache.get(value);
187
+
188
+ } else if (low >= 0x40000000U && high < 0xC0000000U) {
189
+ /**
190
+ * 0100 0000 ... <= value < 1100 0000 ...
191
+ * <=>
192
+ * 0100 0000 ... <= value <= 1011 1111 ...
193
+ * <=>
194
+ * value starts with 01 or 10.
195
+ * 01 - 01 == 00 | 10 - 01 == 01
196
+ * i.e., with shifts
197
+ * 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
198
+ * near convergence
199
+ */
200
+ low <<= 1;
201
+ low &= 0x7FFFFFFFU; // make MSB 0
202
+ high <<= 1;
203
+ high |= 0x80000001U; // add 1 at the end, retain MSB = 1
204
+ value -= 0x40000000U;
205
+
206
+ in_cache.get(value);
207
+
208
+ } else {
209
+ break;
210
+ }
211
+ }
212
+
213
+ return (int16_t)sym_i;
214
+ }
215
+ }
216
+
217
+ };
218
+
219
+ const void check_sym(const torch::Tensor& sym) {
220
+ TORCH_CHECK(sym.sizes().size() == 1,
221
+ "Invalid size for sym. Expected just 1 dim.")
222
+ }
223
+
224
+ /** Get an instance of the `cdf_ptr` struct. */
225
+ const struct cdf_ptr get_cdf_ptr(const torch::Tensor& cdf)
226
+ {
227
+ TORCH_CHECK(!cdf.is_cuda(), "cdf must be on CPU!")
228
+ const auto s = cdf.sizes();
229
+ TORCH_CHECK(s.size() == 2, "Invalid size for cdf! Expected (N, Lp)")
230
+
231
+ const int N_sym = s[0];
232
+ const int Lp = s[1];
233
+ const auto cdf_acc = cdf.accessor<int16_t, 2>();
234
+ cdf_t* cdf_ptr = (uint16_t*)cdf_acc.data();
235
+
236
+ const struct cdf_ptr res(cdf_ptr, N_sym, Lp);
237
+ return res;
238
+ }
239
+
240
+
241
+ // -----------------------------------------------------------------------------
242
+
243
+
244
+ /** Encode symbols `sym` with CDF represented by `cdf_ptr`. NOTE: this is not exposted to python. */
245
+ py::bytes encode(
246
+ const cdf_ptr& cdf_ptr,
247
+ const torch::Tensor& sym){
248
+
249
+ OutCacheString out_cache;
250
+
251
+ uint32_t low = 0;
252
+ uint32_t high = 0xFFFFFFFFU;
253
+ uint64_t pending_bits = 0;
254
+
255
+ const int precision = 16;
256
+
257
+ const cdf_t* cdf = cdf_ptr.data;
258
+ const int N_sym = cdf_ptr.N_sym;
259
+ const int Lp = cdf_ptr.Lp;
260
+ const int max_symbol = Lp - 2;
261
+
262
+ auto sym_ = sym.accessor<int16_t, 1>();
263
+
264
+ for (int i = 0; i < N_sym; ++i) {
265
+ const int16_t sym_i = sym_[i];
266
+
267
+ const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
268
+
269
+ const int offset = i * Lp;
270
+ // Left boundary is at offset + sym_i
271
+ const uint32_t c_low = cdf[offset + sym_i];
272
+ // Right boundary is at offset + sym_i + 1, except for the `max_symbol`
273
+ // For which we hardcode the maxvalue. So if e.g.
274
+ // L == 4, it means that Lp == 5, and the allowed symbols are
275
+ // {0, 1, 2, 3}. The max symbol is thus Lp - 2 == 3. It's probability
276
+ // is then given by c_max - cdf[-2].
277
+ const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1];
278
+
279
+ high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
280
+ low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
281
+
282
+ while (true) {
283
+ if (high < 0x80000000U) {
284
+ out_cache.append_bit_and_pending(0, pending_bits);
285
+ low <<= 1;
286
+ high <<= 1;
287
+ high |= 1;
288
+ } else if (low >= 0x80000000U) {
289
+ out_cache.append_bit_and_pending(1, pending_bits);
290
+ low <<= 1;
291
+ high <<= 1;
292
+ high |= 1;
293
+ } else if (low >= 0x40000000U && high < 0xC0000000U) {
294
+ pending_bits++;
295
+ low <<= 1;
296
+ low &= 0x7FFFFFFF;
297
+ high <<= 1;
298
+ high |= 0x80000001;
299
+ } else {
300
+ break;
301
+ }
302
+ }
303
+ }
304
+
305
+ pending_bits += 1;
306
+
307
+ if (pending_bits) {
308
+ if (low < 0x40000000U) {
309
+ out_cache.append_bit_and_pending(0, pending_bits);
310
+ } else {
311
+ out_cache.append_bit_and_pending(1, pending_bits);
312
+ }
313
+ }
314
+
315
+ out_cache.flush();
316
+
317
+ #ifdef VERBOSE
318
+ std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now();
319
+ std::cout << "Time difference (sec) = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) /1000000.0 <<std::endl;
320
+ #endif
321
+
322
+ return py::bytes(out_cache.out);
323
+ }
324
+
325
+
326
+ /** See torchac.py */
327
+ py::bytes encode_cdf(
328
+ const torch::Tensor& cdf, /* NHWLp, must be on CPU! */
329
+ const torch::Tensor& sym)
330
+ {
331
+ check_sym(sym);
332
+ const auto cdf_ptr = get_cdf_ptr(cdf);
333
+ return encode(cdf_ptr, sym);
334
+ }
335
+
336
+
337
+
338
+
339
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
340
+ m.def("encode_cdf", &encode_cdf, "Encode from CDF");
341
+
342
+ py::class_<decode>(m, "decode")
343
+ .def(py::init([] (const std::string in, const int&sysNum_,const int&sysNumDim_) {
344
+ return new decode(in,sysNum_,sysNumDim_);
345
+ }))
346
+ .def("decodeAsym", &decode::decodeAsym);
347
+ }
numpyAc/numpyAc.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from torch.autograd.grad_mode import F
5
+ from torch.utils.cpp_extension import load
6
+
7
+
8
+ PRECISION = 16 # DO NOT EDIT!
9
+
10
+
11
+ # Load on-the-fly with ninja.
12
+ torchac_dir = os.path.dirname(os.path.realpath(__file__))
13
+ backend_dir = os.path.join(torchac_dir, 'backend')
14
+ numpyAc_backend = load(
15
+ name="numpyAc_backend",
16
+ sources=[os.path.join(backend_dir, "numpyAc_backend.cpp")],
17
+ verbose=False)
18
+
19
+ def _encode_float_cdf(cdf_float,
20
+ sym,
21
+ needs_normalization=True,
22
+ check_input_bounds=False):
23
+ """Encode symbols `sym` with potentially unnormalized floating point CDF.
24
+
25
+ Check the README for more details.
26
+
27
+ :param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp).
28
+ :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
29
+ :param needs_normalization: if True, assume `cdf_float` is un-normalized and
30
+ needs normalization. Otherwise only convert it, without normalizing.
31
+ :param check_input_bounds: if True, ensure inputs have valid values.
32
+ Important: may take significant time. Only enable to check.
33
+
34
+ :return: byte-string, encoding `sym`.
35
+ """
36
+ if check_input_bounds:
37
+ if cdf_float.min() < 0:
38
+ raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!')
39
+ if cdf_float.max() > 1:
40
+ raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!')
41
+ Lp = cdf_float.shape[-1]
42
+ if sym.max() >= Lp - 1:
43
+ raise ValueError(f'sym.max() == {sym.max()}, should be <=Lp - 1.!')
44
+ cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
45
+ return _encode_int16_normalized_cdf(cdf_int, sym)
46
+
47
+
48
+ def _encode_int16_normalized_cdf(cdf_int, sym):
49
+ """Encode symbols `sym` with a normalized integer cdf `cdf_int`.
50
+
51
+ Check the README for more details.
52
+
53
+ :param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
54
+ :param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
55
+
56
+ :return: byte-string, encoding `sym`
57
+ """
58
+ cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym)
59
+ return numpyAc_backend.encode_cdf( torch.ShortTensor(cdf_int), torch.ShortTensor(sym))
60
+
61
+
62
+ def _check_and_reshape_inputs(cdf, sym=None):
63
+ """Check device, dtype, and shapes."""
64
+ if sym is not None and sym.dtype != np.int16:
65
+ raise ValueError('Symbols must be int16!')
66
+ if sym is not None:
67
+ if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape:
68
+ raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! '
69
+ 'The first m elements of cdf.shape must be equal to '
70
+ 'sym.shape, and cdf should only have one more dimension.')
71
+ Lp = cdf.shape[-1]
72
+ cdf = cdf.reshape(-1, Lp)
73
+ if sym is None:
74
+ return cdf
75
+ sym = sym.reshape(-1)
76
+ return cdf, sym
77
+
78
+
79
+ # def _reshape_output(cdf_shape, sym):
80
+ # """Reshape single dimension `sym` back to the correct spatial dimensions."""
81
+ # spatial_dimensions = cdf_shape[:-1]
82
+ # if len(sym) != np.prod(spatial_dimensions):
83
+ # raise ValueError()
84
+ # return sym.reshape(*spatial_dimensions)
85
+
86
+
87
+ def _convert_to_int_and_normalize(cdf_float, needs_normalization):
88
+ """Convert floatingpoint CDF to integers. See README for more info.
89
+
90
+ The idea is the following:
91
+ When we get the cdf here, it is (assumed to be) between 0 and 1, i.e,
92
+ cdf \in [0, 1)
93
+ (note that 1 should not be included.)
94
+ We now want to convert this to int16 but make sure we do not get
95
+ the same value twice, as this would break the arithmetic coder
96
+ (you need a strictly monotonically increasing function).
97
+ So, if needs_normalization==True, we multiply the input CDF
98
+ with 2**16 - (Lp - 1). This means that now,
99
+ cdf \in [0, 2**16 - (Lp - 1)].
100
+ Then, in a final step, we add an arange(Lp), which is just a line with
101
+ slope one. This ensure that for sure, we will get unique, strictly
102
+ monotonically increasing CDFs, which are \in [0, 2**16)
103
+ """
104
+ Lp = cdf_float.shape[-1]
105
+ factor = 2**PRECISION
106
+ new_max_value = factor
107
+ if needs_normalization:
108
+ new_max_value = new_max_value - (Lp - 1)
109
+ cdf_float = cdf_float*(new_max_value)
110
+ cdf_float = np.round(cdf_float)
111
+ cdf = cdf_float.astype(np.int16)
112
+ if needs_normalization:
113
+ r = np.arange(Lp)
114
+ cdf+=r
115
+ return cdf
116
+
117
+ def pdf_convert_to_cdf_and_normalize(pdf):
118
+ assert pdf.ndim==2
119
+ pdf = pdf / (np.sum(pdf,1,keepdims=True))/(1+10**(-10))
120
+ cdfF = np.cumsum( pdf, axis=1)
121
+ cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
122
+ return cdfF
123
+
124
+ class arithmeticCoding():
125
+ def __init__(self) -> None:
126
+ self.binfile = None
127
+ self.sysNum = None
128
+ self.byte_stream = None
129
+
130
+
131
+ def encode(self,pdf,sym,binfile=None):
132
+ assert pdf.shape[0]==sym.shape[0]
133
+ assert pdf.ndim==2 and sym.ndim==1
134
+
135
+ self.sysNum = sym.shape[0]
136
+
137
+ cdfF = pdf_convert_to_cdf_and_normalize(pdf)
138
+
139
+ # pdf = np.diff(cdfF)
140
+ # print( -np.log2(pdf[range(0,self.sysNum),sym]).sum())
141
+
142
+ self.byte_stream = _encode_float_cdf(cdfF, sym, check_input_bounds=True)
143
+ real_bits = len(self.byte_stream) * 8
144
+ # # Write to a file.
145
+ if binfile is not None:
146
+ with open(binfile, 'wb') as fout:
147
+ fout.write(self.byte_stream)
148
+ return self.byte_stream,real_bits
149
+
150
+ class arithmeticDeCoding():
151
+ """
152
+ Decoding class
153
+ byte_stream: the bin file stream.
154
+ sysNum: the Number of symbols that you are going to decode. This value should be
155
+ saved in other ways.
156
+ sysDim: the Number of the possible symbols.
157
+ binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
158
+ and copy to Cpp backend Class 'InCacheString'
159
+ """
160
+ def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None:
161
+ if binfile is not None:
162
+ with open(binfile, 'rb') as fin:
163
+ byte_stream = fin.read()
164
+ self.byte_stream = byte_stream
165
+ self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1)
166
+
167
+ def decode(self,pdf):
168
+ cdfF = pdf_convert_to_cdf_and_normalize(pdf)
169
+ pro = _convert_to_int_and_normalize(cdfF,needs_normalization=True)
170
+ pro = pro.squeeze(0).astype(np.uint16).tolist()
171
+ sym_out = self.decoder.decodeAsym(pro)
172
+ return sym_out
test.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpyAc
2
+ import numpy as np
3
+
4
+ # Generate random symbols and pdf.
5
+ dim = 128
6
+ symsNum = 2000
7
+ pdf = np.random.rand(symsNum,dim)
8
+ pdf = pdf / (np.sum(pdf,1,keepdims=True))
9
+ sym = np.random.randint(0,dim,symsNum,dtype=np.int16)
10
+ output_pdf = pdf
11
+
12
+ # Encode to bytestream.
13
+ codec = numpyAc.arithmeticCoding()
14
+ byte_stream,real_bits = codec.encode(pdf, sym,'out.b')
15
+
16
+ # Number of bits taken by the stream.
17
+ print('real_bits',real_bits)
18
+
19
+ # Theoretical bits number
20
+ print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
21
+
22
+ # Decode from bytestream.
23
+ decodec = numpyAc.arithmeticDeCoding(None,symsNum,dim,'out.b')
24
+
25
+ # Autoregressive decoding and output will be equal to the input.
26
+ for i,s in enumerate(sym):
27
+ assert decodec.decode(output_pdf[i:i+1,:]) == s
testTorchac.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ LastEditors: fcy
3
+ '''
4
+ import torchac
5
+ import torch
6
+ import numpy as np
7
+ # Encode to bytestream.
8
+
9
+ seed=6
10
+ torch.manual_seed(seed)
11
+ np.random.seed(seed)
12
+
13
+ dim = 500
14
+ symsNum = 40000
15
+ pdf = np.random.rand(symsNum,dim)
16
+ pdf = pdf / (np.sum(pdf,1,keepdims=True))
17
+ sym = torch.ShortTensor(np.random.randint(0,dim,symsNum,dtype=np.int16))
18
+
19
+ def pdf_convert_to_cdf_and_normalize(pdf):
20
+ assert pdf.ndim==2
21
+ pdf = pdf / (np.sum(pdf,1,keepdims=True))/(1+10**(-10))
22
+ cdfF = np.cumsum( pdf, axis=1)
23
+ cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
24
+ return cdfF
25
+
26
+
27
+ output_cdf = torch.Tensor(pdf_convert_to_cdf_and_normalize(pdf)) # Get CDF from your model, shape B, C, H, W, Lp
28
+
29
+ byte_stream = torchac.encode_float_cdf(output_cdf, sym, check_input_bounds=True)
30
+
31
+
32
+ # pdf = np.diff(cdfF)
33
+ # print( -np.log2(pdf[range(0,oct_len),sym]).sum())
34
+
35
+ # Number of bits taken by the stream
36
+ real_bits = len(byte_stream) * 8
37
+ print(real_bits)
38
+ # Write to a file.
39
+ with open('outfile.b', 'wb') as fout:
40
+ fout.write(byte_stream)
41
+
42
+ # Read from a file.
43
+ with open('outfile.b', 'rb') as fin:
44
+ byte_stream = fin.read()
45
+
46
+ # Decode from bytestream.
47
+ sym_out = torchac.decode_float_cdf(output_cdf, byte_stream)
48
+
49
+ # Output will be equal to the input.
50
+ assert sym_out.equal(sym)