Shuu12121 commited on
Commit
c50ba8d
·
verified ·
1 Parent(s): 6897705

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -13,12 +13,12 @@ print(f"Loading model: {model_name}") # モデル読み込み開始ログ
13
 
14
  # --- Tokenizerの読み込み ---
15
  try:
16
- encoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/encoder_tokenizer")
17
- decoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/decoder_tokenizer")
 
18
  print("Tokenizers loaded successfully.")
19
  except Exception as e:
20
  print(f"Error loading tokenizers: {e}")
21
- # エラーが発生した場合、Gradioインターフェースでエラーを表示するなどの処理を追加できます
22
  raise # ここではエラーを再発生させて、起動を停止させます
23
 
24
  # decoder_tokenizerのpad_token設定
@@ -35,6 +35,8 @@ if decoder_tokenizer.pad_token is None:
35
 
36
  # --- モデルの読み込み ---
37
  try:
 
 
38
  model = EncoderDecoderModel.from_pretrained(model_name).to(device)
39
  model.eval() # 評価モードに設定
40
  print("Model loaded successfully and moved to device.")
@@ -62,16 +64,18 @@ def generate_docstring(code: str) -> str:
62
 
63
  # 生成実行
64
  with torch.no_grad():
 
 
 
65
  output_ids = model.generate(
66
  input_ids=inputs.input_ids,
67
  attention_mask=inputs.attention_mask,
68
  max_length=256, # 生成するDocstringの最大長
69
  num_beams=5, # ビームサーチのビーム数
70
  early_stopping=True, # 早く停止させるか
71
- # decoder_start_token_idは通常model.configから自動設定されるが、明示的に指定も可能
72
- # decoder_start_token_id=model.config.decoder_start_token_id,
73
  eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
74
- pad_token_id=decoder_tokenizer.pad_token_id, # PADトークンID
75
  no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
76
  )
77
 
 
13
 
14
  # --- Tokenizerの読み込み ---
15
  try:
16
+ # subfolder引数を使用してサブディレクトリを指定
17
+ encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer")
18
+ decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer")
19
  print("Tokenizers loaded successfully.")
20
  except Exception as e:
21
  print(f"Error loading tokenizers: {e}")
 
22
  raise # ここではエラーを再発生させて、起動を停止させます
23
 
24
  # decoder_tokenizerのpad_token設定
 
35
 
36
  # --- モデルの読み込み ---
37
  try:
38
+ # モデルの読み込みは通常通りリポジトリ名を指定すればOK
39
+ # config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる
40
  model = EncoderDecoderModel.from_pretrained(model_name).to(device)
41
  model.eval() # 評価モードに設定
42
  print("Model loaded successfully and moved to device.")
 
64
 
65
  # 生成実行
66
  with torch.no_grad():
67
+ # pad_token_idを明示的に指定 (重要: Noneでないことを確認)
68
+ pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id
69
+
70
  output_ids = model.generate(
71
  input_ids=inputs.input_ids,
72
  attention_mask=inputs.attention_mask,
73
  max_length=256, # 生成するDocstringの最大長
74
  num_beams=5, # ビームサーチのビーム数
75
  early_stopping=True, # 早く停止させるか
76
+ # decoder_start_token_idは通常model.configから自動設定される
 
77
  eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
78
+ pad_token_id=pad_token_id, # PADトークンID (Noneでないことを保証)
79
  no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
80
  )
81