Spaces:
Running
Running
import re | |
import time | |
import torch | |
import streamlit as st | |
from transformers import T5ForConditionalGeneration, T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer | |
from bert_score import score | |
import tempfile | |
# 模型加载(使用缓存加速) | |
def load_models(): | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# 加载微调模型 | |
finetuned_model_path = "finetuned_model_v2/best_model" | |
finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path) | |
finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path).to(device) | |
# 加载困惑度模型 | |
perplexity_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device) | |
perplexity_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
return finetuned_model, finetuned_tokenizer, perplexity_model, perplexity_tokenizer | |
# 初始化session_state | |
if 'processed' not in st.session_state: | |
st.session_state.processed = False | |
if 'translated_code' not in st.session_state: | |
st.session_state.translated_code = [] | |
# 常量定义 | |
CUSTOM_TERMS = { | |
"写入 CSV": "Запись в CSV", | |
"CSV 表头": "Заголовок таблицы CSV", | |
} | |
prefix = 'translate to ru: ' | |
# 工具函数 | |
def calculate_perplexity(text): | |
tokens = st.session_state.perplexity_tokenizer.encode(text, return_tensors='pt').to('cpu') | |
with torch.no_grad(): | |
loss = st.session_state.perplexity_model(tokens, labels=tokens).loss | |
return torch.exp(loss).item() | |
def evaluate_translation(original, translated, scores): | |
P, R, F1 = score([translated], [original], model_type="xlm-roberta-large", lang="ru") | |
ppl = calculate_perplexity(translated) | |
scores.append((F1.item(), ppl)) | |
# 翻译核心函数 | |
def translate_text(text, term_dict=None): | |
preserved_paths = re.findall(r'[a-zA-Z]:\\[^ \u4e00-\u9fff]+', text) | |
for i, path in enumerate(preserved_paths): | |
text = text.replace(path, f"||PATH_{i}||") | |
if term_dict: | |
sorted_terms = sorted(term_dict.keys(), key=lambda x: len(x), reverse=True) | |
pattern = re.compile('|'.join(map(re.escape, sorted_terms))) | |
text = pattern.sub(lambda x: term_dict[x.group()], text) | |
src_text = prefix + text | |
input_ids = st.session_state.finetuned_tokenizer(src_text, return_tensors="pt", max_length=512, truncation=True) | |
generated_tokens = st.session_state.finetuned_model.generate(**input_ids.to('cpu')) | |
result = st.session_state.finetuned_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
translated = result[0] | |
for i, path in enumerate(preserved_paths): | |
translated = translated.replace(f"||PATH_{i}||", path) | |
translated = re.sub(r'(\b[а-яА-ЯёЁ]+)(\.py\b)', lambda m: f"{m.group(1)} {m.group(2)}", translated) | |
translated = re.sub(r'(?<=[а-яА-ЯёЁ])([.,!?])(?=\S)', r' \1', translated) | |
return translated | |
# 界面布局 | |
st.set_page_config(layout="wide", page_icon="📝", page_title="Python C2R Code Comment Translator") | |
# 标题部分 | |
st.title("Python Chinese to Russian Code Comment Translator") | |
st.subheader("Upload a Python file with Chinese comments", divider='rainbow') | |
# 文件上传 | |
uploaded_file = st.file_uploader("Choose .py file", type=['py'], label_visibility='collapsed') | |
# 添加开始翻译按钮 (修改点1) | |
start_translation = st.button("开始翻译 / Start Translation") | |
# 初始化模型 | |
if 'models_loaded' not in st.session_state: | |
with st.spinner('Loading models...'): | |
(finetuned_model, finetuned_tokenizer, | |
perplexity_model, perplexity_tokenizer) = load_models() | |
st.session_state.update({ | |
'finetuned_model': finetuned_model, | |
'finetuned_tokenizer': finetuned_tokenizer, | |
'perplexity_model': perplexity_model, | |
'perplexity_tokenizer': perplexity_tokenizer, | |
'models_loaded': True | |
}) | |
# 处理上传文件 (修改点2:添加按钮触发逻辑) | |
if uploaded_file and start_translation: | |
st.session_state.processed = False # 重置处理状态 | |
st.session_state.translated_code = [] # 清空已翻译内容 | |
with st.spinner('Processing file...'): | |
code_lines = [line.decode('utf-8-sig') if isinstance(line, bytes) else line | |
for line in uploaded_file.readlines()] | |
# 添加行号(去掉冒号) | |
numbered_original = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(code_lines)]) | |
numbered_translated = [] | |
# 创建两列布局 | |
col1, col2 = st.columns(2) | |
# 原始代码框 | |
with col1: | |
st.subheader("Original Python Code") | |
original_content = st.session_state.original_content = numbered_original | |
st.code(original_content, language='python') | |
# 翻译代码框 | |
with col2: | |
st.subheader("Real-time Translation") | |
translated_box = st.empty() | |
progress_bar = st.progress(0) | |
status_text = st.empty() | |
# 处理变量初始化 | |
translated_lines = [] | |
detected_count = 0 | |
translated_count = 0 | |
scores = [] | |
total_lines = len(code_lines) | |
# 正则表达式模式 | |
pure_comment_pattern = re.compile(r'^(\s*)#.*?([\u4e00-\u9fff]+.*)') | |
inline_comment_pattern = re.compile(r'(\s+#)\s*([^#]*[\u4e00-\u9fff]+[^#]*)') | |
multi_comment_pattern = re.compile(r'^(\s*)(["\']{3})(.*?)\2', re.DOTALL) | |
# 逐行处理 | |
for idx, line in enumerate(code_lines): | |
current_line = line.rstrip('\n') | |
# 更新进度 | |
progress = (idx + 1) / total_lines | |
progress_bar.progress(progress) | |
status_text.markdown(f"**Processing line {idx+1}/{total_lines}** | Content: `{current_line[:50]}...`") | |
# 注释处理逻辑 | |
processed = False | |
if pure_comment_pattern.search(line): | |
detected_count += 1 | |
if match := pure_comment_pattern.match(line): | |
indent, comment = match.groups() | |
translated = translate_text(comment.strip(), CUSTOM_TERMS) | |
evaluate_translation(comment, translated, scores) | |
translated_lines.append(f"{indent}# {translated}\n") | |
translated_count += 1 | |
processed = True | |
if not processed and inline_comment_pattern.search(line): | |
detected_count += 1 | |
if match := inline_comment_pattern.search(line): | |
code_part = line[:match.start()] | |
symbol, comment = match.groups() | |
translated = translate_text(comment.strip(), CUSTOM_TERMS) | |
evaluate_translation(comment, translated, scores) | |
translated_lines.append(f"{code_part}{symbol} {translated}\n") | |
translated_count += 1 | |
processed = True | |
if not processed and (multi_match := multi_comment_pattern.match(line)): | |
detected_count += 1 | |
if re.search(r'[\u4e00-\u9fff]', multi_match.group(3)): | |
translated = translate_text(multi_match.group(3), CUSTOM_TERMS) | |
evaluate_translation(multi_match.group(3), translated, scores) | |
translated_lines.append(f"{multi_match.group(1)}{multi_match.group(2)}{translated}{multi_match.group(2)}\n") | |
translated_count += 1 | |
processed = True | |
if not processed: | |
translated_lines.append(line) | |
# 更新带行号的翻译结果(去掉冒号) | |
numbered_translation = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(translated_lines)]) | |
translated_box.code(numbered_translation, language='python') | |
time.sleep(0.1) | |
# 处理完成 | |
st.session_state.translated_code = translated_lines | |
st.session_state.detected_count = detected_count | |
st.session_state.translated_count = translated_count | |
st.session_state.scores = scores | |
st.session_state.processed = True | |
# 清空进度状态 | |
progress_bar.empty() | |
status_text.empty() | |
# 显示统计信息 | |
if st.session_state.processed: | |
st.divider() | |
# 右侧统计布局 | |
with st.container(): | |
col_right = st.columns([1, 3])[1] | |
with col_right: | |
# 第一行指标 | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Detected Comments", st.session_state.detected_count) | |
with col2: | |
st.metric("Translated Comments", st.session_state.translated_count) | |
# 第二行指标 | |
col3, col4 = st.columns(2) | |
with col3: | |
if st.session_state.scores: | |
avg_bert = sum(f1 for f1, _ in st.session_state.scores) / len(st.session_state.scores) | |
st.metric("Average BERTScore", f"{avg_bert:.4f}", help="Higher is better (0-1)") | |
with col4: | |
if st.session_state.scores: | |
avg_ppl = sum(ppl for _, ppl in st.session_state.scores) / len(st.session_state.scores) | |
st.metric("Average Perplexity", f"{avg_ppl:.4f}", help="Lower is better (Typical range: 1~100+, lower means better translation)") | |
# 下载按钮(修改点3:调整位置到指标下方) | |
cols = st.columns([1, 2, 1]) | |
with cols[1]: | |
with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as tmp: | |
tmp.write("".join(st.session_state.translated_code).encode('utf-8')) | |
with open(tmp.name, 'rb') as f: | |
st.download_button( | |
label="⬇️ Download Translated File", | |
data=f, | |
file_name=f"translated_{uploaded_file.name}", | |
mime='text/x-python', | |
use_container_width=False | |
) |