Update app.py
Browse files
app.py
CHANGED
@@ -89,13 +89,19 @@ def translate_text(text, source_lang, target_lang):
|
|
89 |
src_code = LANGUAGE_CODES.get(source_lang, "eng_Latn")
|
90 |
tgt_code = LANGUAGE_CODES.get(target_lang, "ara_Arab")
|
91 |
|
|
|
|
|
|
|
92 |
# Tokenize
|
93 |
inputs = translator_tokenizer(text, return_tensors="pt", padding=True)
|
94 |
|
95 |
-
#
|
|
|
|
|
|
|
96 |
translated_tokens = translator_model.generate(
|
97 |
**inputs,
|
98 |
-
forced_bos_token_id=
|
99 |
max_length=128
|
100 |
)
|
101 |
|
|
|
89 |
src_code = LANGUAGE_CODES.get(source_lang, "eng_Latn")
|
90 |
tgt_code = LANGUAGE_CODES.get(target_lang, "ara_Arab")
|
91 |
|
92 |
+
# Format target language token with double underscores according to NLLB format
|
93 |
+
tgt_token = f"__{tgt_code}__"
|
94 |
+
|
95 |
# Tokenize
|
96 |
inputs = translator_tokenizer(text, return_tensors="pt", padding=True)
|
97 |
|
98 |
+
# Get the token ID for the target language
|
99 |
+
forced_bos_token_id = translator_tokenizer.convert_tokens_to_ids(tgt_token)
|
100 |
+
|
101 |
+
# Generate translation with the target language token
|
102 |
translated_tokens = translator_model.generate(
|
103 |
**inputs,
|
104 |
+
forced_bos_token_id=forced_bos_token_id,
|
105 |
max_length=128
|
106 |
)
|
107 |
|