Spaces:
Running
Running
File size: 4,695 Bytes
6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 439fcd4 6897705 c7a8772 6897705 439fcd4 6897705 439fcd4 6897705 c7a8772 6897705 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# app.py
import torch
from transformers import AutoTokenizer, EncoderDecoderModel
import gradio as gr
# デバイス設定 (Spacesのハードウェア設定に依存)
# SpacesでGPUを利用する場合、自動的にCUDAが利用可能になります
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}") # デバイス確認用ログ
model_name = "Shuu12121/CodeEncoderDecodeerModel-Ghost"
print(f"Loading model: {model_name}") # モデル読み込み開始ログ
# --- Tokenizerの読み込み ---
try:
encoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/encoder_tokenizer")
decoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/decoder_tokenizer")
print("Tokenizers loaded successfully.")
except Exception as e:
print(f"Error loading tokenizers: {e}")
# エラーが発生した場合、Gradioインターフェースでエラーを表示するなどの処理を追加できます
raise # ここではエラーを再発生させて、起動を停止させます
# decoder_tokenizerのpad_token設定
if decoder_tokenizer.pad_token is None:
if decoder_tokenizer.eos_token is not None:
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
print("Set decoder pad_token to eos_token.")
else:
# eos_tokenもない場合の代替処理(例: '<pad>'トークンを追加)
decoder_tokenizer.add_special_tokens({'pad_token': '<pad>'})
print("Added '<pad>' as pad_token.")
# モデルのリサイズが必要になる場合がある
# model.resize_token_embeddings(len(decoder_tokenizer)) # 必要に応じて
# --- モデルの読み込み ---
try:
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
model.eval() # 評価モードに設定
print("Model loaded successfully and moved to device.")
except Exception as e:
print(f"Error loading model: {e}")
raise
# --- Docstring生成関数 ---
def generate_docstring(code: str) -> str:
print("Received code snippet for docstring generation.") # 関数呼び出しログ
if not code:
return "Please provide a code snippet."
try:
# エンコーダー入力の準備
inputs = encoder_tokenizer(
code,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048 # モデルが許容する最大長に合わせる(必要なら調整)
).to(device)
print(f"Input tokens length: {inputs.input_ids.shape[1]}")
# 生成実行
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=256, # 生成するDocstringの最大長
num_beams=5, # ビームサーチのビーム数
early_stopping=True, # 早く停止させるか
# decoder_start_token_idは通常model.configから自動設定されるが、明示的に指定も可能
# decoder_start_token_id=model.config.decoder_start_token_id,
eos_token_id=decoder_tokenizer.eos_token_id, # EOSトークンID
pad_token_id=decoder_tokenizer.pad_token_id, # PADトークンID
no_repeat_ngram_size=2 # 繰り返さないN-gramサイズ
)
print(f"Generated output tokens length: {output_ids.shape[1]}")
# デコードしてテキストに変換
generated_docstring = decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
print("Docstring generated successfully.")
return generated_docstring
except Exception as e:
print(f"Error during generation: {e}")
# ユーザーにエラーを通知
return f"An error occurred during generation: {e}"
# --- Gradio UI ---
iface = gr.Interface(
fn=generate_docstring,
inputs=gr.Textbox(label="Code Snippet", lines=10, placeholder="Paste your Python function or code block here..."),
outputs=gr.Textbox(label="Generated Docstring"),
title="Code-to-Docstring Generator (Shuu12121/CodeEncoderDecodeerModel-Ghost)",
description="This demo uses the Shuu12121/CodeEncoderDecodeerModel-Ghost model to automatically generate Python docstrings from code snippets. Paste your code below and click 'Submit'."
)
# --- アプリケーションの起動 ---
# Hugging Face Spacesで実行する場合、share=Trueは不要
if __name__ == "__main__":
print("Launching Gradio interface...")
iface.launch()
print("Gradio interface launched.") |