Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
import time | |
from datasets import load_dataset | |
import pandas as pd | |
from typing import Any, Dict, Union, Optional, Generator, List | |
from dataclasses import dataclass | |
from .utils.command import preview_command, save_arguments, build_command_list | |
from .utils.monitor import TrainingMonitor | |
import traceback | |
import base64 | |
import tempfile | |
import numpy as np | |
import queue | |
import subprocess | |
import sys | |
import threading | |
class TrainingArgs: | |
def __init__(self, args: list, plm_models: dict, dataset_configs: dict): | |
# Basic parameters | |
self.plm_model = plm_models[args[0]] | |
# 处理自定义数据集或预定义数据集 | |
self.dataset_selection = args[1] # "Use Custom Dataset" 或 "Use Pre-defined Dataset" | |
if self.dataset_selection == "Use Pre-defined Dataset": | |
self.dataset_config = dataset_configs[args[2]] | |
self.dataset_custom = None | |
# 从配置加载问题类型等 | |
with open(self.dataset_config, 'r') as f: | |
config = json.load(f) | |
self.problem_type = config.get("problem_type", "single_label_classification") | |
self.num_labels = config.get("num_labels", 2) | |
self.metrics = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") | |
else: | |
self.dataset_config = None | |
self.dataset_custom = args[3] # Custom dataset path | |
self.problem_type = args[4] | |
self.num_labels = args[5] | |
self.metrics = args[6] | |
# 如果metrics是列表,转换为逗号分隔的字符串 | |
if isinstance(self.metrics, list): | |
self.metrics = ",".join(self.metrics) | |
# Training method parameters | |
self.training_method = args[7] | |
self.pooling_method = args[8] | |
# Batch processing parameters | |
self.batch_mode = args[9] | |
if self.batch_mode == "Batch Size Mode": | |
self.batch_size = args[10] | |
else: | |
self.batch_token = args[11] | |
# Training parameters | |
self.learning_rate = args[12] | |
self.num_epochs = args[13] | |
self.max_seq_len = args[14] | |
self.gradient_accumulation_steps = args[15] | |
self.warmup_steps = args[16] | |
self.scheduler = args[17] | |
# Output parameters | |
self.output_model_name = args[18] | |
self.output_dir = args[19] | |
# Wandb parameters | |
self.wandb_enabled = args[20] | |
if self.wandb_enabled: | |
self.wandb_project = args[21] | |
self.wandb_entity = args[22] | |
# Other parameters | |
self.patience = args[23] | |
self.num_workers = args[24] | |
self.max_grad_norm = args[25] | |
self.structure_seq = args[26] | |
# LoRA parameters | |
self.lora_r = args[27] | |
self.lora_alpha = args[28] | |
self.lora_dropout = args[29] | |
self.lora_target_modules = [m.strip() for m in args[30].split(",")] if args[30] else [] | |
def to_dict(self) -> Dict[str, Any]: | |
args_dict = { | |
"plm_model": self.plm_model, | |
"training_method": self.training_method, | |
"pooling_method": self.pooling_method, | |
"learning_rate": self.learning_rate, | |
"num_epochs": self.num_epochs, | |
"max_seq_len": self.max_seq_len, | |
"gradient_accumulation_steps": self.gradient_accumulation_steps, | |
"warmup_steps": self.warmup_steps, | |
"scheduler": self.scheduler, | |
"output_model_name": self.output_model_name, | |
"output_dir": self.output_dir, | |
"patience": self.patience, | |
"num_workers": self.num_workers, | |
"max_grad_norm": self.max_grad_norm, | |
} | |
if self.training_method == "ses-adapter" and self.structure_seq: | |
args_dict["structure_seq"] = ",".join(self.structure_seq) | |
# 添加数据集相关参数 | |
if self.dataset_selection == "Use Pre-defined Dataset": | |
args_dict["dataset_config"] = self.dataset_config | |
else: | |
args_dict["dataset"] = self.dataset_custom | |
args_dict["problem_type"] = self.problem_type | |
args_dict["num_labels"] = self.num_labels | |
args_dict["metrics"] = self.metrics | |
# Add LoRA parameters | |
if self.training_method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"]: | |
args_dict.update({ | |
"lora_r": self.lora_r, | |
"lora_alpha": self.lora_alpha, | |
"lora_dropout": self.lora_dropout, | |
"lora_target_modules": self.lora_target_modules | |
}) | |
# Add batch processing parameters | |
if self.batch_mode == "Batch Size Mode": | |
args_dict["batch_size"] = self.batch_size | |
else: | |
args_dict["batch_token"] = self.batch_token | |
# Add wandb parameters | |
if self.wandb_enabled: | |
args_dict["wandb"] = True | |
if self.wandb_project: | |
args_dict["wandb_project"] = self.wandb_project | |
if self.wandb_entity: | |
args_dict["wandb_entity"] = self.wandb_entity | |
return args_dict | |
def create_train_tab(constant: Dict[str, Any]) -> Dict[str, Any]: | |
# Create training monitor | |
monitor = TrainingMonitor() | |
# Add missing variable declarations | |
is_training = False | |
current_process = None | |
stop_thread = False | |
process_aborted = False | |
plm_models = constant["plm_models"] | |
dataset_configs = constant["dataset_configs"] | |
with gr.Tab("Training"): | |
# Model and Dataset Selection | |
gr.Markdown("### Model and Dataset Configuration") | |
# Original training interface components | |
with gr.Group(): | |
with gr.Row(): | |
with gr.Column(scale=4): | |
with gr.Row(): | |
plm_model = gr.Dropdown( | |
choices=list(plm_models.keys()), | |
label="Protein Language Model", | |
value=list(plm_models.keys())[0], | |
scale=2 | |
) | |
# 新增数据集选择方式 | |
is_custom_dataset = gr.Radio( | |
choices=["Use Custom Dataset", "Use Pre-defined Dataset"], | |
label="Dataset Selection", | |
value="Use Pre-defined Dataset", | |
scale=3 | |
) | |
dataset_config = gr.Dropdown( | |
choices=list(dataset_configs.keys()), | |
label="Dataset Configuration", | |
value=list(dataset_configs.keys())[0], | |
visible=True, | |
scale=2 | |
) | |
dataset_custom = gr.Textbox( | |
label="Custom Dataset Path", | |
placeholder="Huggingface Dataset eg: user/dataset", | |
visible=False, | |
scale=2 | |
) | |
# 将预览按钮放在单独的列中,并添加样式 | |
with gr.Column(scale=1, min_width=120, elem_classes="preview-button-container"): | |
dataset_preview_button = gr.Button( | |
"Preview Dataset", | |
variant="primary", | |
size="lg", | |
elem_classes="preview-button" | |
) | |
# 自定义数据集的额外配置选项(单独一行) | |
with gr.Row(visible=True) as custom_dataset_settings: | |
problem_type = gr.Dropdown( | |
choices=["single_label_classification", "multi_label_classification", "regression"], | |
label="Problem Type", | |
value="single_label_classification", | |
scale=23, | |
interactive=False | |
) | |
num_labels = gr.Number( | |
value=2, | |
label="Number of Labels", | |
scale=11, | |
interactive=False | |
) | |
metrics = gr.Dropdown( | |
choices=["accuracy", "recall", "precision", "f1", "mcc", "auroc", "f1max", "spearman_corr", "mse"], | |
label="Metrics", | |
value=["accuracy", "mcc", "f1", "precision", "recall", "auroc"], | |
scale=101, | |
multiselect=True, | |
interactive=False | |
) | |
with gr.Row(): | |
structure_seq = gr.Dropdown( | |
label="Structure Sequence", | |
choices=["foldseek_seq", "ss8_seq"], | |
value=["foldseek_seq", "ss8_seq"], | |
multiselect=True, | |
visible=False | |
) | |
# ! add for plm-lora, plm-qlora, plm_adalora, plm_dora, plm_ia3 | |
with gr.Row(visible=False) as lora_params_row: | |
# gr.Markdown("#### LoRA Parameters") | |
with gr.Column(): | |
lora_r = gr.Number( | |
value=8, | |
label="LoRA Rank", | |
precision=0, | |
minimum=1, | |
maximum=128, | |
) | |
with gr.Column(): | |
lora_alpha = gr.Number( | |
value=32, | |
label="LoRA Alpha", | |
precision=0, | |
minimum=1, | |
maximum=128 | |
) | |
with gr.Column(): | |
lora_dropout = gr.Number( | |
value=0.1, | |
label="LoRA Dropout", | |
minimum=0.0, | |
maximum=1.0 | |
) | |
with gr.Column(): | |
lora_target_modules = gr.Textbox( | |
value="query,key,value", | |
label="LoRA Target Modules", | |
placeholder="Comma-separated list of target modules", | |
# info="LoRA will be applied to these modules" | |
) | |
# 将数据统计和表格都放入折叠面板 | |
with gr.Row(): | |
with gr.Accordion("Dataset Preview", open=False) as preview_accordion: | |
# 数据统计区域 | |
with gr.Row(): | |
dataset_stats_md = gr.HTML("", elem_classes=["dataset-stats"]) | |
# 表格区域 | |
with gr.Row(): | |
preview_table = gr.Dataframe( | |
headers=["Name", "Sequence", "Label"], | |
value=[["No dataset selected", "-", "-"]], | |
wrap=True, | |
interactive=False, | |
row_count=3, | |
elem_classes=["preview-table"] | |
) | |
# Add CSS styles | |
gr.HTML(""" | |
<style> | |
/* 数据统计样式 */ | |
.dataset-stats { | |
margin: 0 0 15px 0; | |
padding: 0; | |
} | |
.dataset-stats table { | |
width: 100%; | |
border-collapse: collapse; | |
font-size: 0.9em; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
border-radius: 8px; | |
overflow: hidden; | |
table-layout: fixed; | |
} | |
.dataset-stats th { | |
background-color: #e0e0e0; | |
font-weight: bold; | |
padding: 6px 10px; | |
text-align: center; | |
border: 1px solid #ddd; | |
font-size: 0.95em; | |
white-space: nowrap; | |
overflow: hidden; | |
min-width: 120px; | |
} | |
.dataset-stats td { | |
padding: 6px 10px; | |
text-align: center; | |
border: 1px solid #ddd; | |
} | |
.dataset-stats h2 { | |
font-size: 1.1em; | |
margin: 0 0 10px 0; | |
text-align: center; | |
} | |
/* 表格样式 */ | |
.preview-table table { | |
background-color: white !important; | |
font-size: 0.9em !important; | |
width: 100%; | |
table-layout: fixed !important; | |
} | |
.preview-table .gr-block.gr-box { | |
background-color: transparent !important; | |
} | |
.preview-table .gr-input-label { | |
background-color: transparent !important; | |
} | |
/* 表格外观增强 */ | |
.preview-table table { | |
margin-top: 0; | |
border-radius: 8px; | |
overflow: hidden; | |
box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
} | |
/* 表头样式 */ | |
.preview-table th { | |
background-color: #e0e0e0 !important; | |
font-weight: bold !important; | |
padding: 6px !important; | |
border-bottom: 1px solid #ccc !important; | |
font-size: 0.95em !important; | |
text-align: center !important; | |
white-space: nowrap !important; | |
min-width: 120px !important; | |
} | |
/* 单元格样式 */ | |
.preview-table td { | |
padding: 4px 6px !important; | |
max-width: 300px !important; | |
overflow: hidden; | |
text-overflow: ellipsis; | |
white-space: nowrap; | |
text-align: left !important; | |
} | |
/* 悬停效果 */ | |
.preview-table tr:hover { | |
background-color: #f0f0f0 !important; | |
} | |
/* 折叠面板样式 */ | |
.gr-accordion { | |
border: 1px solid #e0e0e0; | |
border-radius: 8px; | |
overflow: hidden; | |
margin-bottom: 15px; | |
} | |
/* 折叠面板标题样式 */ | |
.gr-accordion .label-wrap { | |
background-color: #f5f5f5; | |
padding: 8px 15px; | |
font-weight: bold; | |
} | |
.preview-button { | |
height: 86px !important; | |
} | |
/* Center Model Statistics Table */ | |
.center-table-content td, .center-table-content th { | |
text-align: center !important; | |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important; | |
padding: 10px !important; | |
} | |
.center-table-content table { | |
width: 100% !important; | |
border-collapse: collapse !important; | |
margin-bottom: 20px !important; | |
box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important; | |
border-radius: 8px !important; | |
overflow: hidden !important; | |
} | |
.center-table-content th { | |
background-color: #f0f4f8 !important; | |
color: #2c3e50 !important; | |
font-weight: 600 !important; | |
border-bottom: 2px solid #ddd !important; | |
} | |
.center-table-content tr:nth-child(even) { | |
background-color: #f9f9f9 !important; | |
} | |
.center-table-content tr:hover { | |
background-color: #f0f7ff !important; | |
} | |
/* Improve readability of progress bars */ | |
.progress-container { | |
margin-bottom: 20px !important; | |
} | |
.progress-bar { | |
transition: width 0.5s ease-in-out !important; | |
} | |
.status-message { | |
margin-bottom: 8px !important; | |
font-weight: 500 !important; | |
} | |
</style> | |
""", visible=True) | |
# Batch Processing Configuration | |
gr.Markdown("### Batch Processing Configuration") | |
with gr.Group(): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
batch_mode = gr.Radio( | |
choices=["Batch Size Mode", "Batch Token Mode"], | |
label="Batch Processing Mode", | |
value="Batch Size Mode" | |
) | |
with gr.Column(scale=2): | |
batch_size = gr.Slider( | |
minimum=1, | |
maximum=128, | |
value=16, | |
step=1, | |
label="Batch Size", | |
visible=True | |
) | |
batch_token = gr.Slider( | |
minimum=1000, | |
maximum=50000, | |
value=10000, | |
step=1000, | |
label="Tokens per Batch", | |
visible=False | |
) | |
def update_batch_inputs(mode): | |
return { | |
batch_size: gr.update(visible=mode == "Batch Size Mode"), | |
batch_token: gr.update(visible=mode == "Batch Token Mode") | |
} | |
# Update visibility when mode changes | |
batch_mode.change( | |
fn=update_batch_inputs, | |
inputs=[batch_mode], | |
outputs=[batch_size, batch_token] | |
) | |
# Training Parameters | |
gr.Markdown("### Training Parameters") | |
with gr.Group(): | |
# First row: Basic training parameters | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, min_width=150): | |
training_method = gr.Dropdown( | |
choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"], | |
label="Training Method", | |
value="freeze" | |
) | |
with gr.Column(scale=1, min_width=150): | |
learning_rate = gr.Slider( | |
minimum=1e-8, maximum=1e-2, value=5e-4, step=1e-6, | |
label="Learning Rate" | |
) | |
with gr.Column(scale=1, min_width=150): | |
num_epochs = gr.Slider( | |
minimum=1, maximum=200, value=20, step=1, | |
label="Number of Epochs" | |
) | |
with gr.Column(scale=1, min_width=150): | |
patience = gr.Slider( | |
minimum=1, maximum=50, value=10, step=1, | |
label="Early Stopping Patience" | |
) | |
with gr.Column(scale=1, min_width=150): | |
max_seq_len = gr.Slider( | |
minimum=-1, maximum=2048, value=None, step=32, | |
label="Max Sequence Length (-1 for unlimited)" | |
) | |
def update_training_method(method): | |
return { | |
structure_seq: gr.update(visible=method == "ses-adapter"), | |
lora_params_row: gr.update(visible=method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"]) | |
} | |
# Add training_method change event | |
training_method.change( | |
fn=update_training_method, | |
inputs=[training_method], | |
outputs=[structure_seq, lora_params_row] | |
) | |
# Second row: Advanced training parameters | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1, min_width=150): | |
pooling_method = gr.Dropdown( | |
choices=["mean", "attention1d", "light_attention"], | |
label="Pooling Method", | |
value="mean" | |
) | |
with gr.Column(scale=1, min_width=150): | |
scheduler_type = gr.Dropdown( | |
choices=["linear", "cosine", "step", None], | |
label="Scheduler Type", | |
value=None | |
) | |
with gr.Column(scale=1, min_width=150): | |
warmup_steps = gr.Slider( | |
minimum=0, maximum=1000, value=0, step=10, | |
label="Warmup Steps" | |
) | |
with gr.Column(scale=1, min_width=150): | |
gradient_accumulation_steps = gr.Slider( | |
minimum=1, maximum=32, value=1, step=1, | |
label="Gradient Accumulation Steps" | |
) | |
with gr.Column(scale=1, min_width=150): | |
max_grad_norm = gr.Slider( | |
minimum=0.1, maximum=10.0, value=-1, step=0.1, | |
label="Max Gradient Norm (-1 for no clipping)" | |
) | |
with gr.Column(scale=1, min_width=150): | |
num_workers = gr.Slider( | |
minimum=0, maximum=16, value=4, step=1, | |
label="Number of Workers" | |
) | |
# Output and Logging Settings | |
gr.Markdown("### Output and Logging Settings") | |
with gr.Row(): | |
with gr.Column(): | |
output_dir = gr.Textbox( | |
label="Save Directory", | |
value="demo", | |
placeholder="Path to save training results" | |
) | |
output_model_name = gr.Textbox( | |
label="Output Model Name", | |
value="demo.pt", | |
placeholder="Name of the output model file" | |
) | |
with gr.Column(): | |
wandb_logging = gr.Checkbox( | |
label="Enable W&B Logging", | |
value=False | |
) | |
wandb_project = gr.Textbox( | |
label="W&B Project Name", | |
value=None, | |
visible=False | |
) | |
wandb_entity = gr.Textbox( | |
label="W&B Entity", | |
value=None, | |
visible=False | |
) | |
# Training Control and Output | |
gr.Markdown("### Training Control") | |
with gr.Row(): | |
preview_button = gr.Button("Preview Command") | |
abort_button = gr.Button("Abort", variant="stop") | |
train_button = gr.Button("Start", variant="primary") | |
with gr.Row(): | |
command_preview = gr.Code( | |
label="Command Preview", | |
language="shell", | |
interactive=False, | |
visible=False | |
) | |
# Model Statistics Section | |
gr.Markdown("### Model Statistics") | |
with gr.Row(): | |
model_stats = gr.Dataframe( | |
headers=["Model Type", "Total Parameters", "Trainable Parameters", "Percentage"], | |
value=[ | |
["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"] | |
], | |
interactive=False, | |
elem_classes=["center-table-content"] | |
) | |
def update_model_stats(stats: Dict[str, str]) -> List[List[str]]: | |
"""Update model statistics in table format.""" | |
if not stats: | |
return [ | |
["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"] | |
] | |
adapter_total = stats.get('adapter_total', '-') | |
adapter_trainable = stats.get('adapter_trainable', '-') | |
pretrain_total = stats.get('pretrain_total', '-') | |
pretrain_trainable = stats.get('pretrain_trainable', '-') | |
combined_total = stats.get('combined_total', '-') | |
combined_trainable = stats.get('combined_trainable', '-') | |
trainable_percentage = stats.get('trainable_percentage', '-') | |
return [ | |
["Training Model", str(adapter_total), str(adapter_trainable), "-"], | |
["Pre-trained Model", str(pretrain_total), str(pretrain_trainable), "-"], | |
["Combined Model", str(combined_total), str(combined_trainable), str(trainable_percentage)] | |
] | |
# Training Progress | |
gr.Markdown("### Training Progress") | |
with gr.Row(): | |
progress_status = gr.HTML( | |
value=""" | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Click Start to train your model</span> | |
</div> | |
</div> | |
</div> | |
""", | |
label="Status" | |
) | |
with gr.Row(): | |
best_model_info = gr.Textbox( | |
value="Best Model: None", | |
label="Best Performance", | |
interactive=False | |
) | |
# Add test results HTML display | |
with gr.Row(): | |
test_results_html = gr.HTML( | |
value="", | |
label="Test Results", | |
visible=True | |
) | |
with gr.Row(): | |
with gr.Column(scale=4): | |
pass | |
with gr.Column(scale=1): # 限制列的最大宽度 | |
download_csv_btn = gr.DownloadButton( | |
"Download CSV", | |
visible=False, | |
size="lg" | |
) | |
# 添加一个空列来占据剩余空间 | |
with gr.Column(scale=4): | |
pass | |
# Training plot in a separate row for full width | |
with gr.Row(): | |
with gr.Column(): | |
loss_plot = gr.Plot( | |
label="Training and Validation Loss", | |
elem_id="loss_plot" | |
) | |
with gr.Column(): | |
metrics_plot = gr.Plot( | |
label="Validation Metrics", | |
elem_id="metrics_plot" | |
) | |
def update_progress(progress_info): | |
# If progress_info is empty or None, use completely fresh empty state | |
if not progress_info or not any(progress_info.values()): | |
fresh_status_html = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Click Start to train your model</span> | |
</div> | |
</div> | |
</div> | |
""" | |
return ( | |
fresh_status_html, | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False) | |
) | |
# Reset values if stage is "Waiting" or "Error" | |
if progress_info.get('stage', '') == 'Waiting' or progress_info.get('stage', '') == 'Error': | |
# If this is an error stage, show error styling | |
if progress_info.get('stage', '') == 'Error': | |
error_status_html = """ | |
<div style="background-color: #ffebee; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #c62828; font-weight: 500; font-size: 16px;">Failed</span> | |
</div> | |
</div> | |
</div> | |
""" | |
return ( | |
error_status_html, | |
"Training failed", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False) | |
) | |
else: | |
return ( | |
""" | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Waiting to start...</span> | |
</div> | |
</div> | |
</div> | |
""", | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False) | |
) | |
current = progress_info.get('current', 0) | |
total = progress_info.get('total', 100) | |
epoch = progress_info.get('epoch', 0) | |
stage = progress_info.get('stage', 'Waiting') | |
progress_detail = progress_info.get('progress_detail', '') | |
best_epoch = progress_info.get('best_epoch', 0) | |
best_metric_name = progress_info.get('best_metric_name', 'accuracy') | |
best_metric_value = progress_info.get('best_metric_value', 0.0) | |
elapsed_time = progress_info.get('elapsed_time', '') | |
remaining_time = progress_info.get('remaining_time', '') | |
it_per_sec = progress_info.get('it_per_sec', 0.0) | |
grad_step = progress_info.get('grad_step', 0) | |
loss = progress_info.get('loss', 0.0) | |
total_epochs = progress_info.get('total_epochs', 0) # 获取总epoch数 | |
test_results_html = progress_info.get('test_results_html', '') # 获取测试结果HTML | |
test_metrics = progress_info.get('test_metrics', {}) # 获取测试指标 | |
is_completed = progress_info.get('is_completed', False) # 检查训练是否完成 | |
# Test results HTML visibility is always True, but show message when content is empty | |
if not test_results_html and stage == 'Testing': | |
test_results_html = """ | |
<div style="text-align: center; padding: 20px; color: #666;"> | |
<p>Testing in progress, please wait for results...</p> | |
</div> | |
""" | |
elif not test_results_html: | |
test_results_html = """ | |
<div style="text-align: center; padding: 20px; color: #666;"> | |
<p>Test results will be displayed after testing phase completes</p> | |
</div> | |
""" | |
test_html_update = gr.update(value=test_results_html, visible=True) | |
# 处理CSV下载按钮 | |
if test_metrics and len(test_metrics) > 0: | |
# 创建临时文件保存CSV内容 | |
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', prefix='metrics_results_') as temp_file: | |
# 写入CSV头部 | |
temp_file.write("Metric,Value\n") | |
# 按照优先级排序指标 | |
priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc'] | |
def get_priority(item): | |
name = item[0] | |
if name in priority_metrics: | |
return priority_metrics.index(name) | |
return len(priority_metrics) | |
# 排序并添加到CSV | |
sorted_metrics = sorted(test_metrics.items(), key=get_priority) | |
for metric_name, metric_value in sorted_metrics: | |
# Convert metric name: uppercase for abbreviations, capitalize for others | |
display_name = metric_name | |
if metric_name.lower() in ['f1', 'mcc', 'auroc']: | |
display_name = metric_name.upper() | |
else: | |
display_name = metric_name.capitalize() | |
temp_file.write(f"{display_name},{metric_value:.6f}\n") | |
file_path = temp_file.name | |
download_btn_update = gr.update(value=file_path, visible=True) | |
else: | |
download_btn_update = gr.update(visible=False) | |
# 计算进度百分比 | |
progress_percentage = (current / total) * 100 if total > 0 else 0 | |
# 创建现代化的进度条HTML | |
if is_completed: | |
# 训练完成状态 | |
status_html = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #4caf50; font-weight: 500; font-size: 16px;">Training complete!</span> | |
</div> | |
<div> | |
<span style="font-weight: 600; color: #333;">100%</span> | |
</div> | |
</div> | |
<div style="margin-bottom: 15px; background-color: #e9ecef; height: 10px; border-radius: 5px; overflow: hidden;"> | |
<div style="background-color: #4caf50; width: 100%; height: 100%; border-radius: 5px;"></div> | |
</div> | |
</div> | |
""" | |
else: | |
# 训练或验证阶段 | |
epoch_total = total_epochs if total_epochs > 0 else 100 | |
status_html = f""" | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">{stage} (Epoch {epoch}/{epoch_total})</span> | |
</div> | |
<div> | |
<span style="font-weight: 600; color: #333;">{progress_percentage:.1f}%</span> | |
</div> | |
</div> | |
<div style="margin-bottom: 15px; background-color: #e9ecef; height: 10px; border-radius: 5px; overflow: hidden;"> | |
<div style="background-color: #4285f4; width: {progress_percentage}%; height: 100%; border-radius: 5px; transition: width 0.3s ease;"></div> | |
</div> | |
<div style="display: flex; flex-wrap: wrap; gap: 10px; font-size: 14px; color: #555;"> | |
<div style="background-color: #e8f5e9; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Progress:</span> {current}/{total}</div> | |
{f'<div style="background-color: #fff8e1; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Time:</span> {elapsed_time}<{remaining_time}, {it_per_sec:.2f}it/s></div>' if elapsed_time and remaining_time else ''} | |
{f'<div style="background-color: #e3f2fd; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Loss:</span> {loss:.4f}</div>' if stage == 'Training' and loss > 0 else ''} | |
{f'<div style="background-color: #f3e5f5; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Grad steps:</span> {grad_step}</div>' if stage == 'Training' and grad_step > 0 else ''} | |
</div> | |
</div> | |
""" | |
# 构建最佳模型信息 | |
if best_epoch >= 0 and best_metric_value > 0: | |
best_info = f"Best model: Epoch {best_epoch} ({best_metric_name}: {best_metric_value:.4f})" | |
else: | |
best_info = "No best model found yet" | |
# 获取并更新图表 | |
loss_fig = monitor.get_loss_plot() | |
metrics_fig = monitor.get_metrics_plot() | |
# 返回更新的组件 | |
return status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update | |
def handle_train(*args) -> Generator: | |
nonlocal is_training, current_process, stop_thread, process_aborted, monitor | |
# If already training, return | |
if is_training: | |
yield None, None, None, None, None, None, None | |
return | |
# Force explicit state reset first thing | |
monitor._reset_tracking() | |
monitor._reset_stats() | |
# Explicitly ensure stats are reset | |
if hasattr(monitor, "stats"): | |
monitor.stats = {} | |
# Force override any cached state in monitor | |
monitor.current_progress = { | |
"current": 0, | |
"total": 0, | |
"epoch": 0, | |
"stage": "Waiting", | |
"progress_detail": "", | |
"best_epoch": -1, | |
"best_metric_name": "", | |
"best_metric_value": 0.0, | |
"elapsed_time": "", | |
"remaining_time": "", | |
"it_per_sec": 0.0, | |
"grad_step": 0, | |
"loss": 0.0, | |
"test_results_html": "", | |
"test_metrics": {}, | |
"is_completed": False, | |
"lines": [] | |
} | |
# Reset all monitoring data structures | |
monitor.train_losses = [] | |
monitor.val_losses = [] | |
monitor.metrics = {} | |
monitor.epochs = [] | |
if hasattr(monitor, "stats"): | |
monitor.stats = {} | |
# Reset flags for new training session | |
process_aborted = False | |
stop_thread = False | |
# Initialize table state | |
initial_stats = [ | |
["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"] | |
] | |
# Initial UI state with "Initializing" message | |
initial_status_html = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Initializing training environment...</span> | |
</div> | |
</div> | |
<div style="font-size: 14px; color: #555; margin-top: 10px;"> | |
<p>• Parsing configuration parameters</p> | |
<p>• Preparing training environment</p> | |
<p>• This may take a few moments...</p> | |
</div> | |
</div> | |
""" | |
# First yield to update UI with "initializing" state | |
yield initial_stats, initial_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
try: | |
# Parse training arguments | |
training_args = TrainingArgs(args, plm_models, dataset_configs) | |
if training_args.training_method != "ses-adapter": | |
training_args.structure_seq = None | |
args_dict = training_args.to_dict() | |
# Save total epochs to monitor for use in progress_info | |
total_epochs = args_dict.get('num_epochs', 100) | |
monitor.current_progress['total_epochs'] = total_epochs | |
# Update status to "Preparing dataset" | |
preparing_status_html = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Preparing dataset and model...</span> | |
</div> | |
</div> | |
<div style="font-size: 14px; color: #555; margin-top: 10px;"> | |
<p>• Loading dataset</p> | |
<p>• Initializing model architecture</p> | |
<p>• Setting up training environment</p> | |
</div> | |
</div> | |
""" | |
yield initial_stats, preparing_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
# Save arguments to file | |
save_arguments(args_dict, args_dict.get('output_dir', 'ckpt')) | |
# Start training | |
is_training = True | |
process_aborted = False # Reset abort flag | |
monitor.start_training(args_dict) | |
current_process = monitor.process # Store the process reference | |
starting_status_html = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Starting training process...</span> | |
</div> | |
</div> | |
<div style="font-size: 14px; color: #555; margin-top: 10px;"> | |
<p>• Training process launched</p> | |
<p>• Waiting for first statistics to appear</p> | |
<p>• This may take a moment for large models</p> | |
</div> | |
</div> | |
""" | |
yield initial_stats, starting_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
# Add delay to ensure enough time for parsing initial statistics | |
for i in range(3): | |
time.sleep(1) | |
# Check if statistics are already available | |
stats = monitor.get_stats() | |
if stats and len(stats) > 0: | |
break | |
update_count = 0 | |
while True: | |
# Check if the process still exists and hasn't been aborted | |
if process_aborted or not monitor.is_training or current_process is None or (current_process and current_process.poll() is not None): | |
break | |
try: | |
update_count += 1 | |
time.sleep(0.5) | |
# Check process status | |
monitor.check_process_status() | |
# Get latest progress info | |
progress_info = monitor.get_progress() | |
# If process has ended, check if it's normal end or error | |
if not monitor.is_training: | |
# Check both monitor.process and current_process since they might be different objects | |
if (monitor.process and monitor.process.returncode != 0) or (current_process and current_process.poll() is not None and current_process.returncode != 0): | |
# Get the return code from whichever process object is available | |
return_code = monitor.process.returncode if monitor.process else current_process.returncode | |
# Get complete output log | |
error_output = "\n".join(progress_info.get("lines", [])) | |
if not error_output: | |
error_output = "No output captured from the training process" | |
# Ensure we set the is_completed flag to False for errors | |
progress_info['is_completed'] = False | |
monitor.current_progress['is_completed'] = False | |
# Also set the stage to Error | |
progress_info['stage'] = 'Error' | |
monitor.current_progress['stage'] = 'Error' | |
error_status_html = f""" | |
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
<p style="margin: 0; color: #c62828; font-weight: bold;">Training failed with error code {return_code}:</p> | |
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre> | |
</div> | |
""" | |
yield ( | |
initial_stats, | |
error_status_html, | |
"Training failed", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False) | |
) | |
return | |
else: | |
# Only set is_completed to True if there was a successful exit code | |
progress_info['is_completed'] = True | |
monitor.current_progress['is_completed'] = True | |
# Update UI | |
stats = monitor.get_stats() | |
if stats: | |
model_stats = update_model_stats(stats) | |
else: | |
model_stats = initial_stats | |
status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info) | |
yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update | |
except Exception as e: | |
# Get complete output log | |
error_output = "\n".join(progress_info.get("lines", [])) | |
if not error_output: | |
error_output = "No output captured from the training process" | |
error_status_html = f""" | |
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
<p style="margin: 0; color: #c62828; font-weight: bold;">Error during training:</p> | |
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p> | |
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre> | |
</div> | |
""" | |
print(f"Error updating UI: {str(e)}") | |
traceback.print_exc() | |
yield initial_stats, error_status_html, "Training error", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
return | |
# Check if aborted | |
if process_aborted: | |
is_training = False | |
current_process = None | |
aborted_status_html = """ | |
<div style="padding: 10px; background-color: #e8f5e9; border-radius: 5px;"> | |
<p style="margin: 0; color: #2e7d32; font-weight: bold;">Training was manually terminated.</p> | |
</div> | |
""" | |
yield initial_stats, aborted_status_html, "Training aborted", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
return | |
# Final update after training ends (only for normal completion) | |
if monitor.process and monitor.process.returncode == 0: | |
try: | |
progress_info = monitor.get_progress() | |
progress_info['is_completed'] = True | |
monitor.current_progress['is_completed'] = True | |
stats = monitor.get_stats() | |
if stats: | |
model_stats = update_model_stats(stats) | |
else: | |
model_stats = initial_stats | |
status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info) | |
yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update | |
except Exception as e: | |
error_output = "\n".join(progress_info.get("lines", [])) | |
if not error_output: | |
error_output = "No output captured from the training process" | |
error_status_html = f""" | |
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
<p style="margin: 0; color: #c62828; font-weight: bold;">Error in final update:</p> | |
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p> | |
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre> | |
</div> | |
""" | |
yield initial_stats, error_status_html, "Error in final update", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
except Exception as e: | |
# Initialization error, may not have output log | |
error_status_html = f""" | |
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
<p style="margin: 0; color: #c62828; font-weight: bold;">Training initialization failed:</p> | |
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p> | |
</div> | |
""" | |
yield initial_stats, error_status_html, "Training failed", gr.update(value="", visible=False), None, None, gr.update(visible=False) | |
finally: | |
is_training = False | |
current_process = None | |
def handle_abort(): | |
"""Handle abortion of the training process""" | |
nonlocal is_training, current_process, stop_thread, process_aborted | |
if not is_training or current_process is None: | |
return (gr.HTML(""" | |
<div style="padding: 10px; background-color: #f5f5f5; border-radius: 5px;"> | |
<p style="margin: 0;">No training process is currently running.</p> | |
</div> | |
"""), | |
[["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"]], | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False)) | |
try: | |
# Set the abort flag before terminating the process | |
process_aborted = True | |
stop_thread = True | |
# Use process.terminate() instead of os.killpg for safer termination | |
# This avoids accidentally killing the parent WebUI process | |
current_process.terminate() | |
# Wait for process to terminate (with timeout) | |
try: | |
current_process.wait(timeout=5) | |
except subprocess.TimeoutExpired: | |
# Only if terminate didn't work, use a stronger method | |
# But do NOT use killpg which might kill the parent WebUI | |
current_process.kill() | |
# Create a completely fresh state - not just resetting | |
monitor.is_training = False | |
# Explicitly create a new dictionary instead of modifying the existing one | |
monitor.current_progress = { | |
"current": 0, | |
"total": 0, | |
"epoch": 0, | |
"stage": "Waiting", | |
"progress_detail": "", | |
"best_epoch": -1, | |
"best_metric_name": "", | |
"best_metric_value": 0.0, | |
"elapsed_time": "", | |
"remaining_time": "", | |
"it_per_sec": 0.0, | |
"grad_step": 0, | |
"loss": 0.0, | |
"test_results_html": "", | |
"test_metrics": {}, | |
"is_completed": False, | |
"lines": [] | |
} | |
# Explicitly clear stats by creating a new dictionary | |
monitor.stats = {} | |
if hasattr(monitor, "process") and monitor.process: | |
monitor.process = None | |
# Reset state variables | |
is_training = False | |
current_process = None | |
# Explicitly reset tracking to clear all state | |
monitor._reset_tracking() | |
monitor._reset_stats() | |
# Reset all plots and statistics with new empty lists | |
monitor.train_losses = [] | |
monitor.val_losses = [] | |
monitor.metrics = {} | |
monitor.epochs = [] | |
# Create entirely fresh UI components | |
empty_model_stats = [["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"]] | |
success_html = """ | |
<div style="padding: 10px; background-color: #e8f5e9; border-radius: 5px;"> | |
<p style="margin: 0; color: #2e7d32; font-weight: bold;">Training successfully terminated!</p> | |
<p style="margin: 5px 0 0; color: #388e3c;">All training state has been reset. You can start a new training session.</p> | |
</div> | |
""" | |
# Return updates for all relevant components | |
return (gr.HTML(success_html), | |
empty_model_stats, | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False)) | |
except Exception as e: | |
# Still need to reset states even if there's an error | |
is_training = False | |
current_process = None | |
process_aborted = False | |
# Reset monitor state regardless of error | |
monitor.is_training = False | |
monitor.stats = {} | |
if hasattr(monitor, "process") and monitor.process: | |
monitor.process = None | |
monitor._reset_tracking() | |
monitor._reset_stats() | |
# Fresh empty components | |
empty_model_stats = [["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"]] | |
error_html = f""" | |
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px;"> | |
<p style="margin: 0; color: #c62828; font-weight: bold;">Failed to terminate training: {str(e)}</p> | |
<p style="margin: 5px 0 0; color: #c62828;">Training state has been reset.</p> | |
</div> | |
""" | |
# Return updates for all relevant components including empty model stats | |
return (gr.HTML(error_html), | |
empty_model_stats, | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, | |
None, | |
gr.update(visible=False)) | |
def update_wandb_visibility(checkbox): | |
return { | |
wandb_project: gr.update(visible=checkbox), | |
wandb_entity: gr.update(visible=checkbox) | |
} | |
# define all input components | |
input_components = [ | |
plm_model, #0 | |
is_custom_dataset, #1 | |
dataset_config, #2 | |
dataset_custom, #3 | |
problem_type, #4 | |
num_labels, #5 | |
metrics, #6 | |
training_method, #7 | |
pooling_method, #8 | |
batch_mode, #9 | |
batch_size, #10 | |
batch_token, #11 | |
learning_rate, #12 | |
num_epochs, #13 | |
max_seq_len, #14 | |
gradient_accumulation_steps, #15 | |
warmup_steps, #16 | |
scheduler_type, #17 | |
output_model_name, #18 | |
output_dir, #19 | |
wandb_logging, #20 | |
wandb_project, #21 | |
wandb_entity, #22 | |
patience, #23 | |
num_workers, #24 | |
max_grad_norm, #25 | |
structure_seq, #26 | |
lora_r, #27 | |
lora_alpha, #28 | |
lora_dropout, #29 | |
lora_target_modules, #30 | |
] | |
# bind preview and train buttons | |
def handle_preview(*args): | |
if command_preview.visible: | |
return gr.update(visible=False) | |
training_args = TrainingArgs(args, plm_models, dataset_configs) | |
preview_text = preview_command(training_args.to_dict()) | |
return gr.update(value=preview_text, visible=True) | |
def reset_train_ui(): | |
"""Reset the UI state before training starts""" | |
# Reset monitor state | |
monitor._reset_tracking() | |
monitor._reset_stats() | |
# Explicitly ensure stats are reset | |
if hasattr(monitor, "stats"): | |
monitor.stats = {} | |
# Create a completely fresh progress state | |
monitor.current_progress = { | |
"current": 0, | |
"total": 0, | |
"epoch": 0, | |
"stage": "Waiting", | |
"progress_detail": "", | |
"best_epoch": -1, | |
"best_metric_name": "", | |
"best_metric_value": 0.0, | |
"elapsed_time": "", | |
"remaining_time": "", | |
"it_per_sec": 0.0, | |
"grad_step": 0, | |
"loss": 0.0, | |
"test_results_html": "", | |
"test_metrics": {}, | |
"is_completed": False, | |
"lines": [] | |
} | |
# Reset all statistical data | |
monitor.train_losses = [] | |
monitor.val_losses = [] | |
monitor.metrics = {} | |
monitor.epochs = [] | |
# Force UI to reset by creating completely fresh components | |
empty_model_stats = [["Training Model", "-", "-", "-"], | |
["Pre-trained Model", "-", "-", "-"], | |
["Combined Model", "-", "-", "-"]] | |
empty_progress_status = """ | |
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
<div> | |
<span style="font-weight: 600; font-size: 16px;">Training Status: </span> | |
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Preparing to start training...</span> | |
</div> | |
</div> | |
</div> | |
""" | |
# Return exactly 7 values matching the 7 output components | |
return ( | |
empty_model_stats, | |
empty_progress_status, | |
"Best Model: None", | |
gr.update(value="", visible=False), | |
None, # loss_plot must be None, not a string | |
None, # metrics_plot must be None, not a string | |
gr.update(visible=False) | |
) | |
preview_button.click( | |
fn=handle_preview, | |
inputs=input_components, | |
outputs=[command_preview] | |
) | |
train_button.click( | |
fn=reset_train_ui, | |
outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] | |
).then( | |
fn=handle_train, | |
inputs=input_components, | |
outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] | |
) | |
# bind abort button | |
abort_button.click( | |
fn=handle_abort, | |
outputs=[progress_status, model_stats, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] | |
) | |
wandb_logging.change( | |
fn=update_wandb_visibility, | |
inputs=[wandb_logging], | |
outputs=[wandb_project, wandb_entity] | |
) | |
def update_dataset_preview(dataset_type=None, dataset_name=None, custom_dataset=None): | |
"""Update dataset preview content""" | |
# Determine which dataset to use based on selection | |
if dataset_type == "Use Custom Dataset" and custom_dataset: | |
try: | |
# Try to load custom dataset | |
dataset = load_dataset(custom_dataset) | |
stats_html = f""" | |
<div style="text-align: center; margin: 20px 0;"> | |
<table style="width: 100%; border-collapse: collapse; margin: 0 auto;"> | |
<tr> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th> | |
</tr> | |
<tr> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{custom_dataset}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td> | |
</tr> | |
</table> | |
</div> | |
""" | |
# Get sample data points | |
split = "train" if "train" in dataset else list(dataset.keys())[0] | |
samples = dataset[split].select(range(min(3, len(dataset[split])))) | |
if len(samples) == 0: | |
return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
# Get fields actually present in the dataset | |
available_fields = list(samples[0].keys()) | |
# Build sample data | |
sample_data = [] | |
for sample in samples: | |
sample_dict = {} | |
for field in available_fields: | |
# Keep full sequence | |
sample_dict[field] = str(sample[field]) | |
sample_data.append(sample_dict) | |
df = pd.DataFrame(sample_data) | |
return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) | |
except Exception as e: | |
error_html = f""" | |
<div> | |
<h2>Error loading dataset</h2> | |
<p style="color: #c62828;">{str(e)}</p> | |
</div> | |
""" | |
return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
# Use predefined dataset | |
elif dataset_type == "Use Pre-defined Dataset" and dataset_name: | |
try: | |
config_path = dataset_configs[dataset_name] | |
with open(config_path, 'r') as f: | |
config = json.load(f) | |
# Load dataset statistics | |
dataset = load_dataset(config["dataset"]) | |
stats_html = f""" | |
<div style="text-align: center; margin: 20px 0;"> | |
<table style="width: 100%; border-collapse: collapse; margin: 0 auto;"> | |
<tr> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th> | |
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th> | |
</tr> | |
<tr> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{config["dataset"]}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td> | |
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td> | |
</tr> | |
</table> | |
</div> | |
""" | |
# Get sample data points and available fields | |
samples = dataset["train"].select(range(min(3, len(dataset["train"])))) | |
if len(samples) == 0: | |
return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
# Get fields actually present in the dataset | |
available_fields = list(samples[0].keys()) | |
# Build sample data | |
sample_data = [] | |
for sample in samples: | |
sample_dict = {} | |
for field in available_fields: | |
# Keep full sequence | |
sample_dict[field] = str(sample[field]) | |
sample_data.append(sample_dict) | |
df = pd.DataFrame(sample_data) | |
return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) | |
except Exception as e: | |
error_html = f""" | |
<div> | |
<h2>Error loading dataset</h2> | |
<p style="color: #c62828;">{str(e)}</p> | |
</div> | |
""" | |
return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
# If no valid dataset information provided | |
return gr.update(value=""), gr.update(value=[["No dataset selected", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
# Preview button click event | |
dataset_preview_button.click( | |
fn=update_dataset_preview, | |
inputs=[is_custom_dataset, dataset_config, dataset_custom], | |
outputs=[dataset_stats_md, preview_table, preview_accordion] | |
) | |
# 添加自定义数据集设置的函数 | |
def update_dataset_settings(choice, dataset_name=None): | |
if choice == "Use Pre-defined Dataset": | |
# 从dataset_config加载配置 | |
result = { | |
dataset_config: gr.update(visible=True), | |
dataset_custom: gr.update(visible=False), | |
custom_dataset_settings: gr.update(visible=True) | |
} | |
# 如果有选择特定数据集,自动加载配置 | |
if dataset_name and dataset_name in dataset_configs: | |
with open(dataset_configs[dataset_name], 'r') as f: | |
config = json.load(f) | |
# 处理metrics,将字符串转换为列表以适应多选组件 | |
metrics_value = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") | |
if isinstance(metrics_value, str): | |
metrics_value = metrics_value.split(",") | |
result.update({ | |
problem_type: gr.update(value=config.get("problem_type", "single_label_classification"), interactive=False), | |
num_labels: gr.update(value=config.get("num_labels", 2), interactive=False), | |
metrics: gr.update(value=metrics_value, interactive=False), | |
}) | |
return result | |
else: | |
# 自定义数据集设置,清零/设为默认值并可编辑 | |
# 为多选组件提供默认值列表 | |
default_metrics = ["accuracy", "mcc", "f1", "precision", "recall", "auroc"] | |
return { | |
dataset_config: gr.update(visible=False), | |
dataset_custom: gr.update(visible=True), | |
custom_dataset_settings: gr.update(visible=True), | |
problem_type: gr.update(value="single_label_classification", interactive=True), | |
num_labels: gr.update(value=2, interactive=True), | |
metrics: gr.update(value=default_metrics, interactive=True) | |
} | |
# 绑定数据集设置更新事件 | |
is_custom_dataset.change( | |
fn=update_dataset_settings, | |
inputs=[is_custom_dataset, dataset_config], | |
outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics] | |
) | |
dataset_config.change( | |
fn=lambda x: update_dataset_settings("Use Pre-defined Dataset", x), | |
inputs=[dataset_config], | |
outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics] | |
) | |
# Return components that need to be accessed from outside | |
return { | |
"output_text": progress_status, | |
"loss_plot": loss_plot, | |
"metrics_plot": metrics_plot, | |
"train_button": train_button, | |
"monitor": monitor, | |
"test_results_html": test_results_html, # 添加测试结果HTML组件 | |
"components": { | |
"plm_model": plm_model, | |
"dataset_config": dataset_config, | |
"training_method": training_method, | |
"pooling_method": pooling_method, | |
"batch_mode": batch_mode, | |
"batch_size": batch_size, | |
"batch_token": batch_token, | |
"learning_rate": learning_rate, | |
"num_epochs": num_epochs, | |
"max_seq_len": max_seq_len, | |
"gradient_accumulation_steps": gradient_accumulation_steps, | |
"warmup_steps": warmup_steps, | |
"scheduler_type": scheduler_type, | |
"output_model_name": output_model_name, | |
"output_dir": output_dir, | |
"wandb_logging": wandb_logging, | |
"wandb_project": wandb_project, | |
"wandb_entity": wandb_entity, | |
"patience": patience, | |
"num_workers": num_workers, | |
"max_grad_norm": max_grad_norm, | |
"structure_seq": structure_seq, | |
"lora_r": lora_r, | |
"lora_alpha": lora_alpha, | |
"lora_dropout": lora_dropout, | |
"lora_target_modules": lora_target_modules, | |
} | |
} |