Spaces:
Sleeping
Sleeping
zb12138
commited on
Commit
·
e389f7b
0
Parent(s):
numpyAc
Browse files- README.md +85 -0
- numpyAc/__init__.py +1 -0
- numpyAc/backend/numpyAc_backend.cpp +347 -0
- numpyAc/numpyAc.py +172 -0
- test.py +27 -0
- testTorchac.py +50 -0
README.md
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# NumpyAc: Fast Autoregressive Arithmetic Coding
|
2 |
+
|
3 |
+
## About
|
4 |
+
|
5 |
+
This is a modified version of the [torchac](https://github.com/fab-jul/torchac). NumpyAc takes numpy array as input and can decode in an autoregressive mode.
|
6 |
+
|
7 |
+
The backend is written in C++, the API is for PyTorch tensors. It will compile in the first run with ninja.
|
8 |
+
|
9 |
+
The implementation is based on [this blog post](https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html),
|
10 |
+
meaning that we implement _arithmetic coding_.
|
11 |
+
While it could be further optimized, it is already much faster than doing the equivalent thing in pure-Python (because of all the
|
12 |
+
bit-shifts etc.).
|
13 |
+
|
14 |
+
### Set up conda environment
|
15 |
+
|
16 |
+
This library has been tested with
|
17 |
+
|
18 |
+
- PyTorch 1.5, 1.6, 1.7
|
19 |
+
- Python 3.8
|
20 |
+
And that's all you need. Other versions of Python may also work,
|
21 |
+
but on-the-fly ninja compilation only works for PyTorch 1.5+.
|
22 |
+
|
23 |
+
### Example
|
24 |
+
|
25 |
+
```python
|
26 |
+
import numpyAc
|
27 |
+
import numpy as np
|
28 |
+
|
29 |
+
# Generate random symbols and pdf.
|
30 |
+
dim = 128
|
31 |
+
symsNum = 2000
|
32 |
+
pdf = np.random.rand(symsNum,dim)
|
33 |
+
pdf = pdf / (np.sum(pdf,1,keepdims=True))
|
34 |
+
sym = np.random.randint(0,dim,symsNum,dtype=np.int16)
|
35 |
+
output_pdf = pdf
|
36 |
+
|
37 |
+
# Encode to bytestream.
|
38 |
+
codec = numpyAc.arithmeticCoding()
|
39 |
+
byte_stream,real_bits = codec.encode(pdf, sym,'out.b')
|
40 |
+
|
41 |
+
# Number of bits taken by the stream.
|
42 |
+
print('real_bits',real_bits)
|
43 |
+
|
44 |
+
# Theoretical bits number
|
45 |
+
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
46 |
+
|
47 |
+
# Decode from bytestream.
|
48 |
+
decodec = numpyAc.arithmeticDeCoding(None,symsNum,dim,'out.b')
|
49 |
+
|
50 |
+
# Autoregressive decoding and output will be equal to the input.
|
51 |
+
for i,s in enumerate(sym):
|
52 |
+
assert decodec.decode(output_pdf[i:i+1,:]) == s
|
53 |
+
```
|
54 |
+
|
55 |
+
|
56 |
+
## Important Implementation Details
|
57 |
+
|
58 |
+
### How we represent probability distributions
|
59 |
+
|
60 |
+
The probabilities are specified as [PDFs](https://en.wikipedia.org/wiki/Probability_density_function).
|
61 |
+
For each possible symbol, we need one PDF. This means that if there are `symsNum` possible symbols, and the values of them are distributed in `{0, ..., dim-1}`. The PDF ( shape (`symsNum,dim`) ) must specified the value for `symsNum` symbols.
|
62 |
+
|
63 |
+
**Example**:
|
64 |
+
|
65 |
+
```
|
66 |
+
For a symsNum = 1 particular symbol, let's say we have dim = 3 possible values.
|
67 |
+
We can draw 4 CDF from 3 PDF to specify the symbols distribution:
|
68 |
+
|
69 |
+
symbol: 0 1 2
|
70 |
+
pdf: P(0) P(1) P(2)
|
71 |
+
cdf: C_0 C_1 C_2 C_3
|
72 |
+
|
73 |
+
This corresponds to the 3 probabilities
|
74 |
+
|
75 |
+
P(0) = C_1 - C_0
|
76 |
+
P(1) = C_2 - C_1
|
77 |
+
P(2) = C_3 - C_2
|
78 |
+
|
79 |
+
where PDF =[[ P(0), P(1) ,P(2) ]]
|
80 |
+
NOTE: The arithmetic coder assumes that P(0) + P(1) + P(2) = 1, C_0 = 0, C_3 = 1
|
81 |
+
```
|
82 |
+
The theoretical bits number can estimated by Shannon’s source coding theorem:
|
83 |
+
$\sum_{s}-log_2P(s)$
|
84 |
+
## Citation
|
85 |
+
Reference from [torchac](https://github.com/fab-jul/torchac), thanks!
|
numpyAc/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from numpyAc.numpyAc import arithmeticCoding,arithmeticDeCoding
|
numpyAc/backend/numpyAc_backend.cpp
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/**
|
2 |
+
* COPYRIGHT 2020 ETH Zurich
|
3 |
+
* BASED on
|
4 |
+
*
|
5 |
+
* https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html
|
6 |
+
*/
|
7 |
+
|
8 |
+
#include <torch/extension.h>
|
9 |
+
|
10 |
+
#include <iostream>
|
11 |
+
#include <vector>
|
12 |
+
#include <tuple>
|
13 |
+
#include <fstream>
|
14 |
+
#include <algorithm>
|
15 |
+
#include <string>
|
16 |
+
#include <chrono>
|
17 |
+
#include <numeric>
|
18 |
+
#include <iterator>
|
19 |
+
|
20 |
+
#include <bitset>
|
21 |
+
|
22 |
+
using cdf_t = uint16_t;
|
23 |
+
|
24 |
+
/** Encapsulates a pointer to a CDF tensor */
|
25 |
+
struct cdf_ptr {
|
26 |
+
cdf_t* data; // expected to be a N_sym x Lp matrix, stored in row major.
|
27 |
+
const int N_sym; // Number of symbols stored by `data`.
|
28 |
+
const int Lp; // == L+1, where L is the number of possible values a symbol can take.
|
29 |
+
cdf_ptr(cdf_t* data,
|
30 |
+
const int N_sym,
|
31 |
+
const int Lp) : data(data), N_sym(N_sym), Lp(Lp) {};
|
32 |
+
};
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
/** Class to save output bit by bit to a byte string */
|
37 |
+
class OutCacheString {
|
38 |
+
private:
|
39 |
+
public:
|
40 |
+
std::string out="";
|
41 |
+
uint8_t cache=0;
|
42 |
+
uint8_t count=0;
|
43 |
+
void append(const int bit) {
|
44 |
+
cache <<= 1;
|
45 |
+
cache |= bit;
|
46 |
+
count += 1;
|
47 |
+
if (count == 8) {
|
48 |
+
out.append(reinterpret_cast<const char *>(&cache), 1);
|
49 |
+
count = 0;
|
50 |
+
}
|
51 |
+
}
|
52 |
+
void flush() {
|
53 |
+
if (count > 0) {
|
54 |
+
for (int i = count; i < 8; ++i) {
|
55 |
+
append(0);
|
56 |
+
}
|
57 |
+
assert(count==0);
|
58 |
+
}
|
59 |
+
}
|
60 |
+
void append_bit_and_pending(const int bit, uint64_t &pending_bits) {
|
61 |
+
append(bit);
|
62 |
+
while (pending_bits > 0) {
|
63 |
+
append(!bit);
|
64 |
+
pending_bits -= 1;
|
65 |
+
}
|
66 |
+
}
|
67 |
+
};
|
68 |
+
|
69 |
+
/** Class to read byte string bit by bit */
|
70 |
+
class InCacheString {
|
71 |
+
private:
|
72 |
+
const std::string in_;
|
73 |
+
|
74 |
+
public:
|
75 |
+
explicit InCacheString(const std::string& in) : in_(in) {};
|
76 |
+
|
77 |
+
uint8_t cache=0;
|
78 |
+
uint8_t cached_bits=0;
|
79 |
+
size_t in_ptr=0;
|
80 |
+
|
81 |
+
void get(uint32_t& value) {
|
82 |
+
|
83 |
+
if (cached_bits == 0) {
|
84 |
+
if (in_ptr == in_.size()){
|
85 |
+
value <<= 1;
|
86 |
+
return;
|
87 |
+
}
|
88 |
+
/// Read 1 byte
|
89 |
+
|
90 |
+
cache = (uint8_t) in_[in_ptr];
|
91 |
+
in_ptr++;
|
92 |
+
cached_bits = 8;
|
93 |
+
}
|
94 |
+
value <<= 1;
|
95 |
+
value |= (cache >> (cached_bits - 1)) & 1;
|
96 |
+
cached_bits--;
|
97 |
+
}
|
98 |
+
|
99 |
+
void initialize(uint32_t& value) {
|
100 |
+
for (int i = 0; i < 32; ++i) {
|
101 |
+
get(value);
|
102 |
+
}
|
103 |
+
}
|
104 |
+
};
|
105 |
+
|
106 |
+
|
107 |
+
//------------------------------------------------------------------------------
|
108 |
+
|
109 |
+
|
110 |
+
cdf_t binsearch(py::list &cdf, cdf_t target, cdf_t max_sym,
|
111 |
+
const int offset) /* i * Lp */
|
112 |
+
{
|
113 |
+
cdf_t left = 0;
|
114 |
+
cdf_t right = max_sym + 1; // len(cdf) == max_sym + 2
|
115 |
+
|
116 |
+
while (left + 1 < right) { // ?
|
117 |
+
// Left and right will be < 0x10000 in practice,
|
118 |
+
// so left+right fits in uint16_t.
|
119 |
+
const auto m = static_cast<const cdf_t>((left + right) / 2);
|
120 |
+
const auto v = cdf[offset + m].cast<cdf_t>();
|
121 |
+
if (v < target) {
|
122 |
+
left = m;
|
123 |
+
} else if (v > target) {
|
124 |
+
right = m;
|
125 |
+
} else {
|
126 |
+
return m;
|
127 |
+
}
|
128 |
+
}
|
129 |
+
|
130 |
+
return left;
|
131 |
+
}
|
132 |
+
|
133 |
+
|
134 |
+
class decode
|
135 |
+
{
|
136 |
+
private:
|
137 |
+
|
138 |
+
public:
|
139 |
+
int dataID=0;
|
140 |
+
const int Lp;// To calculate offset
|
141 |
+
const int N_sym;// To know the # of syms to decode. Is encoded in the stream!
|
142 |
+
const int max_symbol;
|
143 |
+
uint32_t low = 0;
|
144 |
+
uint32_t high = 0xFFFFFFFFU;
|
145 |
+
const uint32_t c_count = 0x10000U;
|
146 |
+
const int precision = 16;
|
147 |
+
cdf_t sym_i = 0;
|
148 |
+
uint32_t value = 0;
|
149 |
+
InCacheString in_cache;
|
150 |
+
decode(const std::string &in, const int&sysNum_,const int&sysNumDim_):in_cache(in),N_sym(sysNum_),Lp(sysNumDim_),max_symbol(sysNumDim_-2){
|
151 |
+
in_cache.initialize(value);
|
152 |
+
|
153 |
+
};
|
154 |
+
|
155 |
+
int16_t decodeAsym(py::list cdf) {
|
156 |
+
|
157 |
+
|
158 |
+
for (; dataID < N_sym; ++dataID) {
|
159 |
+
|
160 |
+
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
161 |
+
// always < 0x10000 ???
|
162 |
+
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span;
|
163 |
+
|
164 |
+
int offset = 0;
|
165 |
+
|
166 |
+
sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset);
|
167 |
+
|
168 |
+
|
169 |
+
if (dataID == N_sym-1) {
|
170 |
+
break;
|
171 |
+
}
|
172 |
+
|
173 |
+
|
174 |
+
const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>();
|
175 |
+
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>();
|
176 |
+
|
177 |
+
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
178 |
+
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
179 |
+
|
180 |
+
while (true) {
|
181 |
+
if (low >= 0x80000000U || high < 0x80000000U) {
|
182 |
+
low <<= 1;
|
183 |
+
high <<= 1;
|
184 |
+
high |= 1;
|
185 |
+
|
186 |
+
in_cache.get(value);
|
187 |
+
|
188 |
+
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
189 |
+
/**
|
190 |
+
* 0100 0000 ... <= value < 1100 0000 ...
|
191 |
+
* <=>
|
192 |
+
* 0100 0000 ... <= value <= 1011 1111 ...
|
193 |
+
* <=>
|
194 |
+
* value starts with 01 or 10.
|
195 |
+
* 01 - 01 == 00 | 10 - 01 == 01
|
196 |
+
* i.e., with shifts
|
197 |
+
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in
|
198 |
+
* near convergence
|
199 |
+
*/
|
200 |
+
low <<= 1;
|
201 |
+
low &= 0x7FFFFFFFU; // make MSB 0
|
202 |
+
high <<= 1;
|
203 |
+
high |= 0x80000001U; // add 1 at the end, retain MSB = 1
|
204 |
+
value -= 0x40000000U;
|
205 |
+
|
206 |
+
in_cache.get(value);
|
207 |
+
|
208 |
+
} else {
|
209 |
+
break;
|
210 |
+
}
|
211 |
+
}
|
212 |
+
|
213 |
+
return (int16_t)sym_i;
|
214 |
+
}
|
215 |
+
}
|
216 |
+
|
217 |
+
};
|
218 |
+
|
219 |
+
const void check_sym(const torch::Tensor& sym) {
|
220 |
+
TORCH_CHECK(sym.sizes().size() == 1,
|
221 |
+
"Invalid size for sym. Expected just 1 dim.")
|
222 |
+
}
|
223 |
+
|
224 |
+
/** Get an instance of the `cdf_ptr` struct. */
|
225 |
+
const struct cdf_ptr get_cdf_ptr(const torch::Tensor& cdf)
|
226 |
+
{
|
227 |
+
TORCH_CHECK(!cdf.is_cuda(), "cdf must be on CPU!")
|
228 |
+
const auto s = cdf.sizes();
|
229 |
+
TORCH_CHECK(s.size() == 2, "Invalid size for cdf! Expected (N, Lp)")
|
230 |
+
|
231 |
+
const int N_sym = s[0];
|
232 |
+
const int Lp = s[1];
|
233 |
+
const auto cdf_acc = cdf.accessor<int16_t, 2>();
|
234 |
+
cdf_t* cdf_ptr = (uint16_t*)cdf_acc.data();
|
235 |
+
|
236 |
+
const struct cdf_ptr res(cdf_ptr, N_sym, Lp);
|
237 |
+
return res;
|
238 |
+
}
|
239 |
+
|
240 |
+
|
241 |
+
// -----------------------------------------------------------------------------
|
242 |
+
|
243 |
+
|
244 |
+
/** Encode symbols `sym` with CDF represented by `cdf_ptr`. NOTE: this is not exposted to python. */
|
245 |
+
py::bytes encode(
|
246 |
+
const cdf_ptr& cdf_ptr,
|
247 |
+
const torch::Tensor& sym){
|
248 |
+
|
249 |
+
OutCacheString out_cache;
|
250 |
+
|
251 |
+
uint32_t low = 0;
|
252 |
+
uint32_t high = 0xFFFFFFFFU;
|
253 |
+
uint64_t pending_bits = 0;
|
254 |
+
|
255 |
+
const int precision = 16;
|
256 |
+
|
257 |
+
const cdf_t* cdf = cdf_ptr.data;
|
258 |
+
const int N_sym = cdf_ptr.N_sym;
|
259 |
+
const int Lp = cdf_ptr.Lp;
|
260 |
+
const int max_symbol = Lp - 2;
|
261 |
+
|
262 |
+
auto sym_ = sym.accessor<int16_t, 1>();
|
263 |
+
|
264 |
+
for (int i = 0; i < N_sym; ++i) {
|
265 |
+
const int16_t sym_i = sym_[i];
|
266 |
+
|
267 |
+
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1;
|
268 |
+
|
269 |
+
const int offset = i * Lp;
|
270 |
+
// Left boundary is at offset + sym_i
|
271 |
+
const uint32_t c_low = cdf[offset + sym_i];
|
272 |
+
// Right boundary is at offset + sym_i + 1, except for the `max_symbol`
|
273 |
+
// For which we hardcode the maxvalue. So if e.g.
|
274 |
+
// L == 4, it means that Lp == 5, and the allowed symbols are
|
275 |
+
// {0, 1, 2, 3}. The max symbol is thus Lp - 2 == 3. It's probability
|
276 |
+
// is then given by c_max - cdf[-2].
|
277 |
+
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1];
|
278 |
+
|
279 |
+
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision);
|
280 |
+
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision);
|
281 |
+
|
282 |
+
while (true) {
|
283 |
+
if (high < 0x80000000U) {
|
284 |
+
out_cache.append_bit_and_pending(0, pending_bits);
|
285 |
+
low <<= 1;
|
286 |
+
high <<= 1;
|
287 |
+
high |= 1;
|
288 |
+
} else if (low >= 0x80000000U) {
|
289 |
+
out_cache.append_bit_and_pending(1, pending_bits);
|
290 |
+
low <<= 1;
|
291 |
+
high <<= 1;
|
292 |
+
high |= 1;
|
293 |
+
} else if (low >= 0x40000000U && high < 0xC0000000U) {
|
294 |
+
pending_bits++;
|
295 |
+
low <<= 1;
|
296 |
+
low &= 0x7FFFFFFF;
|
297 |
+
high <<= 1;
|
298 |
+
high |= 0x80000001;
|
299 |
+
} else {
|
300 |
+
break;
|
301 |
+
}
|
302 |
+
}
|
303 |
+
}
|
304 |
+
|
305 |
+
pending_bits += 1;
|
306 |
+
|
307 |
+
if (pending_bits) {
|
308 |
+
if (low < 0x40000000U) {
|
309 |
+
out_cache.append_bit_and_pending(0, pending_bits);
|
310 |
+
} else {
|
311 |
+
out_cache.append_bit_and_pending(1, pending_bits);
|
312 |
+
}
|
313 |
+
}
|
314 |
+
|
315 |
+
out_cache.flush();
|
316 |
+
|
317 |
+
#ifdef VERBOSE
|
318 |
+
std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now();
|
319 |
+
std::cout << "Time difference (sec) = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) /1000000.0 <<std::endl;
|
320 |
+
#endif
|
321 |
+
|
322 |
+
return py::bytes(out_cache.out);
|
323 |
+
}
|
324 |
+
|
325 |
+
|
326 |
+
/** See torchac.py */
|
327 |
+
py::bytes encode_cdf(
|
328 |
+
const torch::Tensor& cdf, /* NHWLp, must be on CPU! */
|
329 |
+
const torch::Tensor& sym)
|
330 |
+
{
|
331 |
+
check_sym(sym);
|
332 |
+
const auto cdf_ptr = get_cdf_ptr(cdf);
|
333 |
+
return encode(cdf_ptr, sym);
|
334 |
+
}
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
340 |
+
m.def("encode_cdf", &encode_cdf, "Encode from CDF");
|
341 |
+
|
342 |
+
py::class_<decode>(m, "decode")
|
343 |
+
.def(py::init([] (const std::string in, const int&sysNum_,const int&sysNumDim_) {
|
344 |
+
return new decode(in,sysNum_,sysNumDim_);
|
345 |
+
}))
|
346 |
+
.def("decodeAsym", &decode::decodeAsym);
|
347 |
+
}
|
numpyAc/numpyAc.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch.autograd.grad_mode import F
|
5 |
+
from torch.utils.cpp_extension import load
|
6 |
+
|
7 |
+
|
8 |
+
PRECISION = 16 # DO NOT EDIT!
|
9 |
+
|
10 |
+
|
11 |
+
# Load on-the-fly with ninja.
|
12 |
+
torchac_dir = os.path.dirname(os.path.realpath(__file__))
|
13 |
+
backend_dir = os.path.join(torchac_dir, 'backend')
|
14 |
+
numpyAc_backend = load(
|
15 |
+
name="numpyAc_backend",
|
16 |
+
sources=[os.path.join(backend_dir, "numpyAc_backend.cpp")],
|
17 |
+
verbose=False)
|
18 |
+
|
19 |
+
def _encode_float_cdf(cdf_float,
|
20 |
+
sym,
|
21 |
+
needs_normalization=True,
|
22 |
+
check_input_bounds=False):
|
23 |
+
"""Encode symbols `sym` with potentially unnormalized floating point CDF.
|
24 |
+
|
25 |
+
Check the README for more details.
|
26 |
+
|
27 |
+
:param cdf_float: CDF tensor, float32, on CPU. Shape (N1, ..., Nm, Lp).
|
28 |
+
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
|
29 |
+
:param needs_normalization: if True, assume `cdf_float` is un-normalized and
|
30 |
+
needs normalization. Otherwise only convert it, without normalizing.
|
31 |
+
:param check_input_bounds: if True, ensure inputs have valid values.
|
32 |
+
Important: may take significant time. Only enable to check.
|
33 |
+
|
34 |
+
:return: byte-string, encoding `sym`.
|
35 |
+
"""
|
36 |
+
if check_input_bounds:
|
37 |
+
if cdf_float.min() < 0:
|
38 |
+
raise ValueError(f'cdf_float.min() == {cdf_float.min()}, should be >=0.!')
|
39 |
+
if cdf_float.max() > 1:
|
40 |
+
raise ValueError(f'cdf_float.max() == {cdf_float.max()}, should be <=1.!')
|
41 |
+
Lp = cdf_float.shape[-1]
|
42 |
+
if sym.max() >= Lp - 1:
|
43 |
+
raise ValueError(f'sym.max() == {sym.max()}, should be <=Lp - 1.!')
|
44 |
+
cdf_int = _convert_to_int_and_normalize(cdf_float, needs_normalization)
|
45 |
+
return _encode_int16_normalized_cdf(cdf_int, sym)
|
46 |
+
|
47 |
+
|
48 |
+
def _encode_int16_normalized_cdf(cdf_int, sym):
|
49 |
+
"""Encode symbols `sym` with a normalized integer cdf `cdf_int`.
|
50 |
+
|
51 |
+
Check the README for more details.
|
52 |
+
|
53 |
+
:param cdf_int: CDF tensor, int16, on CPU. Shape (N1, ..., Nm, Lp).
|
54 |
+
:param sym: The symbols to encode, int16, on CPU. Shape (N1, ..., Nm).
|
55 |
+
|
56 |
+
:return: byte-string, encoding `sym`
|
57 |
+
"""
|
58 |
+
cdf_int, sym = _check_and_reshape_inputs(cdf_int, sym)
|
59 |
+
return numpyAc_backend.encode_cdf( torch.ShortTensor(cdf_int), torch.ShortTensor(sym))
|
60 |
+
|
61 |
+
|
62 |
+
def _check_and_reshape_inputs(cdf, sym=None):
|
63 |
+
"""Check device, dtype, and shapes."""
|
64 |
+
if sym is not None and sym.dtype != np.int16:
|
65 |
+
raise ValueError('Symbols must be int16!')
|
66 |
+
if sym is not None:
|
67 |
+
if len(cdf.shape) != len(sym.shape) + 1 or cdf.shape[:-1] != sym.shape:
|
68 |
+
raise ValueError(f'Invalid shapes of cdf={cdf.shape}, sym={sym.shape}! '
|
69 |
+
'The first m elements of cdf.shape must be equal to '
|
70 |
+
'sym.shape, and cdf should only have one more dimension.')
|
71 |
+
Lp = cdf.shape[-1]
|
72 |
+
cdf = cdf.reshape(-1, Lp)
|
73 |
+
if sym is None:
|
74 |
+
return cdf
|
75 |
+
sym = sym.reshape(-1)
|
76 |
+
return cdf, sym
|
77 |
+
|
78 |
+
|
79 |
+
# def _reshape_output(cdf_shape, sym):
|
80 |
+
# """Reshape single dimension `sym` back to the correct spatial dimensions."""
|
81 |
+
# spatial_dimensions = cdf_shape[:-1]
|
82 |
+
# if len(sym) != np.prod(spatial_dimensions):
|
83 |
+
# raise ValueError()
|
84 |
+
# return sym.reshape(*spatial_dimensions)
|
85 |
+
|
86 |
+
|
87 |
+
def _convert_to_int_and_normalize(cdf_float, needs_normalization):
|
88 |
+
"""Convert floatingpoint CDF to integers. See README for more info.
|
89 |
+
|
90 |
+
The idea is the following:
|
91 |
+
When we get the cdf here, it is (assumed to be) between 0 and 1, i.e,
|
92 |
+
cdf \in [0, 1)
|
93 |
+
(note that 1 should not be included.)
|
94 |
+
We now want to convert this to int16 but make sure we do not get
|
95 |
+
the same value twice, as this would break the arithmetic coder
|
96 |
+
(you need a strictly monotonically increasing function).
|
97 |
+
So, if needs_normalization==True, we multiply the input CDF
|
98 |
+
with 2**16 - (Lp - 1). This means that now,
|
99 |
+
cdf \in [0, 2**16 - (Lp - 1)].
|
100 |
+
Then, in a final step, we add an arange(Lp), which is just a line with
|
101 |
+
slope one. This ensure that for sure, we will get unique, strictly
|
102 |
+
monotonically increasing CDFs, which are \in [0, 2**16)
|
103 |
+
"""
|
104 |
+
Lp = cdf_float.shape[-1]
|
105 |
+
factor = 2**PRECISION
|
106 |
+
new_max_value = factor
|
107 |
+
if needs_normalization:
|
108 |
+
new_max_value = new_max_value - (Lp - 1)
|
109 |
+
cdf_float = cdf_float*(new_max_value)
|
110 |
+
cdf_float = np.round(cdf_float)
|
111 |
+
cdf = cdf_float.astype(np.int16)
|
112 |
+
if needs_normalization:
|
113 |
+
r = np.arange(Lp)
|
114 |
+
cdf+=r
|
115 |
+
return cdf
|
116 |
+
|
117 |
+
def pdf_convert_to_cdf_and_normalize(pdf):
|
118 |
+
assert pdf.ndim==2
|
119 |
+
pdf = pdf / (np.sum(pdf,1,keepdims=True))/(1+10**(-10))
|
120 |
+
cdfF = np.cumsum( pdf, axis=1)
|
121 |
+
cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
|
122 |
+
return cdfF
|
123 |
+
|
124 |
+
class arithmeticCoding():
|
125 |
+
def __init__(self) -> None:
|
126 |
+
self.binfile = None
|
127 |
+
self.sysNum = None
|
128 |
+
self.byte_stream = None
|
129 |
+
|
130 |
+
|
131 |
+
def encode(self,pdf,sym,binfile=None):
|
132 |
+
assert pdf.shape[0]==sym.shape[0]
|
133 |
+
assert pdf.ndim==2 and sym.ndim==1
|
134 |
+
|
135 |
+
self.sysNum = sym.shape[0]
|
136 |
+
|
137 |
+
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
138 |
+
|
139 |
+
# pdf = np.diff(cdfF)
|
140 |
+
# print( -np.log2(pdf[range(0,self.sysNum),sym]).sum())
|
141 |
+
|
142 |
+
self.byte_stream = _encode_float_cdf(cdfF, sym, check_input_bounds=True)
|
143 |
+
real_bits = len(self.byte_stream) * 8
|
144 |
+
# # Write to a file.
|
145 |
+
if binfile is not None:
|
146 |
+
with open(binfile, 'wb') as fout:
|
147 |
+
fout.write(self.byte_stream)
|
148 |
+
return self.byte_stream,real_bits
|
149 |
+
|
150 |
+
class arithmeticDeCoding():
|
151 |
+
"""
|
152 |
+
Decoding class
|
153 |
+
byte_stream: the bin file stream.
|
154 |
+
sysNum: the Number of symbols that you are going to decode. This value should be
|
155 |
+
saved in other ways.
|
156 |
+
sysDim: the Number of the possible symbols.
|
157 |
+
binfile: bin file path, if it is Not None, 'byte_stream' will read from this file
|
158 |
+
and copy to Cpp backend Class 'InCacheString'
|
159 |
+
"""
|
160 |
+
def __init__(self,byte_stream,sysNum,symDim,binfile=None) -> None:
|
161 |
+
if binfile is not None:
|
162 |
+
with open(binfile, 'rb') as fin:
|
163 |
+
byte_stream = fin.read()
|
164 |
+
self.byte_stream = byte_stream
|
165 |
+
self.decoder = numpyAc_backend.decode(self.byte_stream,sysNum,symDim+1)
|
166 |
+
|
167 |
+
def decode(self,pdf):
|
168 |
+
cdfF = pdf_convert_to_cdf_and_normalize(pdf)
|
169 |
+
pro = _convert_to_int_and_normalize(cdfF,needs_normalization=True)
|
170 |
+
pro = pro.squeeze(0).astype(np.uint16).tolist()
|
171 |
+
sym_out = self.decoder.decodeAsym(pro)
|
172 |
+
return sym_out
|
test.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpyAc
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# Generate random symbols and pdf.
|
5 |
+
dim = 128
|
6 |
+
symsNum = 2000
|
7 |
+
pdf = np.random.rand(symsNum,dim)
|
8 |
+
pdf = pdf / (np.sum(pdf,1,keepdims=True))
|
9 |
+
sym = np.random.randint(0,dim,symsNum,dtype=np.int16)
|
10 |
+
output_pdf = pdf
|
11 |
+
|
12 |
+
# Encode to bytestream.
|
13 |
+
codec = numpyAc.arithmeticCoding()
|
14 |
+
byte_stream,real_bits = codec.encode(pdf, sym,'out.b')
|
15 |
+
|
16 |
+
# Number of bits taken by the stream.
|
17 |
+
print('real_bits',real_bits)
|
18 |
+
|
19 |
+
# Theoretical bits number
|
20 |
+
print('shannon entropy',-int(np.log2(pdf[range(0,symsNum),sym]).sum()))
|
21 |
+
|
22 |
+
# Decode from bytestream.
|
23 |
+
decodec = numpyAc.arithmeticDeCoding(None,symsNum,dim,'out.b')
|
24 |
+
|
25 |
+
# Autoregressive decoding and output will be equal to the input.
|
26 |
+
for i,s in enumerate(sym):
|
27 |
+
assert decodec.decode(output_pdf[i:i+1,:]) == s
|
testTorchac.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
LastEditors: fcy
|
3 |
+
'''
|
4 |
+
import torchac
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
# Encode to bytestream.
|
8 |
+
|
9 |
+
seed=6
|
10 |
+
torch.manual_seed(seed)
|
11 |
+
np.random.seed(seed)
|
12 |
+
|
13 |
+
dim = 500
|
14 |
+
symsNum = 40000
|
15 |
+
pdf = np.random.rand(symsNum,dim)
|
16 |
+
pdf = pdf / (np.sum(pdf,1,keepdims=True))
|
17 |
+
sym = torch.ShortTensor(np.random.randint(0,dim,symsNum,dtype=np.int16))
|
18 |
+
|
19 |
+
def pdf_convert_to_cdf_and_normalize(pdf):
|
20 |
+
assert pdf.ndim==2
|
21 |
+
pdf = pdf / (np.sum(pdf,1,keepdims=True))/(1+10**(-10))
|
22 |
+
cdfF = np.cumsum( pdf, axis=1)
|
23 |
+
cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
|
24 |
+
return cdfF
|
25 |
+
|
26 |
+
|
27 |
+
output_cdf = torch.Tensor(pdf_convert_to_cdf_and_normalize(pdf)) # Get CDF from your model, shape B, C, H, W, Lp
|
28 |
+
|
29 |
+
byte_stream = torchac.encode_float_cdf(output_cdf, sym, check_input_bounds=True)
|
30 |
+
|
31 |
+
|
32 |
+
# pdf = np.diff(cdfF)
|
33 |
+
# print( -np.log2(pdf[range(0,oct_len),sym]).sum())
|
34 |
+
|
35 |
+
# Number of bits taken by the stream
|
36 |
+
real_bits = len(byte_stream) * 8
|
37 |
+
print(real_bits)
|
38 |
+
# Write to a file.
|
39 |
+
with open('outfile.b', 'wb') as fout:
|
40 |
+
fout.write(byte_stream)
|
41 |
+
|
42 |
+
# Read from a file.
|
43 |
+
with open('outfile.b', 'rb') as fin:
|
44 |
+
byte_stream = fin.read()
|
45 |
+
|
46 |
+
# Decode from bytestream.
|
47 |
+
sym_out = torchac.decode_float_cdf(output_cdf, byte_stream)
|
48 |
+
|
49 |
+
# Output will be equal to the input.
|
50 |
+
assert sym_out.equal(sym)
|