File size: 2,862 Bytes
aa7cb02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import torch

def nllb():
    """
    Load and return the NLLB (No Language Left Behind) model and tokenizer.

    This function loads the NLLB-200-distilled-1.3B model and tokenizer from Hugging Face's Transformers library.
    The model is configured to use a GPU if available, otherwise it defaults to CPU.

    Returns:
        tuple: A tuple containing the loaded model and tokenizer.
            - model (transformers.AutoModelForSeq2SeqLM): The loaded NLLB model.
            - tokenizer (transformers.AutoTokenizer): The loaded tokenizer.
            
    Example usage:
        model, tokenizer = nllb()
    """
    #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-1.3B")
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-1.3B").to(device)
    # write done to the file named status.txt
    with open("status.txt", 'w') as f:
        f.write("done")
    return model, tokenizer

def nllb_translate(model, tokenizer, article, language):
    """
    Translate an article using the NLLB model and tokenizer.

    Args:
        model (transformers.AutoModelForSeq2SeqLM): The NLLB model to use for translation.
            Example: model, tokenizer = nllb()
        tokenizer (transformers.AutoTokenizer): The tokenizer to use with the NLLB model.
            Example: model, tokenizer = nllb()
        article (str): The article text to be translated.
            Example: "This is a sample article."
        language (str): The target language for translation. Must be either 'spanish' or 'english'.
            Example: "spanish"

    Returns:
        str: The translated text.
            Example: "Este es un artículo de muestra."
    """
    try:
        # Tokenize the text
        inputs = tokenizer(article, return_tensors="pt")

        # Move the tokenized inputs to the same device as the model
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        if language == "es":
            translated_tokens = model.generate(
                **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["spa_Latn"], max_length=30
            )
        elif language == "en":
            translated_tokens = model.generate(
                **inputs, forced_bos_token_id=tokenizer.lang_code_to_id["eng_Latn"], max_length=30
            )
        else:
            raise ValueError("Unsupported language. Use 'es' or 'en'.")

        # Decode the translation
        text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
        return text
    
    except Exception as e:
        print(f"Error during translation: {e}")
        return "Translation failed"