File size: 2,221 Bytes
cd1309d
 
 
 
 
c72d839
cd1309d
 
5f94a8b
 
cd1309d
 
 
 
 
 
 
 
c72d839
cd1309d
c72d839
5f94a8b
c72d839
5f94a8b
 
 
 
c72d839
2477bc4
c72d839
 
 
 
 
 
 
 
2477bc4
5f94a8b
 
c72d839
5f94a8b
 
 
c72d839
 
 
5f94a8b
c72d839
 
77b7581
c72d839
 
5f94a8b
c72d839
 
2477bc4
c72d839
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
Text Translation Module using NLLB-3.3B model
Handles text segmentation and batch translation
"""

import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

logger = logging.getLogger(__name__)

def translate_text(text):
    """
    Translate English text to Simplified Chinese
    Args:
        text: Input English text
    Returns:
        Translated Chinese text
    """
    logger.info(f"Starting translation for text length: {len(text)}")
    
    try:
        # Model initialization with explicit language codes
        logger.info("Loading NLLB model")
        tokenizer = AutoTokenizer.from_pretrained(
            "facebook/nllb-200-3.3B",
            src_lang="eng_Latn"  # Specify source language
        )
        model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B")
        logger.info("Translation model loaded")

        # Text processing
        max_chunk_length = 1000
        text_chunks = [text[i:i+max_chunk_length] for i in range(0, len(text), max_chunk_length)]
        logger.info(f"Split text into {len(text_chunks)} chunks")

        translated_chunks = []
        for i, chunk in enumerate(text_chunks):
            logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
            
            # Tokenize with source language specification
            inputs = tokenizer(
                chunk,
                return_tensors="pt",
                max_length=1024,
                truncation=True
            )
            
            # Generate translation with target language specification
            outputs = model.generate(
                **inputs,
                forced_bos_token_id=tokenizer.convert_tokens_to_ids("zho_Hans"),
                max_new_tokens=1024
            )
            
            translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            translated_chunks.append(translated)
            logger.info(f"Chunk {i+1} translated successfully")

        result = "".join(translated_chunks)
        logger.info(f"Translation completed. Total length: {len(result)}")
        return result

    except Exception as e:
        logger.error(f"Translation failed: {str(e)}", exc_info=True)
        raise