c2r / app.py
gdnjr5233-YOLOer's picture
Update app.py
7ef1cd9 verified
import re
import time
import torch
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
from bert_score import score
import tempfile
# 模型加载(使用缓存加速)
@st.cache_resource
def load_models():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# 加载微调模型
finetuned_model_path = "finetuned_model_v2/best_model"
finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path).to(device)
# 加载困惑度模型
perplexity_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
perplexity_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
return finetuned_model, finetuned_tokenizer, perplexity_model, perplexity_tokenizer
# 初始化session_state
if 'processed' not in st.session_state:
st.session_state.processed = False
if 'translated_code' not in st.session_state:
st.session_state.translated_code = []
# 常量定义
CUSTOM_TERMS = {
"写入 CSV": "Запись в CSV",
"CSV 表头": "Заголовок таблицы CSV",
}
prefix = 'translate to ru: '
# 工具函数
def calculate_perplexity(text):
tokens = st.session_state.perplexity_tokenizer.encode(text, return_tensors='pt').to('cpu')
with torch.no_grad():
loss = st.session_state.perplexity_model(tokens, labels=tokens).loss
return torch.exp(loss).item()
def evaluate_translation(original, translated, scores):
P, R, F1 = score([translated], [original], model_type="xlm-roberta-large", lang="ru")
ppl = calculate_perplexity(translated)
scores.append((F1.item(), ppl))
# 翻译核心函数
def translate_text(text, term_dict=None):
preserved_paths = re.findall(r'[a-zA-Z]:\\[^ \u4e00-\u9fff]+', text)
for i, path in enumerate(preserved_paths):
text = text.replace(path, f"||PATH_{i}||")
if term_dict:
sorted_terms = sorted(term_dict.keys(), key=lambda x: len(x), reverse=True)
pattern = re.compile('|'.join(map(re.escape, sorted_terms)))
text = pattern.sub(lambda x: term_dict[x.group()], text)
src_text = prefix + text
input_ids = st.session_state.finetuned_tokenizer(src_text, return_tensors="pt", max_length=512, truncation=True)
generated_tokens = st.session_state.finetuned_model.generate(**input_ids.to('cpu'))
result = st.session_state.finetuned_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
translated = result[0]
for i, path in enumerate(preserved_paths):
translated = translated.replace(f"||PATH_{i}||", path)
translated = re.sub(r'(\b[а-яА-ЯёЁ]+)(\.py\b)', lambda m: f"{m.group(1)} {m.group(2)}", translated)
translated = re.sub(r'(?<=[а-яА-ЯёЁ])([.,!?])(?=\S)', r' \1', translated)
return translated
# 界面布局
st.set_page_config(layout="wide", page_icon="📝", page_title="Python C2R Code Comment Translator")
# 标题部分
st.title("Python Chinese to Russian Code Comment Translator")
st.subheader("Upload a Python file with Chinese comments", divider='rainbow')
# 文件上传
uploaded_file = st.file_uploader("Choose .py file", type=['py'], label_visibility='collapsed')
# 添加开始翻译按钮 (修改点1)
start_translation = st.button("开始翻译 / Start Translation")
# 初始化模型
if 'models_loaded' not in st.session_state:
with st.spinner('Loading models...'):
(finetuned_model, finetuned_tokenizer,
perplexity_model, perplexity_tokenizer) = load_models()
st.session_state.update({
'finetuned_model': finetuned_model,
'finetuned_tokenizer': finetuned_tokenizer,
'perplexity_model': perplexity_model,
'perplexity_tokenizer': perplexity_tokenizer,
'models_loaded': True
})
# 处理上传文件 (修改点2:添加按钮触发逻辑)
if uploaded_file and start_translation:
st.session_state.processed = False # 重置处理状态
st.session_state.translated_code = [] # 清空已翻译内容
with st.spinner('Processing file...'):
code_lines = [line.decode('utf-8-sig') if isinstance(line, bytes) else line
for line in uploaded_file.readlines()]
# 添加行号(去掉冒号)
numbered_original = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(code_lines)])
numbered_translated = []
# 创建两列布局
col1, col2 = st.columns(2)
# 原始代码框
with col1:
st.subheader("Original Python Code")
original_content = st.session_state.original_content = numbered_original
st.code(original_content, language='python')
# 翻译代码框
with col2:
st.subheader("Real-time Translation")
translated_box = st.empty()
progress_bar = st.progress(0)
status_text = st.empty()
# 处理变量初始化
translated_lines = []
detected_count = 0
translated_count = 0
scores = []
total_lines = len(code_lines)
# 正则表达式模式
pure_comment_pattern = re.compile(r'^(\s*)#.*?([\u4e00-\u9fff]+.*)')
inline_comment_pattern = re.compile(r'(\s+#)\s*([^#]*[\u4e00-\u9fff]+[^#]*)')
multi_comment_pattern = re.compile(r'^(\s*)(["\']{3})(.*?)\2', re.DOTALL)
# 逐行处理
for idx, line in enumerate(code_lines):
current_line = line.rstrip('\n')
# 更新进度
progress = (idx + 1) / total_lines
progress_bar.progress(progress)
status_text.markdown(f"**Processing line {idx+1}/{total_lines}** | Content: `{current_line[:50]}...`")
# 注释处理逻辑
processed = False
if pure_comment_pattern.search(line):
detected_count += 1
if match := pure_comment_pattern.match(line):
indent, comment = match.groups()
translated = translate_text(comment.strip(), CUSTOM_TERMS)
evaluate_translation(comment, translated, scores)
translated_lines.append(f"{indent}# {translated}\n")
translated_count += 1
processed = True
if not processed and inline_comment_pattern.search(line):
detected_count += 1
if match := inline_comment_pattern.search(line):
code_part = line[:match.start()]
symbol, comment = match.groups()
translated = translate_text(comment.strip(), CUSTOM_TERMS)
evaluate_translation(comment, translated, scores)
translated_lines.append(f"{code_part}{symbol} {translated}\n")
translated_count += 1
processed = True
if not processed and (multi_match := multi_comment_pattern.match(line)):
detected_count += 1
if re.search(r'[\u4e00-\u9fff]', multi_match.group(3)):
translated = translate_text(multi_match.group(3), CUSTOM_TERMS)
evaluate_translation(multi_match.group(3), translated, scores)
translated_lines.append(f"{multi_match.group(1)}{multi_match.group(2)}{translated}{multi_match.group(2)}\n")
translated_count += 1
processed = True
if not processed:
translated_lines.append(line)
# 更新带行号的翻译结果(去掉冒号)
numbered_translation = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(translated_lines)])
translated_box.code(numbered_translation, language='python')
time.sleep(0.1)
# 处理完成
st.session_state.translated_code = translated_lines
st.session_state.detected_count = detected_count
st.session_state.translated_count = translated_count
st.session_state.scores = scores
st.session_state.processed = True
# 清空进度状态
progress_bar.empty()
status_text.empty()
# 显示统计信息
if st.session_state.processed:
st.divider()
# 右侧统计布局
with st.container():
col_right = st.columns([1, 3])[1]
with col_right:
# 第一行指标
col1, col2 = st.columns(2)
with col1:
st.metric("Detected Comments", st.session_state.detected_count)
with col2:
st.metric("Translated Comments", st.session_state.translated_count)
# 第二行指标
col3, col4 = st.columns(2)
with col3:
if st.session_state.scores:
avg_bert = sum(f1 for f1, _ in st.session_state.scores) / len(st.session_state.scores)
st.metric("Average BERTScore", f"{avg_bert:.4f}", help="Higher is better (0-1)")
with col4:
if st.session_state.scores:
avg_ppl = sum(ppl for _, ppl in st.session_state.scores) / len(st.session_state.scores)
st.metric("Average Perplexity", f"{avg_ppl:.4f}", help="Lower is better (Typical range: 1~100+, lower means better translation)")
# 下载按钮(修改点3:调整位置到指标下方)
cols = st.columns([1, 2, 1])
with cols[1]:
with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as tmp:
tmp.write("".join(st.session_state.translated_code).encode('utf-8'))
with open(tmp.name, 'rb') as f:
st.download_button(
label="⬇️ Download Translated File",
data=f,
file_name=f"translated_{uploaded_file.name}",
mime='text/x-python',
use_container_width=False
)