Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -13,12 +13,12 @@ print(f"Loading model: {model_name}") # モデル読み込み開始ログ
|
|
13 |
|
14 |
# --- Tokenizerの読み込み ---
|
15 |
try:
|
16 |
-
|
17 |
-
|
|
|
18 |
print("Tokenizers loaded successfully.")
|
19 |
except Exception as e:
|
20 |
print(f"Error loading tokenizers: {e}")
|
21 |
-
# エラーが発生した場合、Gradioインターフェースでエラーを表示するなどの処理を追加できます
|
22 |
raise # ここではエラーを再発生させて、起動を停止させます
|
23 |
|
24 |
# decoder_tokenizerのpad_token設定
|
@@ -35,6 +35,8 @@ if decoder_tokenizer.pad_token is None:
|
|
35 |
|
36 |
# --- モデルの読み込み ---
|
37 |
try:
|
|
|
|
|
38 |
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
|
39 |
model.eval() # 評価モードに設定
|
40 |
print("Model loaded successfully and moved to device.")
|
@@ -62,16 +64,18 @@ def generate_docstring(code: str) -> str:
|
|
62 |
|
63 |
# 生成実行
|
64 |
with torch.no_grad():
|
|
|
|
|
|
|
65 |
output_ids = model.generate(
|
66 |
input_ids=inputs.input_ids,
|
67 |
attention_mask=inputs.attention_mask,
|
68 |
max_length=256, # 生成するDocstringの最大長
|
69 |
num_beams=5, # ビームサーチのビーム数
|
70 |
early_stopping=True, # 早く停止させるか
|
71 |
-
# decoder_start_token_idは通常model.config
|
72 |
-
# decoder_start_token_id=model.config.decoder_start_token_id,
|
73 |
eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
|
74 |
-
pad_token_id=
|
75 |
no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
|
76 |
)
|
77 |
|
|
|
13 |
|
14 |
# --- Tokenizerの読み込み ---
|
15 |
try:
|
16 |
+
# subfolder引数を使用してサブディレクトリを指定
|
17 |
+
encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer")
|
18 |
+
decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer")
|
19 |
print("Tokenizers loaded successfully.")
|
20 |
except Exception as e:
|
21 |
print(f"Error loading tokenizers: {e}")
|
|
|
22 |
raise # ここではエラーを再発生させて、起動を停止させます
|
23 |
|
24 |
# decoder_tokenizerのpad_token設定
|
|
|
35 |
|
36 |
# --- モデルの読み込み ---
|
37 |
try:
|
38 |
+
# モデルの読み込みは通常通りリポジトリ名を指定すればOK
|
39 |
+
# config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる
|
40 |
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
|
41 |
model.eval() # 評価モードに設定
|
42 |
print("Model loaded successfully and moved to device.")
|
|
|
64 |
|
65 |
# 生成実行
|
66 |
with torch.no_grad():
|
67 |
+
# pad_token_idを明示的に指定 (重要: Noneでないことを確認)
|
68 |
+
pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id
|
69 |
+
|
70 |
output_ids = model.generate(
|
71 |
input_ids=inputs.input_ids,
|
72 |
attention_mask=inputs.attention_mask,
|
73 |
max_length=256, # 生成するDocstringの最大長
|
74 |
num_beams=5, # ビームサーチのビーム数
|
75 |
early_stopping=True, # 早く停止させるか
|
76 |
+
# decoder_start_token_idは通常model.configから自動設定される
|
|
|
77 |
eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
|
78 |
+
pad_token_id=pad_token_id, # PADトークンID (Noneでないことを保証)
|
79 |
no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
|
80 |
)
|
81 |
|