zb12138 commited on
Commit
a965cad
·
1 Parent(s): e389f7b
Files changed (1) hide show
  1. testTorchac.py +0 -50
testTorchac.py DELETED
@@ -1,50 +0,0 @@
1
- '''
2
- LastEditors: fcy
3
- '''
4
- import torchac
5
- import torch
6
- import numpy as np
7
- # Encode to bytestream.
8
-
9
- seed=6
10
- torch.manual_seed(seed)
11
- np.random.seed(seed)
12
-
13
- dim = 500
14
- symsNum = 40000
15
- pdf = np.random.rand(symsNum,dim)
16
- pdf = pdf / (np.sum(pdf,1,keepdims=True))
17
- sym = torch.ShortTensor(np.random.randint(0,dim,symsNum,dtype=np.int16))
18
-
19
- def pdf_convert_to_cdf_and_normalize(pdf):
20
- assert pdf.ndim==2
21
- pdf = pdf / (np.sum(pdf,1,keepdims=True))/(1+10**(-10))
22
- cdfF = np.cumsum( pdf, axis=1)
23
- cdfF = np.hstack((np.zeros((pdf.shape[0],1)),cdfF))
24
- return cdfF
25
-
26
-
27
- output_cdf = torch.Tensor(pdf_convert_to_cdf_and_normalize(pdf)) # Get CDF from your model, shape B, C, H, W, Lp
28
-
29
- byte_stream = torchac.encode_float_cdf(output_cdf, sym, check_input_bounds=True)
30
-
31
-
32
- # pdf = np.diff(cdfF)
33
- # print( -np.log2(pdf[range(0,oct_len),sym]).sum())
34
-
35
- # Number of bits taken by the stream
36
- real_bits = len(byte_stream) * 8
37
- print(real_bits)
38
- # Write to a file.
39
- with open('outfile.b', 'wb') as fout:
40
- fout.write(byte_stream)
41
-
42
- # Read from a file.
43
- with open('outfile.b', 'rb') as fin:
44
- byte_stream = fin.read()
45
-
46
- # Decode from bytestream.
47
- sym_out = torchac.decode_float_cdf(output_cdf, byte_stream)
48
-
49
- # Output will be equal to the input.
50
- assert sym_out.equal(sym)