VenusFactory / src /web /train_tab.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import os
import json
import gradio as gr
import time
from datasets import load_dataset
import pandas as pd
from typing import Any, Dict, Union, Optional, Generator, List
from dataclasses import dataclass
from .utils.command import preview_command, save_arguments, build_command_list
from .utils.monitor import TrainingMonitor
import traceback
import base64
import tempfile
import numpy as np
import queue
import subprocess
import sys
import threading
@dataclass
class TrainingArgs:
def __init__(self, args: list, plm_models: dict, dataset_configs: dict):
# Basic parameters
self.plm_model = plm_models[args[0]]
# 处理自定义数据集或预定义数据集
self.dataset_selection = args[1] # "Use Custom Dataset" 或 "Use Pre-defined Dataset"
if self.dataset_selection == "Use Pre-defined Dataset":
self.dataset_config = dataset_configs[args[2]]
self.dataset_custom = None
# 从配置加载问题类型等
with open(self.dataset_config, 'r') as f:
config = json.load(f)
self.problem_type = config.get("problem_type", "single_label_classification")
self.num_labels = config.get("num_labels", 2)
self.metrics = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc")
else:
self.dataset_config = None
self.dataset_custom = args[3] # Custom dataset path
self.problem_type = args[4]
self.num_labels = args[5]
self.metrics = args[6]
# 如果metrics是列表,转换为逗号分隔的字符串
if isinstance(self.metrics, list):
self.metrics = ",".join(self.metrics)
# Training method parameters
self.training_method = args[7]
self.pooling_method = args[8]
# Batch processing parameters
self.batch_mode = args[9]
if self.batch_mode == "Batch Size Mode":
self.batch_size = args[10]
else:
self.batch_token = args[11]
# Training parameters
self.learning_rate = args[12]
self.num_epochs = args[13]
self.max_seq_len = args[14]
self.gradient_accumulation_steps = args[15]
self.warmup_steps = args[16]
self.scheduler = args[17]
# Output parameters
self.output_model_name = args[18]
self.output_dir = args[19]
# Wandb parameters
self.wandb_enabled = args[20]
if self.wandb_enabled:
self.wandb_project = args[21]
self.wandb_entity = args[22]
# Other parameters
self.patience = args[23]
self.num_workers = args[24]
self.max_grad_norm = args[25]
self.structure_seq = args[26]
# LoRA parameters
self.lora_r = args[27]
self.lora_alpha = args[28]
self.lora_dropout = args[29]
self.lora_target_modules = [m.strip() for m in args[30].split(",")] if args[30] else []
def to_dict(self) -> Dict[str, Any]:
args_dict = {
"plm_model": self.plm_model,
"training_method": self.training_method,
"pooling_method": self.pooling_method,
"learning_rate": self.learning_rate,
"num_epochs": self.num_epochs,
"max_seq_len": self.max_seq_len,
"gradient_accumulation_steps": self.gradient_accumulation_steps,
"warmup_steps": self.warmup_steps,
"scheduler": self.scheduler,
"output_model_name": self.output_model_name,
"output_dir": self.output_dir,
"patience": self.patience,
"num_workers": self.num_workers,
"max_grad_norm": self.max_grad_norm,
}
if self.training_method == "ses-adapter" and self.structure_seq:
args_dict["structure_seq"] = ",".join(self.structure_seq)
# 添加数据集相关参数
if self.dataset_selection == "Use Pre-defined Dataset":
args_dict["dataset_config"] = self.dataset_config
else:
args_dict["dataset"] = self.dataset_custom
args_dict["problem_type"] = self.problem_type
args_dict["num_labels"] = self.num_labels
args_dict["metrics"] = self.metrics
# Add LoRA parameters
if self.training_method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"]:
args_dict.update({
"lora_r": self.lora_r,
"lora_alpha": self.lora_alpha,
"lora_dropout": self.lora_dropout,
"lora_target_modules": self.lora_target_modules
})
# Add batch processing parameters
if self.batch_mode == "Batch Size Mode":
args_dict["batch_size"] = self.batch_size
else:
args_dict["batch_token"] = self.batch_token
# Add wandb parameters
if self.wandb_enabled:
args_dict["wandb"] = True
if self.wandb_project:
args_dict["wandb_project"] = self.wandb_project
if self.wandb_entity:
args_dict["wandb_entity"] = self.wandb_entity
return args_dict
def create_train_tab(constant: Dict[str, Any]) -> Dict[str, Any]:
# Create training monitor
monitor = TrainingMonitor()
# Add missing variable declarations
is_training = False
current_process = None
stop_thread = False
process_aborted = False
plm_models = constant["plm_models"]
dataset_configs = constant["dataset_configs"]
with gr.Tab("Training"):
# Model and Dataset Selection
gr.Markdown("### Model and Dataset Configuration")
# Original training interface components
with gr.Group():
with gr.Row():
with gr.Column(scale=4):
with gr.Row():
plm_model = gr.Dropdown(
choices=list(plm_models.keys()),
label="Protein Language Model",
value=list(plm_models.keys())[0],
scale=2
)
# 新增数据集选择方式
is_custom_dataset = gr.Radio(
choices=["Use Custom Dataset", "Use Pre-defined Dataset"],
label="Dataset Selection",
value="Use Pre-defined Dataset",
scale=3
)
dataset_config = gr.Dropdown(
choices=list(dataset_configs.keys()),
label="Dataset Configuration",
value=list(dataset_configs.keys())[0],
visible=True,
scale=2
)
dataset_custom = gr.Textbox(
label="Custom Dataset Path",
placeholder="Huggingface Dataset eg: user/dataset",
visible=False,
scale=2
)
# 将预览按钮放在单独的列中,并添加样式
with gr.Column(scale=1, min_width=120, elem_classes="preview-button-container"):
dataset_preview_button = gr.Button(
"Preview Dataset",
variant="primary",
size="lg",
elem_classes="preview-button"
)
# 自定义数据集的额外配置选项(单独一行)
with gr.Row(visible=True) as custom_dataset_settings:
problem_type = gr.Dropdown(
choices=["single_label_classification", "multi_label_classification", "regression"],
label="Problem Type",
value="single_label_classification",
scale=23,
interactive=False
)
num_labels = gr.Number(
value=2,
label="Number of Labels",
scale=11,
interactive=False
)
metrics = gr.Dropdown(
choices=["accuracy", "recall", "precision", "f1", "mcc", "auroc", "f1max", "spearman_corr", "mse"],
label="Metrics",
value=["accuracy", "mcc", "f1", "precision", "recall", "auroc"],
scale=101,
multiselect=True,
interactive=False
)
with gr.Row():
structure_seq = gr.Dropdown(
label="Structure Sequence",
choices=["foldseek_seq", "ss8_seq"],
value=["foldseek_seq", "ss8_seq"],
multiselect=True,
visible=False
)
# ! add for plm-lora, plm-qlora, plm_adalora, plm_dora, plm_ia3
with gr.Row(visible=False) as lora_params_row:
# gr.Markdown("#### LoRA Parameters")
with gr.Column():
lora_r = gr.Number(
value=8,
label="LoRA Rank",
precision=0,
minimum=1,
maximum=128,
)
with gr.Column():
lora_alpha = gr.Number(
value=32,
label="LoRA Alpha",
precision=0,
minimum=1,
maximum=128
)
with gr.Column():
lora_dropout = gr.Number(
value=0.1,
label="LoRA Dropout",
minimum=0.0,
maximum=1.0
)
with gr.Column():
lora_target_modules = gr.Textbox(
value="query,key,value",
label="LoRA Target Modules",
placeholder="Comma-separated list of target modules",
# info="LoRA will be applied to these modules"
)
# 将数据统计和表格都放入折叠面板
with gr.Row():
with gr.Accordion("Dataset Preview", open=False) as preview_accordion:
# 数据统计区域
with gr.Row():
dataset_stats_md = gr.HTML("", elem_classes=["dataset-stats"])
# 表格区域
with gr.Row():
preview_table = gr.Dataframe(
headers=["Name", "Sequence", "Label"],
value=[["No dataset selected", "-", "-"]],
wrap=True,
interactive=False,
row_count=3,
elem_classes=["preview-table"]
)
# Add CSS styles
gr.HTML("""
<style>
/* 数据统计样式 */
.dataset-stats {
margin: 0 0 15px 0;
padding: 0;
}
.dataset-stats table {
width: 100%;
border-collapse: collapse;
font-size: 0.9em;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
border-radius: 8px;
overflow: hidden;
table-layout: fixed;
}
.dataset-stats th {
background-color: #e0e0e0;
font-weight: bold;
padding: 6px 10px;
text-align: center;
border: 1px solid #ddd;
font-size: 0.95em;
white-space: nowrap;
overflow: hidden;
min-width: 120px;
}
.dataset-stats td {
padding: 6px 10px;
text-align: center;
border: 1px solid #ddd;
}
.dataset-stats h2 {
font-size: 1.1em;
margin: 0 0 10px 0;
text-align: center;
}
/* 表格样式 */
.preview-table table {
background-color: white !important;
font-size: 0.9em !important;
width: 100%;
table-layout: fixed !important;
}
.preview-table .gr-block.gr-box {
background-color: transparent !important;
}
.preview-table .gr-input-label {
background-color: transparent !important;
}
/* 表格外观增强 */
.preview-table table {
margin-top: 0;
border-radius: 8px;
overflow: hidden;
box-shadow: 0 2px 4px rgba(0,0,0,0.05);
}
/* 表头样式 */
.preview-table th {
background-color: #e0e0e0 !important;
font-weight: bold !important;
padding: 6px !important;
border-bottom: 1px solid #ccc !important;
font-size: 0.95em !important;
text-align: center !important;
white-space: nowrap !important;
min-width: 120px !important;
}
/* 单元格样式 */
.preview-table td {
padding: 4px 6px !important;
max-width: 300px !important;
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
text-align: left !important;
}
/* 悬停效果 */
.preview-table tr:hover {
background-color: #f0f0f0 !important;
}
/* 折叠面板样式 */
.gr-accordion {
border: 1px solid #e0e0e0;
border-radius: 8px;
overflow: hidden;
margin-bottom: 15px;
}
/* 折叠面板标题样式 */
.gr-accordion .label-wrap {
background-color: #f5f5f5;
padding: 8px 15px;
font-weight: bold;
}
.preview-button {
height: 86px !important;
}
/* Center Model Statistics Table */
.center-table-content td, .center-table-content th {
text-align: center !important;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
padding: 10px !important;
}
.center-table-content table {
width: 100% !important;
border-collapse: collapse !important;
margin-bottom: 20px !important;
box-shadow: 0 2px 8px rgba(0,0,0,0.1) !important;
border-radius: 8px !important;
overflow: hidden !important;
}
.center-table-content th {
background-color: #f0f4f8 !important;
color: #2c3e50 !important;
font-weight: 600 !important;
border-bottom: 2px solid #ddd !important;
}
.center-table-content tr:nth-child(even) {
background-color: #f9f9f9 !important;
}
.center-table-content tr:hover {
background-color: #f0f7ff !important;
}
/* Improve readability of progress bars */
.progress-container {
margin-bottom: 20px !important;
}
.progress-bar {
transition: width 0.5s ease-in-out !important;
}
.status-message {
margin-bottom: 8px !important;
font-weight: 500 !important;
}
</style>
""", visible=True)
# Batch Processing Configuration
gr.Markdown("### Batch Processing Configuration")
with gr.Group():
with gr.Row(equal_height=True):
with gr.Column(scale=1):
batch_mode = gr.Radio(
choices=["Batch Size Mode", "Batch Token Mode"],
label="Batch Processing Mode",
value="Batch Size Mode"
)
with gr.Column(scale=2):
batch_size = gr.Slider(
minimum=1,
maximum=128,
value=16,
step=1,
label="Batch Size",
visible=True
)
batch_token = gr.Slider(
minimum=1000,
maximum=50000,
value=10000,
step=1000,
label="Tokens per Batch",
visible=False
)
def update_batch_inputs(mode):
return {
batch_size: gr.update(visible=mode == "Batch Size Mode"),
batch_token: gr.update(visible=mode == "Batch Token Mode")
}
# Update visibility when mode changes
batch_mode.change(
fn=update_batch_inputs,
inputs=[batch_mode],
outputs=[batch_size, batch_token]
)
# Training Parameters
gr.Markdown("### Training Parameters")
with gr.Group():
# First row: Basic training parameters
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=150):
training_method = gr.Dropdown(
choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"],
label="Training Method",
value="freeze"
)
with gr.Column(scale=1, min_width=150):
learning_rate = gr.Slider(
minimum=1e-8, maximum=1e-2, value=5e-4, step=1e-6,
label="Learning Rate"
)
with gr.Column(scale=1, min_width=150):
num_epochs = gr.Slider(
minimum=1, maximum=200, value=20, step=1,
label="Number of Epochs"
)
with gr.Column(scale=1, min_width=150):
patience = gr.Slider(
minimum=1, maximum=50, value=10, step=1,
label="Early Stopping Patience"
)
with gr.Column(scale=1, min_width=150):
max_seq_len = gr.Slider(
minimum=-1, maximum=2048, value=None, step=32,
label="Max Sequence Length (-1 for unlimited)"
)
def update_training_method(method):
return {
structure_seq: gr.update(visible=method == "ses-adapter"),
lora_params_row: gr.update(visible=method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"])
}
# Add training_method change event
training_method.change(
fn=update_training_method,
inputs=[training_method],
outputs=[structure_seq, lora_params_row]
)
# Second row: Advanced training parameters
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=150):
pooling_method = gr.Dropdown(
choices=["mean", "attention1d", "light_attention"],
label="Pooling Method",
value="mean"
)
with gr.Column(scale=1, min_width=150):
scheduler_type = gr.Dropdown(
choices=["linear", "cosine", "step", None],
label="Scheduler Type",
value=None
)
with gr.Column(scale=1, min_width=150):
warmup_steps = gr.Slider(
minimum=0, maximum=1000, value=0, step=10,
label="Warmup Steps"
)
with gr.Column(scale=1, min_width=150):
gradient_accumulation_steps = gr.Slider(
minimum=1, maximum=32, value=1, step=1,
label="Gradient Accumulation Steps"
)
with gr.Column(scale=1, min_width=150):
max_grad_norm = gr.Slider(
minimum=0.1, maximum=10.0, value=-1, step=0.1,
label="Max Gradient Norm (-1 for no clipping)"
)
with gr.Column(scale=1, min_width=150):
num_workers = gr.Slider(
minimum=0, maximum=16, value=4, step=1,
label="Number of Workers"
)
# Output and Logging Settings
gr.Markdown("### Output and Logging Settings")
with gr.Row():
with gr.Column():
output_dir = gr.Textbox(
label="Save Directory",
value="demo",
placeholder="Path to save training results"
)
output_model_name = gr.Textbox(
label="Output Model Name",
value="demo.pt",
placeholder="Name of the output model file"
)
with gr.Column():
wandb_logging = gr.Checkbox(
label="Enable W&B Logging",
value=False
)
wandb_project = gr.Textbox(
label="W&B Project Name",
value=None,
visible=False
)
wandb_entity = gr.Textbox(
label="W&B Entity",
value=None,
visible=False
)
# Training Control and Output
gr.Markdown("### Training Control")
with gr.Row():
preview_button = gr.Button("Preview Command")
abort_button = gr.Button("Abort", variant="stop")
train_button = gr.Button("Start", variant="primary")
with gr.Row():
command_preview = gr.Code(
label="Command Preview",
language="shell",
interactive=False,
visible=False
)
# Model Statistics Section
gr.Markdown("### Model Statistics")
with gr.Row():
model_stats = gr.Dataframe(
headers=["Model Type", "Total Parameters", "Trainable Parameters", "Percentage"],
value=[
["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]
],
interactive=False,
elem_classes=["center-table-content"]
)
def update_model_stats(stats: Dict[str, str]) -> List[List[str]]:
"""Update model statistics in table format."""
if not stats:
return [
["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]
]
adapter_total = stats.get('adapter_total', '-')
adapter_trainable = stats.get('adapter_trainable', '-')
pretrain_total = stats.get('pretrain_total', '-')
pretrain_trainable = stats.get('pretrain_trainable', '-')
combined_total = stats.get('combined_total', '-')
combined_trainable = stats.get('combined_trainable', '-')
trainable_percentage = stats.get('trainable_percentage', '-')
return [
["Training Model", str(adapter_total), str(adapter_trainable), "-"],
["Pre-trained Model", str(pretrain_total), str(pretrain_trainable), "-"],
["Combined Model", str(combined_total), str(combined_trainable), str(trainable_percentage)]
]
# Training Progress
gr.Markdown("### Training Progress")
with gr.Row():
progress_status = gr.HTML(
value="""
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Click Start to train your model</span>
</div>
</div>
</div>
""",
label="Status"
)
with gr.Row():
best_model_info = gr.Textbox(
value="Best Model: None",
label="Best Performance",
interactive=False
)
# Add test results HTML display
with gr.Row():
test_results_html = gr.HTML(
value="",
label="Test Results",
visible=True
)
with gr.Row():
with gr.Column(scale=4):
pass
with gr.Column(scale=1): # 限制列的最大宽度
download_csv_btn = gr.DownloadButton(
"Download CSV",
visible=False,
size="lg"
)
# 添加一个空列来占据剩余空间
with gr.Column(scale=4):
pass
# Training plot in a separate row for full width
with gr.Row():
with gr.Column():
loss_plot = gr.Plot(
label="Training and Validation Loss",
elem_id="loss_plot"
)
with gr.Column():
metrics_plot = gr.Plot(
label="Validation Metrics",
elem_id="metrics_plot"
)
def update_progress(progress_info):
# If progress_info is empty or None, use completely fresh empty state
if not progress_info or not any(progress_info.values()):
fresh_status_html = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Click Start to train your model</span>
</div>
</div>
</div>
"""
return (
fresh_status_html,
"Best Model: None",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False)
)
# Reset values if stage is "Waiting" or "Error"
if progress_info.get('stage', '') == 'Waiting' or progress_info.get('stage', '') == 'Error':
# If this is an error stage, show error styling
if progress_info.get('stage', '') == 'Error':
error_status_html = """
<div style="background-color: #ffebee; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #c62828; font-weight: 500; font-size: 16px;">Failed</span>
</div>
</div>
</div>
"""
return (
error_status_html,
"Training failed",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False)
)
else:
return (
"""
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Waiting to start...</span>
</div>
</div>
</div>
""",
"Best Model: None",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False)
)
current = progress_info.get('current', 0)
total = progress_info.get('total', 100)
epoch = progress_info.get('epoch', 0)
stage = progress_info.get('stage', 'Waiting')
progress_detail = progress_info.get('progress_detail', '')
best_epoch = progress_info.get('best_epoch', 0)
best_metric_name = progress_info.get('best_metric_name', 'accuracy')
best_metric_value = progress_info.get('best_metric_value', 0.0)
elapsed_time = progress_info.get('elapsed_time', '')
remaining_time = progress_info.get('remaining_time', '')
it_per_sec = progress_info.get('it_per_sec', 0.0)
grad_step = progress_info.get('grad_step', 0)
loss = progress_info.get('loss', 0.0)
total_epochs = progress_info.get('total_epochs', 0) # 获取总epoch数
test_results_html = progress_info.get('test_results_html', '') # 获取测试结果HTML
test_metrics = progress_info.get('test_metrics', {}) # 获取测试指标
is_completed = progress_info.get('is_completed', False) # 检查训练是否完成
# Test results HTML visibility is always True, but show message when content is empty
if not test_results_html and stage == 'Testing':
test_results_html = """
<div style="text-align: center; padding: 20px; color: #666;">
<p>Testing in progress, please wait for results...</p>
</div>
"""
elif not test_results_html:
test_results_html = """
<div style="text-align: center; padding: 20px; color: #666;">
<p>Test results will be displayed after testing phase completes</p>
</div>
"""
test_html_update = gr.update(value=test_results_html, visible=True)
# 处理CSV下载按钮
if test_metrics and len(test_metrics) > 0:
# 创建临时文件保存CSV内容
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', prefix='metrics_results_') as temp_file:
# 写入CSV头部
temp_file.write("Metric,Value\n")
# 按照优先级排序指标
priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc']
def get_priority(item):
name = item[0]
if name in priority_metrics:
return priority_metrics.index(name)
return len(priority_metrics)
# 排序并添加到CSV
sorted_metrics = sorted(test_metrics.items(), key=get_priority)
for metric_name, metric_value in sorted_metrics:
# Convert metric name: uppercase for abbreviations, capitalize for others
display_name = metric_name
if metric_name.lower() in ['f1', 'mcc', 'auroc']:
display_name = metric_name.upper()
else:
display_name = metric_name.capitalize()
temp_file.write(f"{display_name},{metric_value:.6f}\n")
file_path = temp_file.name
download_btn_update = gr.update(value=file_path, visible=True)
else:
download_btn_update = gr.update(visible=False)
# 计算进度百分比
progress_percentage = (current / total) * 100 if total > 0 else 0
# 创建现代化的进度条HTML
if is_completed:
# 训练完成状态
status_html = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #4caf50; font-weight: 500; font-size: 16px;">Training complete!</span>
</div>
<div>
<span style="font-weight: 600; color: #333;">100%</span>
</div>
</div>
<div style="margin-bottom: 15px; background-color: #e9ecef; height: 10px; border-radius: 5px; overflow: hidden;">
<div style="background-color: #4caf50; width: 100%; height: 100%; border-radius: 5px;"></div>
</div>
</div>
"""
else:
# 训练或验证阶段
epoch_total = total_epochs if total_epochs > 0 else 100
status_html = f"""
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">{stage} (Epoch {epoch}/{epoch_total})</span>
</div>
<div>
<span style="font-weight: 600; color: #333;">{progress_percentage:.1f}%</span>
</div>
</div>
<div style="margin-bottom: 15px; background-color: #e9ecef; height: 10px; border-radius: 5px; overflow: hidden;">
<div style="background-color: #4285f4; width: {progress_percentage}%; height: 100%; border-radius: 5px; transition: width 0.3s ease;"></div>
</div>
<div style="display: flex; flex-wrap: wrap; gap: 10px; font-size: 14px; color: #555;">
<div style="background-color: #e8f5e9; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Progress:</span> {current}/{total}</div>
{f'<div style="background-color: #fff8e1; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Time:</span> {elapsed_time}<{remaining_time}, {it_per_sec:.2f}it/s></div>' if elapsed_time and remaining_time else ''}
{f'<div style="background-color: #e3f2fd; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Loss:</span> {loss:.4f}</div>' if stage == 'Training' and loss > 0 else ''}
{f'<div style="background-color: #f3e5f5; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Grad steps:</span> {grad_step}</div>' if stage == 'Training' and grad_step > 0 else ''}
</div>
</div>
"""
# 构建最佳模型信息
if best_epoch >= 0 and best_metric_value > 0:
best_info = f"Best model: Epoch {best_epoch} ({best_metric_name}: {best_metric_value:.4f})"
else:
best_info = "No best model found yet"
# 获取并更新图表
loss_fig = monitor.get_loss_plot()
metrics_fig = monitor.get_metrics_plot()
# 返回更新的组件
return status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update
def handle_train(*args) -> Generator:
nonlocal is_training, current_process, stop_thread, process_aborted, monitor
# If already training, return
if is_training:
yield None, None, None, None, None, None, None
return
# Force explicit state reset first thing
monitor._reset_tracking()
monitor._reset_stats()
# Explicitly ensure stats are reset
if hasattr(monitor, "stats"):
monitor.stats = {}
# Force override any cached state in monitor
monitor.current_progress = {
"current": 0,
"total": 0,
"epoch": 0,
"stage": "Waiting",
"progress_detail": "",
"best_epoch": -1,
"best_metric_name": "",
"best_metric_value": 0.0,
"elapsed_time": "",
"remaining_time": "",
"it_per_sec": 0.0,
"grad_step": 0,
"loss": 0.0,
"test_results_html": "",
"test_metrics": {},
"is_completed": False,
"lines": []
}
# Reset all monitoring data structures
monitor.train_losses = []
monitor.val_losses = []
monitor.metrics = {}
monitor.epochs = []
if hasattr(monitor, "stats"):
monitor.stats = {}
# Reset flags for new training session
process_aborted = False
stop_thread = False
# Initialize table state
initial_stats = [
["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]
]
# Initial UI state with "Initializing" message
initial_status_html = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Initializing training environment...</span>
</div>
</div>
<div style="font-size: 14px; color: #555; margin-top: 10px;">
<p>• Parsing configuration parameters</p>
<p>• Preparing training environment</p>
<p>• This may take a few moments...</p>
</div>
</div>
"""
# First yield to update UI with "initializing" state
yield initial_stats, initial_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)
try:
# Parse training arguments
training_args = TrainingArgs(args, plm_models, dataset_configs)
if training_args.training_method != "ses-adapter":
training_args.structure_seq = None
args_dict = training_args.to_dict()
# Save total epochs to monitor for use in progress_info
total_epochs = args_dict.get('num_epochs', 100)
monitor.current_progress['total_epochs'] = total_epochs
# Update status to "Preparing dataset"
preparing_status_html = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Preparing dataset and model...</span>
</div>
</div>
<div style="font-size: 14px; color: #555; margin-top: 10px;">
<p>• Loading dataset</p>
<p>• Initializing model architecture</p>
<p>• Setting up training environment</p>
</div>
</div>
"""
yield initial_stats, preparing_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)
# Save arguments to file
save_arguments(args_dict, args_dict.get('output_dir', 'ckpt'))
# Start training
is_training = True
process_aborted = False # Reset abort flag
monitor.start_training(args_dict)
current_process = monitor.process # Store the process reference
starting_status_html = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Starting training process...</span>
</div>
</div>
<div style="font-size: 14px; color: #555; margin-top: 10px;">
<p>• Training process launched</p>
<p>• Waiting for first statistics to appear</p>
<p>• This may take a moment for large models</p>
</div>
</div>
"""
yield initial_stats, starting_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)
# Add delay to ensure enough time for parsing initial statistics
for i in range(3):
time.sleep(1)
# Check if statistics are already available
stats = monitor.get_stats()
if stats and len(stats) > 0:
break
update_count = 0
while True:
# Check if the process still exists and hasn't been aborted
if process_aborted or not monitor.is_training or current_process is None or (current_process and current_process.poll() is not None):
break
try:
update_count += 1
time.sleep(0.5)
# Check process status
monitor.check_process_status()
# Get latest progress info
progress_info = monitor.get_progress()
# If process has ended, check if it's normal end or error
if not monitor.is_training:
# Check both monitor.process and current_process since they might be different objects
if (monitor.process and monitor.process.returncode != 0) or (current_process and current_process.poll() is not None and current_process.returncode != 0):
# Get the return code from whichever process object is available
return_code = monitor.process.returncode if monitor.process else current_process.returncode
# Get complete output log
error_output = "\n".join(progress_info.get("lines", []))
if not error_output:
error_output = "No output captured from the training process"
# Ensure we set the is_completed flag to False for errors
progress_info['is_completed'] = False
monitor.current_progress['is_completed'] = False
# Also set the stage to Error
progress_info['stage'] = 'Error'
monitor.current_progress['stage'] = 'Error'
error_status_html = f"""
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;">
<p style="margin: 0; color: #c62828; font-weight: bold;">Training failed with error code {return_code}:</p>
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre>
</div>
"""
yield (
initial_stats,
error_status_html,
"Training failed",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False)
)
return
else:
# Only set is_completed to True if there was a successful exit code
progress_info['is_completed'] = True
monitor.current_progress['is_completed'] = True
# Update UI
stats = monitor.get_stats()
if stats:
model_stats = update_model_stats(stats)
else:
model_stats = initial_stats
status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info)
yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update
except Exception as e:
# Get complete output log
error_output = "\n".join(progress_info.get("lines", []))
if not error_output:
error_output = "No output captured from the training process"
error_status_html = f"""
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;">
<p style="margin: 0; color: #c62828; font-weight: bold;">Error during training:</p>
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p>
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre>
</div>
"""
print(f"Error updating UI: {str(e)}")
traceback.print_exc()
yield initial_stats, error_status_html, "Training error", gr.update(value="", visible=False), None, None, gr.update(visible=False)
return
# Check if aborted
if process_aborted:
is_training = False
current_process = None
aborted_status_html = """
<div style="padding: 10px; background-color: #e8f5e9; border-radius: 5px;">
<p style="margin: 0; color: #2e7d32; font-weight: bold;">Training was manually terminated.</p>
</div>
"""
yield initial_stats, aborted_status_html, "Training aborted", gr.update(value="", visible=False), None, None, gr.update(visible=False)
return
# Final update after training ends (only for normal completion)
if monitor.process and monitor.process.returncode == 0:
try:
progress_info = monitor.get_progress()
progress_info['is_completed'] = True
monitor.current_progress['is_completed'] = True
stats = monitor.get_stats()
if stats:
model_stats = update_model_stats(stats)
else:
model_stats = initial_stats
status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info)
yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update
except Exception as e:
error_output = "\n".join(progress_info.get("lines", []))
if not error_output:
error_output = "No output captured from the training process"
error_status_html = f"""
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;">
<p style="margin: 0; color: #c62828; font-weight: bold;">Error in final update:</p>
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p>
<pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto; background-color: #f5f5f5; padding: 10px; border-radius: 4px; font-family: monospace;">{error_output}</pre>
</div>
"""
yield initial_stats, error_status_html, "Error in final update", gr.update(value="", visible=False), None, None, gr.update(visible=False)
except Exception as e:
# Initialization error, may not have output log
error_status_html = f"""
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;">
<p style="margin: 0; color: #c62828; font-weight: bold;">Training initialization failed:</p>
<p style="margin: 5px 0; color: #c62828;">{str(e)}</p>
</div>
"""
yield initial_stats, error_status_html, "Training failed", gr.update(value="", visible=False), None, None, gr.update(visible=False)
finally:
is_training = False
current_process = None
def handle_abort():
"""Handle abortion of the training process"""
nonlocal is_training, current_process, stop_thread, process_aborted
if not is_training or current_process is None:
return (gr.HTML("""
<div style="padding: 10px; background-color: #f5f5f5; border-radius: 5px;">
<p style="margin: 0;">No training process is currently running.</p>
</div>
"""),
[["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]],
"Best Model: None",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False))
try:
# Set the abort flag before terminating the process
process_aborted = True
stop_thread = True
# Use process.terminate() instead of os.killpg for safer termination
# This avoids accidentally killing the parent WebUI process
current_process.terminate()
# Wait for process to terminate (with timeout)
try:
current_process.wait(timeout=5)
except subprocess.TimeoutExpired:
# Only if terminate didn't work, use a stronger method
# But do NOT use killpg which might kill the parent WebUI
current_process.kill()
# Create a completely fresh state - not just resetting
monitor.is_training = False
# Explicitly create a new dictionary instead of modifying the existing one
monitor.current_progress = {
"current": 0,
"total": 0,
"epoch": 0,
"stage": "Waiting",
"progress_detail": "",
"best_epoch": -1,
"best_metric_name": "",
"best_metric_value": 0.0,
"elapsed_time": "",
"remaining_time": "",
"it_per_sec": 0.0,
"grad_step": 0,
"loss": 0.0,
"test_results_html": "",
"test_metrics": {},
"is_completed": False,
"lines": []
}
# Explicitly clear stats by creating a new dictionary
monitor.stats = {}
if hasattr(monitor, "process") and monitor.process:
monitor.process = None
# Reset state variables
is_training = False
current_process = None
# Explicitly reset tracking to clear all state
monitor._reset_tracking()
monitor._reset_stats()
# Reset all plots and statistics with new empty lists
monitor.train_losses = []
monitor.val_losses = []
monitor.metrics = {}
monitor.epochs = []
# Create entirely fresh UI components
empty_model_stats = [["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]]
success_html = """
<div style="padding: 10px; background-color: #e8f5e9; border-radius: 5px;">
<p style="margin: 0; color: #2e7d32; font-weight: bold;">Training successfully terminated!</p>
<p style="margin: 5px 0 0; color: #388e3c;">All training state has been reset. You can start a new training session.</p>
</div>
"""
# Return updates for all relevant components
return (gr.HTML(success_html),
empty_model_stats,
"Best Model: None",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False))
except Exception as e:
# Still need to reset states even if there's an error
is_training = False
current_process = None
process_aborted = False
# Reset monitor state regardless of error
monitor.is_training = False
monitor.stats = {}
if hasattr(monitor, "process") and monitor.process:
monitor.process = None
monitor._reset_tracking()
monitor._reset_stats()
# Fresh empty components
empty_model_stats = [["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]]
error_html = f"""
<div style="padding: 10px; background-color: #ffebee; border-radius: 5px;">
<p style="margin: 0; color: #c62828; font-weight: bold;">Failed to terminate training: {str(e)}</p>
<p style="margin: 5px 0 0; color: #c62828;">Training state has been reset.</p>
</div>
"""
# Return updates for all relevant components including empty model stats
return (gr.HTML(error_html),
empty_model_stats,
"Best Model: None",
gr.update(value="", visible=False),
None,
None,
gr.update(visible=False))
def update_wandb_visibility(checkbox):
return {
wandb_project: gr.update(visible=checkbox),
wandb_entity: gr.update(visible=checkbox)
}
# define all input components
input_components = [
plm_model, #0
is_custom_dataset, #1
dataset_config, #2
dataset_custom, #3
problem_type, #4
num_labels, #5
metrics, #6
training_method, #7
pooling_method, #8
batch_mode, #9
batch_size, #10
batch_token, #11
learning_rate, #12
num_epochs, #13
max_seq_len, #14
gradient_accumulation_steps, #15
warmup_steps, #16
scheduler_type, #17
output_model_name, #18
output_dir, #19
wandb_logging, #20
wandb_project, #21
wandb_entity, #22
patience, #23
num_workers, #24
max_grad_norm, #25
structure_seq, #26
lora_r, #27
lora_alpha, #28
lora_dropout, #29
lora_target_modules, #30
]
# bind preview and train buttons
def handle_preview(*args):
if command_preview.visible:
return gr.update(visible=False)
training_args = TrainingArgs(args, plm_models, dataset_configs)
preview_text = preview_command(training_args.to_dict())
return gr.update(value=preview_text, visible=True)
def reset_train_ui():
"""Reset the UI state before training starts"""
# Reset monitor state
monitor._reset_tracking()
monitor._reset_stats()
# Explicitly ensure stats are reset
if hasattr(monitor, "stats"):
monitor.stats = {}
# Create a completely fresh progress state
monitor.current_progress = {
"current": 0,
"total": 0,
"epoch": 0,
"stage": "Waiting",
"progress_detail": "",
"best_epoch": -1,
"best_metric_name": "",
"best_metric_value": 0.0,
"elapsed_time": "",
"remaining_time": "",
"it_per_sec": 0.0,
"grad_step": 0,
"loss": 0.0,
"test_results_html": "",
"test_metrics": {},
"is_completed": False,
"lines": []
}
# Reset all statistical data
monitor.train_losses = []
monitor.val_losses = []
monitor.metrics = {}
monitor.epochs = []
# Force UI to reset by creating completely fresh components
empty_model_stats = [["Training Model", "-", "-", "-"],
["Pre-trained Model", "-", "-", "-"],
["Combined Model", "-", "-", "-"]]
empty_progress_status = """
<div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);">
<div style="display: flex; justify-content: space-between; margin-bottom: 12px;">
<div>
<span style="font-weight: 600; font-size: 16px;">Training Status: </span>
<span style="color: #1976d2; font-weight: 500; font-size: 16px;">Preparing to start training...</span>
</div>
</div>
</div>
"""
# Return exactly 7 values matching the 7 output components
return (
empty_model_stats,
empty_progress_status,
"Best Model: None",
gr.update(value="", visible=False),
None, # loss_plot must be None, not a string
None, # metrics_plot must be None, not a string
gr.update(visible=False)
)
preview_button.click(
fn=handle_preview,
inputs=input_components,
outputs=[command_preview]
)
train_button.click(
fn=reset_train_ui,
outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn]
).then(
fn=handle_train,
inputs=input_components,
outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn]
)
# bind abort button
abort_button.click(
fn=handle_abort,
outputs=[progress_status, model_stats, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn]
)
wandb_logging.change(
fn=update_wandb_visibility,
inputs=[wandb_logging],
outputs=[wandb_project, wandb_entity]
)
def update_dataset_preview(dataset_type=None, dataset_name=None, custom_dataset=None):
"""Update dataset preview content"""
# Determine which dataset to use based on selection
if dataset_type == "Use Custom Dataset" and custom_dataset:
try:
# Try to load custom dataset
dataset = load_dataset(custom_dataset)
stats_html = f"""
<div style="text-align: center; margin: 20px 0;">
<table style="width: 100%; border-collapse: collapse; margin: 0 auto;">
<tr>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th>
</tr>
<tr>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{custom_dataset}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td>
</tr>
</table>
</div>
"""
# Get sample data points
split = "train" if "train" in dataset else list(dataset.keys())[0]
samples = dataset[split].select(range(min(3, len(dataset[split]))))
if len(samples) == 0:
return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True)
# Get fields actually present in the dataset
available_fields = list(samples[0].keys())
# Build sample data
sample_data = []
for sample in samples:
sample_dict = {}
for field in available_fields:
# Keep full sequence
sample_dict[field] = str(sample[field])
sample_data.append(sample_dict)
df = pd.DataFrame(sample_data)
return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True)
except Exception as e:
error_html = f"""
<div>
<h2>Error loading dataset</h2>
<p style="color: #c62828;">{str(e)}</p>
</div>
"""
return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True)
# Use predefined dataset
elif dataset_type == "Use Pre-defined Dataset" and dataset_name:
try:
config_path = dataset_configs[dataset_name]
with open(config_path, 'r') as f:
config = json.load(f)
# Load dataset statistics
dataset = load_dataset(config["dataset"])
stats_html = f"""
<div style="text-align: center; margin: 20px 0;">
<table style="width: 100%; border-collapse: collapse; margin: 0 auto;">
<tr>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th>
<th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th>
</tr>
<tr>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{config["dataset"]}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td>
<td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td>
</tr>
</table>
</div>
"""
# Get sample data points and available fields
samples = dataset["train"].select(range(min(3, len(dataset["train"]))))
if len(samples) == 0:
return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True)
# Get fields actually present in the dataset
available_fields = list(samples[0].keys())
# Build sample data
sample_data = []
for sample in samples:
sample_dict = {}
for field in available_fields:
# Keep full sequence
sample_dict[field] = str(sample[field])
sample_data.append(sample_dict)
df = pd.DataFrame(sample_data)
return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True)
except Exception as e:
error_html = f"""
<div>
<h2>Error loading dataset</h2>
<p style="color: #c62828;">{str(e)}</p>
</div>
"""
return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True)
# If no valid dataset information provided
return gr.update(value=""), gr.update(value=[["No dataset selected", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True)
# Preview button click event
dataset_preview_button.click(
fn=update_dataset_preview,
inputs=[is_custom_dataset, dataset_config, dataset_custom],
outputs=[dataset_stats_md, preview_table, preview_accordion]
)
# 添加自定义数据集设置的函数
def update_dataset_settings(choice, dataset_name=None):
if choice == "Use Pre-defined Dataset":
# 从dataset_config加载配置
result = {
dataset_config: gr.update(visible=True),
dataset_custom: gr.update(visible=False),
custom_dataset_settings: gr.update(visible=True)
}
# 如果有选择特定数据集,自动加载配置
if dataset_name and dataset_name in dataset_configs:
with open(dataset_configs[dataset_name], 'r') as f:
config = json.load(f)
# 处理metrics,将字符串转换为列表以适应多选组件
metrics_value = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc")
if isinstance(metrics_value, str):
metrics_value = metrics_value.split(",")
result.update({
problem_type: gr.update(value=config.get("problem_type", "single_label_classification"), interactive=False),
num_labels: gr.update(value=config.get("num_labels", 2), interactive=False),
metrics: gr.update(value=metrics_value, interactive=False),
})
return result
else:
# 自定义数据集设置,清零/设为默认值并可编辑
# 为多选组件提供默认值列表
default_metrics = ["accuracy", "mcc", "f1", "precision", "recall", "auroc"]
return {
dataset_config: gr.update(visible=False),
dataset_custom: gr.update(visible=True),
custom_dataset_settings: gr.update(visible=True),
problem_type: gr.update(value="single_label_classification", interactive=True),
num_labels: gr.update(value=2, interactive=True),
metrics: gr.update(value=default_metrics, interactive=True)
}
# 绑定数据集设置更新事件
is_custom_dataset.change(
fn=update_dataset_settings,
inputs=[is_custom_dataset, dataset_config],
outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics]
)
dataset_config.change(
fn=lambda x: update_dataset_settings("Use Pre-defined Dataset", x),
inputs=[dataset_config],
outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics]
)
# Return components that need to be accessed from outside
return {
"output_text": progress_status,
"loss_plot": loss_plot,
"metrics_plot": metrics_plot,
"train_button": train_button,
"monitor": monitor,
"test_results_html": test_results_html, # 添加测试结果HTML组件
"components": {
"plm_model": plm_model,
"dataset_config": dataset_config,
"training_method": training_method,
"pooling_method": pooling_method,
"batch_mode": batch_mode,
"batch_size": batch_size,
"batch_token": batch_token,
"learning_rate": learning_rate,
"num_epochs": num_epochs,
"max_seq_len": max_seq_len,
"gradient_accumulation_steps": gradient_accumulation_steps,
"warmup_steps": warmup_steps,
"scheduler_type": scheduler_type,
"output_model_name": output_model_name,
"output_dir": output_dir,
"wandb_logging": wandb_logging,
"wandb_project": wandb_project,
"wandb_entity": wandb_entity,
"patience": patience,
"num_workers": num_workers,
"max_grad_norm": max_grad_norm,
"structure_seq": structure_seq,
"lora_r": lora_r,
"lora_alpha": lora_alpha,
"lora_dropout": lora_dropout,
"lora_target_modules": lora_target_modules,
}
}