File size: 5,282 Bytes
6897705
439fcd4
 
6897705
d821efc
439fcd4
6897705
 
439fcd4
6897705
 
569f6b9
6897705
439fcd4
6897705
 
c50ba8d
 
 
6897705
 
 
 
439fcd4
6897705
439fcd4
6897705
 
 
 
 
 
 
 
 
439fcd4
6897705
 
c50ba8d
 
6897705
 
 
 
 
 
439fcd4
6897705
d821efc
439fcd4
6897705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c50ba8d
 
 
6897705
 
 
68d0547
 
986ff0b
 
 
 
 
6897705
 
986ff0b
6897705
 
 
 
 
 
 
 
 
 
 
c7a8772
6897705
439fcd4
 
a010e61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
439fcd4
2bcfad1
 
c7a8772
 
a010e61
6897705
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# app.py
import torch
from transformers import AutoTokenizer, EncoderDecoderModel
import gradio as gr
from spaces import GPU

# デバイス設定 (Spacesのハードウェア設定に依存)
# SpacesでGPUを利用する場合、自動的にCUDAが利用可能になります
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # デバイス確認用ログ

model_name = "Shuu12121/CodeEncoderDecoderModel-Ghost-large"
print(f"Loading model: {model_name}") # モデル読み込み開始ログ

# --- Tokenizerの読み込み ---
try:
    # subfolder引数を使用してサブディレクトリを指定
    encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer")
    decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer")
    print("Tokenizers loaded successfully.")
except Exception as e:
    print(f"Error loading tokenizers: {e}")
    raise # ここではエラーを再発生させて、起動を停止させます

# decoder_tokenizerのpad_token設定
if decoder_tokenizer.pad_token is None:
    if decoder_tokenizer.eos_token is not None:
        decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
        print("Set decoder pad_token to eos_token.")
    else:
        # eos_tokenもない場合の代替処理(例: '<pad>'トークンを追加)
        decoder_tokenizer.add_special_tokens({'pad_token': '<pad>'})
        print("Added '<pad>' as pad_token.")
        # モデルのリサイズが必要になる場合がある
        # model.resize_token_embeddings(len(decoder_tokenizer)) # 必要に応じて

# --- モデルの読み込み ---
try:
    # モデルの読み込みは通常通りリポジトリ名を指定すればOK
    # config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる
    model = EncoderDecoderModel.from_pretrained(model_name).to(device)
    model.eval() # 評価モードに設定
    print("Model loaded successfully and moved to device.")
except Exception as e:
    print(f"Error loading model: {e}")
    raise

# --- Docstring生成関数 ---
@GPU
def generate_docstring(code: str) -> str:
    print("Received code snippet for docstring generation.") # 関数呼び出しログ
    if not code:
        return "Please provide a code snippet."

    try:
        # エンコーダー入力の準備
        inputs = encoder_tokenizer(
            code,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=2048 # モデルが許容する最大長に合わせる(必要なら調整)
        ).to(device)

        print(f"Input tokens length: {inputs.input_ids.shape[1]}")

        # 生成実行
        with torch.no_grad():
            # pad_token_idを明示的に指定 (重要: Noneでないことを確認)
            pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id

            output_ids = model.generate(
                input_ids=inputs.input_ids,
                attention_mask=inputs.attention_mask,
                max_length=256,
                num_beams=10,
                early_stopping=True,
                eos_token_id=decoder_tokenizer.eos_token_id,
                pad_token_id=pad_token_id,
                no_repeat_ngram_size=3,
                bad_words_ids=decoder_tokenizer(["sexual", "abuse", "child"], add_special_tokens=False).input_ids
            )


        print(f"Generated output tokens length: {output_ids.shape[1]}")

        # デコードしてテキストに変換
        generated_docstring = decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
        print("Docstring generated successfully.")
        return generated_docstring

    except Exception as e:
        print(f"Error during generation: {e}")
        # ユーザーにエラーを通知
        return f"An error occurred during generation: {e}"

# --- Gradio UI ---
iface = gr.Interface(
    fn=generate_docstring,
    inputs=gr.Textbox(
        label="Code Snippet",
        lines=10,
        placeholder="Paste your Python function or code block here...",
        value="""public static String readFileToString(File file, Charset encoding) throws IOException {
    try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), encoding))) {
        StringBuilder sb = new StringBuilder();
        String line;
        while ((line = reader.readLine()) != null) {
            sb.append(line).append("\\n");
        }
        return sb.toString();
    }
}"""
    ),
    outputs=gr.Textbox(label="Generated Docstring"),
    title="Code-to-Docstring Generator (Shuu12121/CodeEncoderDecoderModel-Ghost)",
    description="This demo uses the Shuu12121/CodeEncoderDecoderModel-Ghost model to automatically generate Python docstrings from code snippets. Paste your code below and click 'Submit'."
)


# --- アプリケーションの起動 ---
# Hugging Face Spacesで実行する場合、share=Trueは不要
if __name__ == "__main__":
    print("Launching Gradio interface...")
    iface.launch()
    print("Gradio interface launched.")