Spaces:
Runtime error
Runtime error
# app.py | |
import torch | |
from transformers import AutoTokenizer, EncoderDecoderModel | |
import gradio as gr | |
from spaces import GPU | |
# デバイス設定 (Spacesのハードウェア設定に依存) | |
# SpacesでGPUを利用する場合、自動的にCUDAが利用可能になります | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") # デバイス確認用ログ | |
model_name = "Shuu12121/CodeEncoderDecoderModel-Ghost-large" | |
print(f"Loading model: {model_name}") # モデル読み込み開始ログ | |
# --- Tokenizerの読み込み --- | |
try: | |
# subfolder引数を使用してサブディレクトリを指定 | |
encoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="encoder_tokenizer") | |
decoder_tokenizer = AutoTokenizer.from_pretrained(model_name, subfolder="decoder_tokenizer") | |
print("Tokenizers loaded successfully.") | |
except Exception as e: | |
print(f"Error loading tokenizers: {e}") | |
raise # ここではエラーを再発生させて、起動を停止させます | |
# decoder_tokenizerのpad_token設定 | |
if decoder_tokenizer.pad_token is None: | |
if decoder_tokenizer.eos_token is not None: | |
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token | |
print("Set decoder pad_token to eos_token.") | |
else: | |
# eos_tokenもない場合の代替処理(例: '<pad>'トークンを追加) | |
decoder_tokenizer.add_special_tokens({'pad_token': '<pad>'}) | |
print("Added '<pad>' as pad_token.") | |
# モデルのリサイズが必要になる場合がある | |
# model.resize_token_embeddings(len(decoder_tokenizer)) # 必要に応じて | |
# --- モデルの読み込み --- | |
try: | |
# モデルの読み込みは通常通りリポジトリ名を指定すればOK | |
# config.jsonが適切に設定されていれば、エンコーダー/デコーダー部分は自動的に読み込まれる | |
model = EncoderDecoderModel.from_pretrained(model_name).to(device) | |
model.eval() # 評価モードに設定 | |
print("Model loaded successfully and moved to device.") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
raise | |
# --- Docstring生成関数 --- | |
def generate_docstring(code: str) -> str: | |
print("Received code snippet for docstring generation.") # 関数呼び出しログ | |
if not code: | |
return "Please provide a code snippet." | |
try: | |
# エンコーダー入力の準備 | |
inputs = encoder_tokenizer( | |
code, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=2048 # モデルが許容する最大長に合わせる(必要なら調整) | |
).to(device) | |
print(f"Input tokens length: {inputs.input_ids.shape[1]}") | |
# 生成実行 | |
with torch.no_grad(): | |
# pad_token_idを明示的に指定 (重要: Noneでないことを確認) | |
pad_token_id = decoder_tokenizer.pad_token_id if decoder_tokenizer.pad_token_id is not None else decoder_tokenizer.eos_token_id | |
output_ids = model.generate( | |
input_ids=inputs.input_ids, | |
attention_mask=inputs.attention_mask, | |
max_length=256, | |
num_beams=10, | |
early_stopping=True, | |
eos_token_id=decoder_tokenizer.eos_token_id, | |
pad_token_id=pad_token_id, | |
no_repeat_ngram_size=3, | |
bad_words_ids=decoder_tokenizer(["sexual", "abuse", "child"], add_special_tokens=False).input_ids | |
) | |
print(f"Generated output tokens length: {output_ids.shape[1]}") | |
# デコードしてテキストに変換 | |
generated_docstring = decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
print("Docstring generated successfully.") | |
return generated_docstring | |
except Exception as e: | |
print(f"Error during generation: {e}") | |
# ユーザーにエラーを通知 | |
return f"An error occurred during generation: {e}" | |
# --- Gradio UI --- | |
iface = gr.Interface( | |
fn=generate_docstring, | |
inputs=gr.Textbox( | |
label="Code Snippet", | |
lines=10, | |
placeholder="Paste your Python function or code block here...", | |
value="""public static String readFileToString(File file, Charset encoding) throws IOException { | |
try (BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), encoding))) { | |
StringBuilder sb = new StringBuilder(); | |
String line; | |
while ((line = reader.readLine()) != null) { | |
sb.append(line).append("\\n"); | |
} | |
return sb.toString(); | |
} | |
}""" | |
), | |
outputs=gr.Textbox(label="Generated Docstring"), | |
title="Code-to-Docstring Generator (Shuu12121/CodeEncoderDecoderModel-Ghost)", | |
description="This demo uses the Shuu12121/CodeEncoderDecoderModel-Ghost model to automatically generate Python docstrings from code snippets. Paste your code below and click 'Submit'." | |
) | |
# --- アプリケーションの起動 --- | |
# Hugging Face Spacesで実行する場合、share=Trueは不要 | |
if __name__ == "__main__": | |
print("Launching Gradio interface...") | |
iface.launch() | |
print("Gradio interface launched.") |