tomaszki commited on
Commit
42b3383
·
1 Parent(s): e2dbcf5

Add progress bar for encoding

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -28,12 +28,16 @@ encode_col, decode_col = st.columns(2, gap='medium')
28
 
29
  @st.cache_data
30
  def encode(text):
 
31
  codec = numpyAc.arithmeticCoding()
32
  tokenized = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
33
  output = list()
34
  past_key_values = None
35
 
 
 
36
  for i in range(tokenized.shape[1]):
 
37
  with torch.no_grad():
38
  output_ = model(
39
  input_ids=tokenized[:, i:i + 1],
@@ -52,6 +56,7 @@ def encode(text):
52
 
53
  @st.cache_data
54
  def decode(byte_stream):
 
55
  decodec = numpyAc.arithmeticDeCoding(byte_stream, 32_000)
56
  input_ids = [1]
57
  past_key_values = None
 
28
 
29
  @st.cache_data
30
  def encode(text):
31
+ bar = st.progress(0.0)
32
  codec = numpyAc.arithmeticCoding()
33
  tokenized = tokenizer(text, return_tensors='pt').input_ids.to('cuda')
34
  output = list()
35
  past_key_values = None
36
 
37
+ # We can't run a single pass over all tokens, because
38
+ # we get inconsistent results then
39
  for i in range(tokenized.shape[1]):
40
+ bar.progress((i + 1) / tokenized.shape[1])
41
  with torch.no_grad():
42
  output_ = model(
43
  input_ids=tokenized[:, i:i + 1],
 
56
 
57
  @st.cache_data
58
  def decode(byte_stream):
59
+ # Unfortunately progressbar for decoding isn't possible/is hard
60
  decodec = numpyAc.arithmeticDeCoding(byte_stream, 32_000)
61
  input_ids = [1]
62
  past_key_values = None