File size: 10,571 Bytes
3f0f75c
ffafed3
3f0f75c
 
 
 
 
 
47b349e
3f0f75c
 
a734e05
3f0f75c
47b349e
5038f85
 
 
3f0f75c
47b349e
3f0f75c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffafed3
3f0f75c
 
 
 
 
ffafed3
3f0f75c
 
 
ffafed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f0f75c
 
c554579
ffafed3
 
c554579
8dc3d7e
3f0f75c
 
a3dbc74
ffafed3
 
edbca6d
2a34866
3f0f75c
 
a3dbc74
47b349e
 
3f0f75c
47b349e
 
 
 
3f0f75c
 
 
ffafed3
edbca6d
ffafed3
 
edbca6d
a3dbc74
ffafed3
 
c554579
ffafed3
 
 
 
3f0f75c
 
 
ffafed3
3f0f75c
a3dbc74
ffafed3
c554579
f4bcfb9
ffafed3
3f0f75c
1aed243
3f0f75c
 
 
 
ffafed3
3f0f75c
ffafed3
 
3f0f75c
 
 
ffafed3
 
 
 
 
 
3f0f75c
ffafed3
3f0f75c
ffafed3
 
 
 
47b349e
ffafed3
 
 
 
 
 
 
 
 
 
 
47b349e
ffafed3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f0f75c
ffafed3
 
3f0f75c
ffafed3
 
 
 
3f0f75c
ffafed3
47b349e
ffafed3
 
 
47b349e
ffafed3
 
3f0f75c
 
 
a3dbc74
3f0f75c
 
 
47b349e
a3dbc74
19d6660
47b349e
a3dbc74
47b349e
 
 
c554579
47b349e
c554579
a3dbc74
47b349e
 
 
a3dbc74
 
8dc3d7e
47b349e
a3dbc74
 
47b349e
c554579
7ef1cd9
431d854
47b349e
 
 
 
 
 
 
 
 
 
dd15b53
47b349e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
import re
import time
import torch
import streamlit as st
from transformers import T5ForConditionalGeneration, T5Tokenizer, GPT2LMHeadModel, GPT2Tokenizer
from bert_score import score
import tempfile

# 模型加载(使用缓存加速)
@st.cache_resource
def load_models():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    # 加载微调模型
    finetuned_model_path = "finetuned_model_v2/best_model"
    finetuned_tokenizer = T5Tokenizer.from_pretrained(finetuned_model_path)
    finetuned_model = T5ForConditionalGeneration.from_pretrained(finetuned_model_path).to(device)
    
    # 加载困惑度模型
    perplexity_model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)
    perplexity_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    
    return finetuned_model, finetuned_tokenizer, perplexity_model, perplexity_tokenizer

# 初始化session_state
if 'processed' not in st.session_state:
    st.session_state.processed = False
if 'translated_code' not in st.session_state:
    st.session_state.translated_code = []

# 常量定义
CUSTOM_TERMS = {
    "写入 CSV": "Запись в CSV", 
    "CSV 表头": "Заголовок таблицы CSV",
}
prefix = 'translate to ru: '

# 工具函数
def calculate_perplexity(text):
    tokens = st.session_state.perplexity_tokenizer.encode(text, return_tensors='pt').to('cpu')
    with torch.no_grad():
        loss = st.session_state.perplexity_model(tokens, labels=tokens).loss
    return torch.exp(loss).item()

def evaluate_translation(original, translated, scores):
    P, R, F1 = score([translated], [original], model_type="xlm-roberta-large", lang="ru")
    ppl = calculate_perplexity(translated)
    scores.append((F1.item(), ppl))

# 翻译核心函数
def translate_text(text, term_dict=None):
    preserved_paths = re.findall(r'[a-zA-Z]:\\[^ \u4e00-\u9fff]+', text)
    for i, path in enumerate(preserved_paths):
        text = text.replace(path, f"||PATH_{i}||")

    if term_dict:
        sorted_terms = sorted(term_dict.keys(), key=lambda x: len(x), reverse=True)
        pattern = re.compile('|'.join(map(re.escape, sorted_terms)))
        text = pattern.sub(lambda x: term_dict[x.group()], text)

    src_text = prefix + text
    input_ids = st.session_state.finetuned_tokenizer(src_text, return_tensors="pt", max_length=512, truncation=True)
    generated_tokens = st.session_state.finetuned_model.generate(**input_ids.to('cpu'))
    result = st.session_state.finetuned_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

    translated = result[0]
    for i, path in enumerate(preserved_paths):
        translated = translated.replace(f"||PATH_{i}||", path)

    translated = re.sub(r'(\b[а-яА-ЯёЁ]+)(\.py\b)', lambda m: f"{m.group(1)} {m.group(2)}", translated)
    translated = re.sub(r'(?<=[а-яА-ЯёЁ])([.,!?])(?=\S)', r' \1', translated)
    return translated

# 界面布局
st.set_page_config(layout="wide", page_icon="📝", page_title="Python C2R Code Comment Translator")

# 标题部分
st.title("Python Chinese to Russian Code Comment Translator")
st.subheader("Upload a Python file with Chinese comments", divider='rainbow')

# 文件上传
uploaded_file = st.file_uploader("Choose .py file", type=['py'], label_visibility='collapsed')

# 添加开始翻译按钮 (修改点1)
start_translation = st.button("开始翻译 / Start Translation")

# 初始化模型
if 'models_loaded' not in st.session_state:
    with st.spinner('Loading models...'):
        (finetuned_model, finetuned_tokenizer, 
         perplexity_model, perplexity_tokenizer) = load_models()
        st.session_state.update({
            'finetuned_model': finetuned_model,
            'finetuned_tokenizer': finetuned_tokenizer,
            'perplexity_model': perplexity_model,
            'perplexity_tokenizer': perplexity_tokenizer,
            'models_loaded': True
        })

