pentarosarium commited on
Commit
2e78fc6
·
1 Parent(s): fa80eae

progress %

Browse files
Files changed (1) hide show
  1. app.py +26 -13
app.py CHANGED
@@ -41,19 +41,32 @@ def translate(text):
41
  # Tokenize the input text
42
  inputs = translation_tokenizer(text, return_tensors="pt", truncation=True)
43
 
44
- # Set up a simple spinner
45
- with tqdm(total=0, bar_format='{desc}', desc="Translating...") as pbar:
46
- # Generate translation
47
- translated_tokens = translation_model.generate(
48
- **inputs,
49
- num_beams=5,
50
- max_length=len(text.split()) * 2, # Adjust as needed
51
- no_repeat_ngram_size=2,
52
- early_stopping=True
53
- )
54
-
55
- # Update the spinner description to show completion
56
- pbar.set_description_str("Translation completed")
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Decode the translated tokens
59
  translated_text = translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
 
41
  # Tokenize the input text
42
  inputs = translation_tokenizer(text, return_tensors="pt", truncation=True)
43
 
44
+ # Calculate max_length based on input length (you may need to adjust this ratio)
45
+ max_length = min(512, int(inputs.input_ids.shape[1] * 1.5))
46
+
47
+ # Calculate max_new_tokens
48
+ max_new_tokens = max_length - inputs.input_ids.shape[1]
49
+
50
+ # Set up the progress bar
51
+ pbar = tqdm(total=max_new_tokens, desc="Translating", unit="token")
52
+
53
+ # Custom callback to update the progress bar
54
+ def update_progress_bar(beam_idx, token_idx, token):
55
+ pbar.update(1)
56
+
57
+ # Generate translation with progress updates
58
+ translated_tokens = translation_model.generate(
59
+ **inputs,
60
+ max_length=max_length,
61
+ num_beams=5,
62
+ no_repeat_ngram_size=2,
63
+ early_stopping=True,
64
+ callback=update_progress_bar,
65
+ callback_steps=1
66
+ )
67
+
68
+ # Close the progress bar
69
+ pbar.close()
70
 
71
  # Decode the translated tokens
72
  translated_text = translation_tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]