Michael Hu commited on
Commit
5f94a8b
·
1 Parent(s): c10f1ac

fix lang_code_to_id error

Browse files
Files changed (1) hide show
  1. utils/translation.py +14 -7
utils/translation.py CHANGED
@@ -4,10 +4,10 @@ Handles text segmentation and batch translation
4
  """
5
 
6
  import logging
7
- logger = logging.getLogger(__name__)
8
-
9
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
10
 
 
 
11
  def translate_text(text):
12
  """
13
  Translate English text to Simplified Chinese
@@ -19,9 +19,12 @@ def translate_text(text):
19
  logger.info(f"Starting translation for text length: {len(text)}")
20
 
21
  try:
22
- # Model initialization
23
  logger.info("Loading NLLB model")
24
- tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-3.3B")
 
 
 
25
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B")
26
  logger.info("Translation model loaded")
27
 
@@ -33,18 +36,22 @@ def translate_text(text):
33
  translated_chunks = []
34
  for i, chunk in enumerate(text_chunks):
35
  logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
 
 
36
  inputs = tokenizer(
37
- chunk,
38
- return_tensors="pt",
39
- max_length=1024,
40
  truncation=True
41
  )
42
 
 
43
  outputs = model.generate(
44
  **inputs,
45
  forced_bos_token_id=tokenizer.lang_code_to_id["zho_Hans"],
46
  max_new_tokens=1024
47
  )
 
48
  translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
  translated_chunks.append(translated)
50
  logger.info(f"Chunk {i+1} translated successfully")
 
4
  """
5
 
6
  import logging
 
 
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
9
+ logger = logging.getLogger(__name__)
10
+
11
  def translate_text(text):
12
  """
13
  Translate English text to Simplified Chinese
 
19
  logger.info(f"Starting translation for text length: {len(text)}")
20
 
21
  try:
22
+ # Model initialization with explicit language codes
23
  logger.info("Loading NLLB model")
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ "facebook/nllb-200-3.3B",
26
+ src_lang="eng_Latn" # Specify source language
27
+ )
28
  model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B")
29
  logger.info("Translation model loaded")
30
 
 
36
  translated_chunks = []
37
  for i, chunk in enumerate(text_chunks):
38
  logger.info(f"Processing chunk {i+1}/{len(text_chunks)}")
39
+
40
+ # Tokenize with source language specification
41
  inputs = tokenizer(
42
+ chunk,
43
+ return_tensors="pt",
44
+ max_length=1024,
45
  truncation=True
46
  )
47
 
48
+ # Generate translation with target language specification
49
  outputs = model.generate(
50
  **inputs,
51
  forced_bos_token_id=tokenizer.lang_code_to_id["zho_Hans"],
52
  max_new_tokens=1024
53
  )
54
+
55
  translated = tokenizer.decode(outputs[0], skip_special_tokens=True)
56
  translated_chunks.append(translated)
57
  logger.info(f"Chunk {i+1} translated successfully")