# 处理上传文件 (修改点2:添加按钮触发逻辑)
if uploaded_file and start_translation:
    st.session_state.processed = False  # 重置处理状态
    st.session_state.translated_code = []  # 清空已翻译内容
    
    with st.spinner('Processing file...'):
        code_lines = [line.decode('utf-8-sig') if isinstance(line, bytes) else line 
                      for line in uploaded_file.readlines()]
        
        # 添加行号(去掉冒号)
        numbered_original = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(code_lines)])
        numbered_translated = []

        # 创建两列布局
        col1, col2 = st.columns(2)
        
        # 原始代码框
        with col1:
            st.subheader("Original Python Code")
            original_content = st.session_state.original_content = numbered_original
            st.code(original_content, language='python')
        
        # 翻译代码框
        with col2:
            st.subheader("Real-time Translation")
            translated_box = st.empty()
            progress_bar = st.progress(0)
            status_text = st.empty()
            
            # 处理变量初始化
            translated_lines = []
            detected_count = 0
            translated_count = 0
            scores = []
            total_lines = len(code_lines)
            
            # 正则表达式模式
            pure_comment_pattern = re.compile(r'^(\s*)#.*?([\u4e00-\u9fff]+.*)')
            inline_comment_pattern = re.compile(r'(\s+#)\s*([^#]*[\u4e00-\u9fff]+[^#]*)')
            multi_comment_pattern = re.compile(r'^(\s*)(["\']{3})(.*?)\2', re.DOTALL)
            
            # 逐行处理
            for idx, line in enumerate(code_lines):
                current_line = line.rstrip('\n')
                
                # 更新进度
                progress = (idx + 1) / total_lines
                progress_bar.progress(progress)
                status_text.markdown(f"**Processing line {idx+1}/{total_lines}** | Content: `{current_line[:50]}...`")
                
                # 注释处理逻辑
                processed = False
                if pure_comment_pattern.search(line):
                    detected_count += 1
                    if match := pure_comment_pattern.match(line):
                        indent, comment = match.groups()
                        translated = translate_text(comment.strip(), CUSTOM_TERMS)
                        evaluate_translation(comment, translated, scores)
                        translated_lines.append(f"{indent}# {translated}\n")
                        translated_count += 1
                        processed = True
                
                if not processed and inline_comment_pattern.search(line):
                    detected_count += 1
                    if match := inline_comment_pattern.search(line):
                        code_part = line[:match.start()]
                        symbol, comment = match.groups()
                        translated = translate_text(comment.strip(), CUSTOM_TERMS)
                        evaluate_translation(comment, translated, scores)
                        translated_lines.append(f"{code_part}{symbol} {translated}\n")
                        translated_count += 1
                        processed = True
                
                if not processed and (multi_match := multi_comment_pattern.match(line)):
                    detected_count += 1
                    if re.search(r'[\u4e00-\u9fff]', multi_match.group(3)):
                        translated = translate_text(multi_match.group(3), CUSTOM_TERMS)
                        evaluate_translation(multi_match.group(3), translated, scores)
                        translated_lines.append(f"{multi_match.group(1)}{multi_match.group(2)}{translated}{multi_match.group(2)}\n")
                        translated_count += 1
                        processed = True
                
                if not processed:
                    translated_lines.append(line)
                
                # 更新带行号的翻译结果(去掉冒号)
                numbered_translation = "\n".join([f"{i+1} {line.rstrip()}" for i, line in enumerate(translated_lines)])
                translated_box.code(numbered_translation, language='python')
                time.sleep(0.1)
            
            # 处理完成
            st.session_state.translated_code = translated_lines
            st.session_state.detected_count = detected_count
            st.session_state.translated_count = translated_count
            st.session_state.scores = scores
            st.session_state.processed = True
            
            # 清空进度状态
            progress_bar.empty()
            status_text.empty()

# 显示统计信息
if st.session_state.processed:
    st.divider()
    
    # 右侧统计布局
    with st.container():
        col_right = st.columns([1, 3])[1]

        with col_right:
            # 第一行指标
            col1, col2 = st.columns(2)
            with col1:
                st.metric("Detected Comments", st.session_state.detected_count)
            with col2:
                st.metric("Translated Comments", st.session_state.translated_count)
            
            # 第二行指标
            col3, col4 = st.columns(2)
            with col3:
                if st.session_state.scores:
                    avg_bert = sum(f1 for f1, _ in st.session_state.scores) / len(st.session_state.scores)
                    st.metric("Average BERTScore", f"{avg_bert:.4f}", help="Higher is better (0-1)")
            with col4:
                if st.session_state.scores:
                    avg_ppl = sum(ppl for _, ppl in st.session_state.scores) / len(st.session_state.scores)
                    st.metric("Average Perplexity", f"{avg_ppl:.4f}", help="Lower is better (Typical range: 1~100+, lower means better translation)")
            
            # 下载按钮(修改点3:调整位置到指标下方)
            cols = st.columns([1, 2, 1])
            with cols[1]:
                with tempfile.NamedTemporaryFile(suffix='.py', delete=False) as tmp:
                    tmp.write("".join(st.session_state.translated_code).encode('utf-8'))
                
                with open(tmp.name, 'rb') as f:
                    st.download_button(
                        label="⬇️ Download Translated File",
                        data=f,
                        file_name=f"translated_{uploaded_file.name}",
                        mime='text/x-python',
                        use_container_width=False
                    )