Spaces:
Sleeping
Sleeping
File size: 2,862 Bytes
aa7cb02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch
def nllb():
"""
Load and return the NLLB (No Language Left Behind) model and tokenizer.
This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library.
The model is configured to use a GPU if available, otherwise it defaults to CPU.
Returns:
tuple: A tuple containing the loaded model and tokenizer.
- model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model.
- tokenizer (transformers.AutoTokenizer): The loaded tokenizer.
Example usage:
model, tokenizer = nllb()
"""
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B")
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device)
# write done to the file named status.txt
with open("status.txt", 'w') as f:
f.write("done")
return model, tokenizer
def nllb_translate(model, tokenizer, article, language):
"""
Translate an article using the NLLB model and tokenizer.
Args:
model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation.
Example: model, tokenizer = nllb()
tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model.
Example: model, tokenizer = nllb()
article (str): The article text to be translated.
Example: "This is a sample article."
language (str): The target language for translation. Must be either 'spanish' or 'english'.
Example: "spanish"
Returns:
str: The translated text.
Example: "Este es un artículo de muestra."
"""
try:
# Tokenize the text
inputs = tokenizer(article, return_tensors="pt")
# Move the tokenized inputs to the same device as the model
inputs = {k: v.to(model.device) for k, v in inputs.items()}
if language == "es":
translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30
)
elif language == "en":
translated_tokens = model.generate(
**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30
)
else:
raise ValueError("Unsupported language. Use 'es' or 'en'.")
# Decode the translation
text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
return text
except Exception as e:
print(f"Error during translation: {e}")
return "Translation failed"
|