Spaces:
Sleeping
Sleeping
/** | |
* COPYRIGHT 2020 ETH Zurich | |
* BASED on | |
* | |
* https://marknelson.us/posts/2014/10/19/data-compression-with-arithmetic-coding.html | |
*/ | |
using cdf_t = uint16_t; | |
/** Encapsulates a pointer to a CDF tensor */ | |
struct cdf_ptr { | |
cdf_t* data; // expected to be a N_sym x Lp matrix, stored in row major. | |
const int N_sym; // Number of symbols stored by `data`. | |
const int Lp; // == L+1, where L is the number of possible values a symbol can take. | |
cdf_ptr(cdf_t* data, | |
const int N_sym, | |
const int Lp) : data(data), N_sym(N_sym), Lp(Lp) {}; | |
}; | |
/** Class to save output bit by bit to a byte string */ | |
class OutCacheString { | |
private: | |
public: | |
std::string out=""; | |
uint8_t cache=0; | |
uint8_t count=0; | |
void append(const int bit) { | |
cache <<= 1; | |
cache |= bit; | |
count += 1; | |
if (count == 8) { | |
out.append(reinterpret_cast<const char *>(&cache), 1); | |
count = 0; | |
} | |
} | |
void flush() { | |
if (count > 0) { | |
for (int i = count; i < 8; ++i) { | |
append(0); | |
} | |
assert(count==0); | |
} | |
} | |
void append_bit_and_pending(const int bit, uint64_t &pending_bits) { | |
append(bit); | |
while (pending_bits > 0) { | |
append(!bit); | |
pending_bits -= 1; | |
} | |
} | |
}; | |
/** Class to read byte string bit by bit */ | |
class InCacheString { | |
private: | |
const std::string in_; | |
public: | |
explicit InCacheString(const std::string& in) : in_(in) {}; | |
uint8_t cache=0; | |
uint8_t cached_bits=0; | |
size_t in_ptr=0; | |
void get(uint32_t& value) { | |
if (cached_bits == 0) { | |
if (in_ptr == in_.size()){ | |
value <<= 1; | |
return; | |
} | |
/// Read 1 byte | |
cache = (uint8_t) in_[in_ptr]; | |
in_ptr++; | |
cached_bits = 8; | |
} | |
value <<= 1; | |
value |= (cache >> (cached_bits - 1)) & 1; | |
cached_bits--; | |
} | |
void initialize(uint32_t& value) { | |
for (int i = 0; i < 32; ++i) { | |
get(value); | |
} | |
} | |
}; | |
//------------------------------------------------------------------------------ | |
cdf_t binsearch(py::list &cdf, cdf_t target, cdf_t max_sym, | |
const int offset) /* i * Lp */ | |
{ | |
cdf_t left = 0; | |
cdf_t right = max_sym + 1; // len(cdf) == max_sym + 2 | |
while (left + 1 < right) { // ? | |
// Left and right will be < 0x10000 in practice, | |
// so left+right fits in uint16_t. | |
const auto m = static_cast<const cdf_t>((left + right) / 2); | |
const auto v = cdf[offset + m].cast<cdf_t>(); | |
if (v < target) { | |
left = m; | |
} else if (v > target) { | |
right = m; | |
} else { | |
return m; | |
} | |
} | |
return left; | |
} | |
class decode | |
{ | |
private: | |
public: | |
int dataID=0; | |
const int Lp;// To calculate offset | |
const int max_symbol; | |
uint32_t low = 0; | |
uint32_t high = 0xFFFFFFFFU; | |
const uint32_t c_count = 0x10000U; | |
const int precision = 16; | |
cdf_t sym_i = 0; | |
uint32_t value = 0; | |
InCacheString in_cache; | |
decode(const std::string &in, const int&sysNumDim_):in_cache(in),Lp(sysNumDim_),max_symbol(sysNumDim_-2){ | |
in_cache.initialize(value); | |
}; | |
int16_t decodeAsym(py::list cdf) { | |
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1; | |
// always < 0x10000 ??? | |
const uint16_t count = ((static_cast<uint64_t>(value) - static_cast<uint64_t>(low) + 1) * c_count - 1) / span; | |
int offset = 0; | |
sym_i = binsearch(cdf, count, (cdf_t)max_symbol, offset); | |
const uint32_t c_low = cdf[offset + sym_i].cast<cdf_t>(); | |
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1].cast<cdf_t>(); | |
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision); | |
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision); | |
while (true) { | |
if (low >= 0x80000000U || high < 0x80000000U) { | |
low <<= 1; | |
high <<= 1; | |
high |= 1; | |
in_cache.get(value); | |
} else if (low >= 0x40000000U && high < 0xC0000000U) { | |
/** | |
* 0100 0000 ... <= value < 1100 0000 ... | |
* <=> | |
* 0100 0000 ... <= value <= 1011 1111 ... | |
* <=> | |
* value starts with 01 or 10. | |
* 01 - 01 == 00 | 10 - 01 == 01 | |
* i.e., with shifts | |
* 01A -> 0A or 10A -> 1A, i.e., discard 2SB as it's all the same while we are in | |
* near convergence | |
*/ | |
low <<= 1; | |
low &= 0x7FFFFFFFU; // make MSB 0 | |
high <<= 1; | |
high |= 0x80000001U; // add 1 at the end, retain MSB = 1 | |
value -= 0x40000000U; | |
in_cache.get(value); | |
} else { | |
break; | |
} | |
} | |
return (int16_t)sym_i; | |
} | |
}; | |
const void check_sym(const torch::Tensor& sym) { | |
TORCH_CHECK(sym.sizes().size() == 1, | |
"Invalid size for sym. Expected just 1 dim.") | |
} | |
/** Get an instance of the `cdf_ptr` struct. */ | |
const struct cdf_ptr get_cdf_ptr(const torch::Tensor& cdf) | |
{ | |
TORCH_CHECK(!cdf.is_cuda(), "cdf must be on CPU!") | |
const auto s = cdf.sizes(); | |
TORCH_CHECK(s.size() == 2, "Invalid size for cdf! Expected (N, Lp)") | |
const int N_sym = s[0]; | |
const int Lp = s[1]; | |
const auto cdf_acc = cdf.accessor<int16_t, 2>(); | |
cdf_t* cdf_ptr = (uint16_t*)cdf_acc.data(); | |
const struct cdf_ptr res(cdf_ptr, N_sym, Lp); | |
return res; | |
} | |
// ----------------------------------------------------------------------------- | |
/** Encode symbols `sym` with CDF represented by `cdf_ptr`. NOTE: this is not exposted to python. */ | |
py::bytes encode( | |
const cdf_ptr& cdf_ptr, | |
const torch::Tensor& sym){ | |
OutCacheString out_cache; | |
uint32_t low = 0; | |
uint32_t high = 0xFFFFFFFFU; | |
uint64_t pending_bits = 0; | |
const int precision = 16; | |
const cdf_t* cdf = cdf_ptr.data; | |
const int N_sym = cdf_ptr.N_sym; | |
const int Lp = cdf_ptr.Lp; | |
const int max_symbol = Lp - 2; | |
auto sym_ = sym.accessor<int16_t, 1>(); | |
for (int i = 0; i < N_sym; ++i) { | |
const int16_t sym_i = sym_[i]; | |
const uint64_t span = static_cast<uint64_t>(high) - static_cast<uint64_t>(low) + 1; | |
const int offset = i * Lp; | |
// Left boundary is at offset + sym_i | |
const uint32_t c_low = cdf[offset + sym_i]; | |
// Right boundary is at offset + sym_i + 1, except for the `max_symbol` | |
// For which we hardcode the maxvalue. So if e.g. | |
// L == 4, it means that Lp == 5, and the allowed symbols are | |
// {0, 1, 2, 3}. The max symbol is thus Lp - 2 == 3. It's probability | |
// is then given by c_max - cdf[-2]. | |
const uint32_t c_high = sym_i == max_symbol ? 0x10000U : cdf[offset + sym_i + 1]; | |
high = (low - 1) + ((span * static_cast<uint64_t>(c_high)) >> precision); | |
low = (low) + ((span * static_cast<uint64_t>(c_low)) >> precision); | |
while (true) { | |
if (high < 0x80000000U) { | |
out_cache.append_bit_and_pending(0, pending_bits); | |
low <<= 1; | |
high <<= 1; | |
high |= 1; | |
} else if (low >= 0x80000000U) { | |
out_cache.append_bit_and_pending(1, pending_bits); | |
low <<= 1; | |
high <<= 1; | |
high |= 1; | |
} else if (low >= 0x40000000U && high < 0xC0000000U) { | |
pending_bits++; | |
low <<= 1; | |
low &= 0x7FFFFFFF; | |
high <<= 1; | |
high |= 0x80000001; | |
} else { | |
break; | |
} | |
} | |
} | |
pending_bits += 1; | |
if (pending_bits) { | |
if (low < 0x40000000U) { | |
out_cache.append_bit_and_pending(0, pending_bits); | |
} else { | |
out_cache.append_bit_and_pending(1, pending_bits); | |
} | |
} | |
out_cache.flush(); | |
std::chrono::steady_clock::time_point end= std::chrono::steady_clock::now(); | |
std::cout << "Time difference (sec) = " << (std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count()) /1000000.0 <<std::endl; | |
return py::bytes(out_cache.out); | |
} | |
/** See torchac.py */ | |
py::bytes encode_cdf( | |
const torch::Tensor& cdf, /* NHWLp, must be on CPU! */ | |
const torch::Tensor& sym) | |
{ | |
check_sym(sym); | |
const auto cdf_ptr = get_cdf_ptr(cdf); | |
return encode(cdf_ptr, sym); | |
} | |
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
m.def("encode_cdf", &encode_cdf, "Encode from CDF"); | |
py::class_<decode>(m, "decode") | |
.def(py::init([] (const std::string in, const int&sysNumDim_) { | |
return new decode(in,sysNumDim_); | |
})) | |
.def("decodeAsym", &decode::decodeAsym); | |
} | |