Shuu12121's picture
Update app.py
439fcd4 verified
raw
history blame
1.72 kB
import torch
from transformers import AutoTokenizer, EncoderDecoderModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "Shuu12121/CodeEncoderDecodeerModel-Ghost"
# Tokenizerの読み込み
encoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/encoder_tokenizer")
decoder_tokenizer = AutoTokenizer.from_pretrained(f"{model_name}/decoder_tokenizer")
if decoder_tokenizer.pad_token is None:
decoder_tokenizer.pad_token = decoder_tokenizer.eos_token
model = EncoderDecoderModel.from_pretrained(model_name).to(device)
model.eval()
def generate_docstring(code: str) -> str:
inputs = encoder_tokenizer(
code,
return_tensors="pt",
padding=True,
truncation=True,
max_length=2048
).to(device)
with torch.no_grad():
output_ids = model.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=256,
num_beams=5,
early_stopping=True,
decoder_start_token_id=model.config.decoder_start_token_id,
eos_token_id=model.config.eos_token_id,
pad_token_id=model.config.pad_token_id,
no_repeat_ngram_size=2
)
return decoder_tokenizer.decode(output_ids[0], skip_special_tokens=True)
# Gradio UI
import gradio as gr
iface = gr.Interface(
fn=generate_docstring,
inputs=gr.Textbox(label="Code Snippet", lines=10, placeholder="Paste your function here..."),
outputs=gr.Textbox(label="Generated Docstring"),
title="Code-to-Docstring Generator",
description="This demo uses a custom encoder-decoder model to generate docstrings from code."
)
iface.launch()