import os import json import gradio as gr import time from datasets import load_dataset import pandas as pd from typing import Any, Dict, Union, Optional, Generator, List from dataclasses import dataclass from .utils.command import preview_command, save_arguments, build_command_list from .utils.monitor import TrainingMonitor import traceback import base64 import tempfile import numpy as np import queue import subprocess import sys import threading @dataclass class TrainingArgs: def __init__(self, args: list, plm_models: dict, dataset_configs: dict): # Basic parameters self.plm_model = plm_models[args[0]] # 处理自定义数据集或预定义数据集 self.dataset_selection = args[1] # "Use Custom Dataset" 或 "Use Pre-defined Dataset" if self.dataset_selection == "Use Pre-defined Dataset": self.dataset_config = dataset_configs[args[2]] self.dataset_custom = None # 从配置加载问题类型等 with open(self.dataset_config, 'r') as f: config = json.load(f) self.problem_type = config.get("problem_type", "single_label_classification") self.num_labels = config.get("num_labels", 2) self.metrics = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") else: self.dataset_config = None self.dataset_custom = args[3] # Custom dataset path self.problem_type = args[4] self.num_labels = args[5] self.metrics = args[6] # 如果metrics是列表,转换为逗号分隔的字符串 if isinstance(self.metrics, list): self.metrics = ",".join(self.metrics) # Training method parameters self.training_method = args[7] self.pooling_method = args[8] # Batch processing parameters self.batch_mode = args[9] if self.batch_mode == "Batch Size Mode": self.batch_size = args[10] else: self.batch_token = args[11] # Training parameters self.learning_rate = args[12] self.num_epochs = args[13] self.max_seq_len = args[14] self.gradient_accumulation_steps = args[15] self.warmup_steps = args[16] self.scheduler = args[17] # Output parameters self.output_model_name = args[18] self.output_dir = args[19] # Wandb parameters self.wandb_enabled = args[20] if self.wandb_enabled: self.wandb_project = args[21] self.wandb_entity = args[22] # Other parameters self.patience = args[23] self.num_workers = args[24] self.max_grad_norm = args[25] self.structure_seq = args[26] # LoRA parameters self.lora_r = args[27] self.lora_alpha = args[28] self.lora_dropout = args[29] self.lora_target_modules = [m.strip() for m in args[30].split(",")] if args[30] else [] def to_dict(self) -> Dict[str, Any]: args_dict = { "plm_model": self.plm_model, "training_method": self.training_method, "pooling_method": self.pooling_method, "learning_rate": self.learning_rate, "num_epochs": self.num_epochs, "max_seq_len": self.max_seq_len, "gradient_accumulation_steps": self.gradient_accumulation_steps, "warmup_steps": self.warmup_steps, "scheduler": self.scheduler, "output_model_name": self.output_model_name, "output_dir": self.output_dir, "patience": self.patience, "num_workers": self.num_workers, "max_grad_norm": self.max_grad_norm, } if self.training_method == "ses-adapter" and self.structure_seq: args_dict["structure_seq"] = ",".join(self.structure_seq) # 添加数据集相关参数 if self.dataset_selection == "Use Pre-defined Dataset": args_dict["dataset_config"] = self.dataset_config else: args_dict["dataset"] = self.dataset_custom args_dict["problem_type"] = self.problem_type args_dict["num_labels"] = self.num_labels args_dict["metrics"] = self.metrics # Add LoRA parameters if self.training_method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"]: args_dict.update({ "lora_r": self.lora_r, "lora_alpha": self.lora_alpha, "lora_dropout": self.lora_dropout, "lora_target_modules": self.lora_target_modules }) # Add batch processing parameters if self.batch_mode == "Batch Size Mode": args_dict["batch_size"] = self.batch_size else: args_dict["batch_token"] = self.batch_token # Add wandb parameters if self.wandb_enabled: args_dict["wandb"] = True if self.wandb_project: args_dict["wandb_project"] = self.wandb_project if self.wandb_entity: args_dict["wandb_entity"] = self.wandb_entity return args_dict def create_train_tab(constant: Dict[str, Any]) -> Dict[str, Any]: # Create training monitor monitor = TrainingMonitor() # Add missing variable declarations is_training = False current_process = None stop_thread = False process_aborted = False plm_models = constant["plm_models"] dataset_configs = constant["dataset_configs"] with gr.Tab("Training"): # Model and Dataset Selection gr.Markdown("### Model and Dataset Configuration") # Original training interface components with gr.Group(): with gr.Row(): with gr.Column(scale=4): with gr.Row(): plm_model = gr.Dropdown( choices=list(plm_models.keys()), label="Protein Language Model", value=list(plm_models.keys())[0], scale=2 ) # 新增数据集选择方式 is_custom_dataset = gr.Radio( choices=["Use Custom Dataset", "Use Pre-defined Dataset"], label="Dataset Selection", value="Use Pre-defined Dataset", scale=3 ) dataset_config = gr.Dropdown( choices=list(dataset_configs.keys()), label="Dataset Configuration", value=list(dataset_configs.keys())[0], visible=True, scale=2 ) dataset_custom = gr.Textbox( label="Custom Dataset Path", placeholder="Huggingface Dataset eg: user/dataset", visible=False, scale=2 ) # 将预览按钮放在单独的列中,并添加样式 with gr.Column(scale=1, min_width=120, elem_classes="preview-button-container"): dataset_preview_button = gr.Button( "Preview Dataset", variant="primary", size="lg", elem_classes="preview-button" ) # 自定义数据集的额外配置选项(单独一行) with gr.Row(visible=True) as custom_dataset_settings: problem_type = gr.Dropdown( choices=["single_label_classification", "multi_label_classification", "regression"], label="Problem Type", value="single_label_classification", scale=23, interactive=False ) num_labels = gr.Number( value=2, label="Number of Labels", scale=11, interactive=False ) metrics = gr.Dropdown( choices=["accuracy", "recall", "precision", "f1", "mcc", "auroc", "f1max", "spearman_corr", "mse"], label="Metrics", value=["accuracy", "mcc", "f1", "precision", "recall", "auroc"], scale=101, multiselect=True, interactive=False ) with gr.Row(): structure_seq = gr.Dropdown( label="Structure Sequence", choices=["foldseek_seq", "ss8_seq"], value=["foldseek_seq", "ss8_seq"], multiselect=True, visible=False ) # ! add for plm-lora, plm-qlora, plm_adalora, plm_dora, plm_ia3 with gr.Row(visible=False) as lora_params_row: # gr.Markdown("#### LoRA Parameters") with gr.Column(): lora_r = gr.Number( value=8, label="LoRA Rank", precision=0, minimum=1, maximum=128, ) with gr.Column(): lora_alpha = gr.Number( value=32, label="LoRA Alpha", precision=0, minimum=1, maximum=128 ) with gr.Column(): lora_dropout = gr.Number( value=0.1, label="LoRA Dropout", minimum=0.0, maximum=1.0 ) with gr.Column(): lora_target_modules = gr.Textbox( value="query,key,value", label="LoRA Target Modules", placeholder="Comma-separated list of target modules", # info="LoRA will be applied to these modules" ) # 将数据统计和表格都放入折叠面板 with gr.Row(): with gr.Accordion("Dataset Preview", open=False) as preview_accordion: # 数据统计区域 with gr.Row(): dataset_stats_md = gr.HTML("", elem_classes=["dataset-stats"]) # 表格区域 with gr.Row(): preview_table = gr.Dataframe( headers=["Name", "Sequence", "Label"], value=[["No dataset selected", "-", "-"]], wrap=True, interactive=False, row_count=3, elem_classes=["preview-table"] ) # Add CSS styles gr.HTML(""" """, visible=True) # Batch Processing Configuration gr.Markdown("### Batch Processing Configuration") with gr.Group(): with gr.Row(equal_height=True): with gr.Column(scale=1): batch_mode = gr.Radio( choices=["Batch Size Mode", "Batch Token Mode"], label="Batch Processing Mode", value="Batch Size Mode" ) with gr.Column(scale=2): batch_size = gr.Slider( minimum=1, maximum=128, value=16, step=1, label="Batch Size", visible=True ) batch_token = gr.Slider( minimum=1000, maximum=50000, value=10000, step=1000, label="Tokens per Batch", visible=False ) def update_batch_inputs(mode): return { batch_size: gr.update(visible=mode == "Batch Size Mode"), batch_token: gr.update(visible=mode == "Batch Token Mode") } # Update visibility when mode changes batch_mode.change( fn=update_batch_inputs, inputs=[batch_mode], outputs=[batch_size, batch_token] ) # Training Parameters gr.Markdown("### Training Parameters") with gr.Group(): # First row: Basic training parameters with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=150): training_method = gr.Dropdown( choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"], label="Training Method", value="freeze" ) with gr.Column(scale=1, min_width=150): learning_rate = gr.Slider( minimum=1e-8, maximum=1e-2, value=5e-4, step=1e-6, label="Learning Rate" ) with gr.Column(scale=1, min_width=150): num_epochs = gr.Slider( minimum=1, maximum=200, value=20, step=1, label="Number of Epochs" ) with gr.Column(scale=1, min_width=150): patience = gr.Slider( minimum=1, maximum=50, value=10, step=1, label="Early Stopping Patience" ) with gr.Column(scale=1, min_width=150): max_seq_len = gr.Slider( minimum=-1, maximum=2048, value=None, step=32, label="Max Sequence Length (-1 for unlimited)" ) def update_training_method(method): return { structure_seq: gr.update(visible=method == "ses-adapter"), lora_params_row: gr.update(visible=method in ["plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"]) } # Add training_method change event training_method.change( fn=update_training_method, inputs=[training_method], outputs=[structure_seq, lora_params_row] ) # Second row: Advanced training parameters with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=150): pooling_method = gr.Dropdown( choices=["mean", "attention1d", "light_attention"], label="Pooling Method", value="mean" ) with gr.Column(scale=1, min_width=150): scheduler_type = gr.Dropdown( choices=["linear", "cosine", "step", None], label="Scheduler Type", value=None ) with gr.Column(scale=1, min_width=150): warmup_steps = gr.Slider( minimum=0, maximum=1000, value=0, step=10, label="Warmup Steps" ) with gr.Column(scale=1, min_width=150): gradient_accumulation_steps = gr.Slider( minimum=1, maximum=32, value=1, step=1, label="Gradient Accumulation Steps" ) with gr.Column(scale=1, min_width=150): max_grad_norm = gr.Slider( minimum=0.1, maximum=10.0, value=-1, step=0.1, label="Max Gradient Norm (-1 for no clipping)" ) with gr.Column(scale=1, min_width=150): num_workers = gr.Slider( minimum=0, maximum=16, value=4, step=1, label="Number of Workers" ) # Output and Logging Settings gr.Markdown("### Output and Logging Settings") with gr.Row(): with gr.Column(): output_dir = gr.Textbox( label="Save Directory", value="demo", placeholder="Path to save training results" ) output_model_name = gr.Textbox( label="Output Model Name", value="demo.pt", placeholder="Name of the output model file" ) with gr.Column(): wandb_logging = gr.Checkbox( label="Enable W&B Logging", value=False ) wandb_project = gr.Textbox( label="W&B Project Name", value=None, visible=False ) wandb_entity = gr.Textbox( label="W&B Entity", value=None, visible=False ) # Training Control and Output gr.Markdown("### Training Control") with gr.Row(): preview_button = gr.Button("Preview Command") abort_button = gr.Button("Abort", variant="stop") train_button = gr.Button("Start", variant="primary") with gr.Row(): command_preview = gr.Code( label="Command Preview", language="shell", interactive=False, visible=False ) # Model Statistics Section gr.Markdown("### Model Statistics") with gr.Row(): model_stats = gr.Dataframe( headers=["Model Type", "Total Parameters", "Trainable Parameters", "Percentage"], value=[ ["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"] ], interactive=False, elem_classes=["center-table-content"] ) def update_model_stats(stats: Dict[str, str]) -> List[List[str]]: """Update model statistics in table format.""" if not stats: return [ ["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"] ] adapter_total = stats.get('adapter_total', '-') adapter_trainable = stats.get('adapter_trainable', '-') pretrain_total = stats.get('pretrain_total', '-') pretrain_trainable = stats.get('pretrain_trainable', '-') combined_total = stats.get('combined_total', '-') combined_trainable = stats.get('combined_trainable', '-') trainable_percentage = stats.get('trainable_percentage', '-') return [ ["Training Model", str(adapter_total), str(adapter_trainable), "-"], ["Pre-trained Model", str(pretrain_total), str(pretrain_trainable), "-"], ["Combined Model", str(combined_total), str(combined_trainable), str(trainable_percentage)] ] # Training Progress gr.Markdown("### Training Progress") with gr.Row(): progress_status = gr.HTML( value="""
Training Status: Click Start to train your model
""", label="Status" ) with gr.Row(): best_model_info = gr.Textbox( value="Best Model: None", label="Best Performance", interactive=False ) # Add test results HTML display with gr.Row(): test_results_html = gr.HTML( value="", label="Test Results", visible=True ) with gr.Row(): with gr.Column(scale=4): pass with gr.Column(scale=1): # 限制列的最大宽度 download_csv_btn = gr.DownloadButton( "Download CSV", visible=False, size="lg" ) # 添加一个空列来占据剩余空间 with gr.Column(scale=4): pass # Training plot in a separate row for full width with gr.Row(): with gr.Column(): loss_plot = gr.Plot( label="Training and Validation Loss", elem_id="loss_plot" ) with gr.Column(): metrics_plot = gr.Plot( label="Validation Metrics", elem_id="metrics_plot" ) def update_progress(progress_info): # If progress_info is empty or None, use completely fresh empty state if not progress_info or not any(progress_info.values()): fresh_status_html = """
Training Status: Click Start to train your model
""" return ( fresh_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) ) # Reset values if stage is "Waiting" or "Error" if progress_info.get('stage', '') == 'Waiting' or progress_info.get('stage', '') == 'Error': # If this is an error stage, show error styling if progress_info.get('stage', '') == 'Error': error_status_html = """
Training Status: Failed
""" return ( error_status_html, "Training failed", gr.update(value="", visible=False), None, None, gr.update(visible=False) ) else: return ( """
Training Status: Waiting to start...
""", "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) ) current = progress_info.get('current', 0) total = progress_info.get('total', 100) epoch = progress_info.get('epoch', 0) stage = progress_info.get('stage', 'Waiting') progress_detail = progress_info.get('progress_detail', '') best_epoch = progress_info.get('best_epoch', 0) best_metric_name = progress_info.get('best_metric_name', 'accuracy') best_metric_value = progress_info.get('best_metric_value', 0.0) elapsed_time = progress_info.get('elapsed_time', '') remaining_time = progress_info.get('remaining_time', '') it_per_sec = progress_info.get('it_per_sec', 0.0) grad_step = progress_info.get('grad_step', 0) loss = progress_info.get('loss', 0.0) total_epochs = progress_info.get('total_epochs', 0) # 获取总epoch数 test_results_html = progress_info.get('test_results_html', '') # 获取测试结果HTML test_metrics = progress_info.get('test_metrics', {}) # 获取测试指标 is_completed = progress_info.get('is_completed', False) # 检查训练是否完成 # Test results HTML visibility is always True, but show message when content is empty if not test_results_html and stage == 'Testing': test_results_html = """

Testing in progress, please wait for results...

""" elif not test_results_html: test_results_html = """

Test results will be displayed after testing phase completes

""" test_html_update = gr.update(value=test_results_html, visible=True) # 处理CSV下载按钮 if test_metrics and len(test_metrics) > 0: # 创建临时文件保存CSV内容 with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.csv', prefix='metrics_results_') as temp_file: # 写入CSV头部 temp_file.write("Metric,Value\n") # 按照优先级排序指标 priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc'] def get_priority(item): name = item[0] if name in priority_metrics: return priority_metrics.index(name) return len(priority_metrics) # 排序并添加到CSV sorted_metrics = sorted(test_metrics.items(), key=get_priority) for metric_name, metric_value in sorted_metrics: # Convert metric name: uppercase for abbreviations, capitalize for others display_name = metric_name if metric_name.lower() in ['f1', 'mcc', 'auroc']: display_name = metric_name.upper() else: display_name = metric_name.capitalize() temp_file.write(f"{display_name},{metric_value:.6f}\n") file_path = temp_file.name download_btn_update = gr.update(value=file_path, visible=True) else: download_btn_update = gr.update(visible=False) # 计算进度百分比 progress_percentage = (current / total) * 100 if total > 0 else 0 # 创建现代化的进度条HTML if is_completed: # 训练完成状态 status_html = """
Training Status: Training complete!
100%
""" else: # 训练或验证阶段 epoch_total = total_epochs if total_epochs > 0 else 100 status_html = f"""
Training Status: {stage} (Epoch {epoch}/{epoch_total})
{progress_percentage:.1f}%
Progress: {current}/{total}
{f'
Time: {elapsed_time}<{remaining_time}, {it_per_sec:.2f}it/s>
' if elapsed_time and remaining_time else ''} {f'
Loss: {loss:.4f}
' if stage == 'Training' and loss > 0 else ''} {f'
Grad steps: {grad_step}
' if stage == 'Training' and grad_step > 0 else ''}
""" # 构建最佳模型信息 if best_epoch >= 0 and best_metric_value > 0: best_info = f"Best model: Epoch {best_epoch} ({best_metric_name}: {best_metric_value:.4f})" else: best_info = "No best model found yet" # 获取并更新图表 loss_fig = monitor.get_loss_plot() metrics_fig = monitor.get_metrics_plot() # 返回更新的组件 return status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update def handle_train(*args) -> Generator: nonlocal is_training, current_process, stop_thread, process_aborted, monitor # If already training, return if is_training: yield None, None, None, None, None, None, None return # Force explicit state reset first thing monitor._reset_tracking() monitor._reset_stats() # Explicitly ensure stats are reset if hasattr(monitor, "stats"): monitor.stats = {} # Force override any cached state in monitor monitor.current_progress = { "current": 0, "total": 0, "epoch": 0, "stage": "Waiting", "progress_detail": "", "best_epoch": -1, "best_metric_name": "", "best_metric_value": 0.0, "elapsed_time": "", "remaining_time": "", "it_per_sec": 0.0, "grad_step": 0, "loss": 0.0, "test_results_html": "", "test_metrics": {}, "is_completed": False, "lines": [] } # Reset all monitoring data structures monitor.train_losses = [] monitor.val_losses = [] monitor.metrics = {} monitor.epochs = [] if hasattr(monitor, "stats"): monitor.stats = {} # Reset flags for new training session process_aborted = False stop_thread = False # Initialize table state initial_stats = [ ["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"] ] # Initial UI state with "Initializing" message initial_status_html = """
Training Status: Initializing training environment...

• Parsing configuration parameters

• Preparing training environment

• This may take a few moments...

""" # First yield to update UI with "initializing" state yield initial_stats, initial_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) try: # Parse training arguments training_args = TrainingArgs(args, plm_models, dataset_configs) if training_args.training_method != "ses-adapter": training_args.structure_seq = None args_dict = training_args.to_dict() # Save total epochs to monitor for use in progress_info total_epochs = args_dict.get('num_epochs', 100) monitor.current_progress['total_epochs'] = total_epochs # Update status to "Preparing dataset" preparing_status_html = """
Training Status: Preparing dataset and model...

• Loading dataset

• Initializing model architecture

• Setting up training environment

""" yield initial_stats, preparing_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) # Save arguments to file save_arguments(args_dict, args_dict.get('output_dir', 'ckpt')) # Start training is_training = True process_aborted = False # Reset abort flag monitor.start_training(args_dict) current_process = monitor.process # Store the process reference starting_status_html = """
Training Status: Starting training process...

• Training process launched

• Waiting for first statistics to appear

• This may take a moment for large models

""" yield initial_stats, starting_status_html, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False) # Add delay to ensure enough time for parsing initial statistics for i in range(3): time.sleep(1) # Check if statistics are already available stats = monitor.get_stats() if stats and len(stats) > 0: break update_count = 0 while True: # Check if the process still exists and hasn't been aborted if process_aborted or not monitor.is_training or current_process is None or (current_process and current_process.poll() is not None): break try: update_count += 1 time.sleep(0.5) # Check process status monitor.check_process_status() # Get latest progress info progress_info = monitor.get_progress() # If process has ended, check if it's normal end or error if not monitor.is_training: # Check both monitor.process and current_process since they might be different objects if (monitor.process and monitor.process.returncode != 0) or (current_process and current_process.poll() is not None and current_process.returncode != 0): # Get the return code from whichever process object is available return_code = monitor.process.returncode if monitor.process else current_process.returncode # Get complete output log error_output = "\n".join(progress_info.get("lines", [])) if not error_output: error_output = "No output captured from the training process" # Ensure we set the is_completed flag to False for errors progress_info['is_completed'] = False monitor.current_progress['is_completed'] = False # Also set the stage to Error progress_info['stage'] = 'Error' monitor.current_progress['stage'] = 'Error' error_status_html = f"""

Training failed with error code {return_code}:

{error_output}
""" yield ( initial_stats, error_status_html, "Training failed", gr.update(value="", visible=False), None, None, gr.update(visible=False) ) return else: # Only set is_completed to True if there was a successful exit code progress_info['is_completed'] = True monitor.current_progress['is_completed'] = True # Update UI stats = monitor.get_stats() if stats: model_stats = update_model_stats(stats) else: model_stats = initial_stats status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info) yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update except Exception as e: # Get complete output log error_output = "\n".join(progress_info.get("lines", [])) if not error_output: error_output = "No output captured from the training process" error_status_html = f"""

Error during training:

{str(e)}

{error_output}
""" print(f"Error updating UI: {str(e)}") traceback.print_exc() yield initial_stats, error_status_html, "Training error", gr.update(value="", visible=False), None, None, gr.update(visible=False) return # Check if aborted if process_aborted: is_training = False current_process = None aborted_status_html = """

Training was manually terminated.

""" yield initial_stats, aborted_status_html, "Training aborted", gr.update(value="", visible=False), None, None, gr.update(visible=False) return # Final update after training ends (only for normal completion) if monitor.process and monitor.process.returncode == 0: try: progress_info = monitor.get_progress() progress_info['is_completed'] = True monitor.current_progress['is_completed'] = True stats = monitor.get_stats() if stats: model_stats = update_model_stats(stats) else: model_stats = initial_stats status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update = update_progress(progress_info) yield model_stats, status_html, best_info, test_html_update, loss_fig, metrics_fig, download_btn_update except Exception as e: error_output = "\n".join(progress_info.get("lines", [])) if not error_output: error_output = "No output captured from the training process" error_status_html = f"""

Error in final update:

{str(e)}

{error_output}
""" yield initial_stats, error_status_html, "Error in final update", gr.update(value="", visible=False), None, None, gr.update(visible=False) except Exception as e: # Initialization error, may not have output log error_status_html = f"""

Training initialization failed:

{str(e)}

""" yield initial_stats, error_status_html, "Training failed", gr.update(value="", visible=False), None, None, gr.update(visible=False) finally: is_training = False current_process = None def handle_abort(): """Handle abortion of the training process""" nonlocal is_training, current_process, stop_thread, process_aborted if not is_training or current_process is None: return (gr.HTML("""

No training process is currently running.

"""), [["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"]], "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)) try: # Set the abort flag before terminating the process process_aborted = True stop_thread = True # Use process.terminate() instead of os.killpg for safer termination # This avoids accidentally killing the parent WebUI process current_process.terminate() # Wait for process to terminate (with timeout) try: current_process.wait(timeout=5) except subprocess.TimeoutExpired: # Only if terminate didn't work, use a stronger method # But do NOT use killpg which might kill the parent WebUI current_process.kill() # Create a completely fresh state - not just resetting monitor.is_training = False # Explicitly create a new dictionary instead of modifying the existing one monitor.current_progress = { "current": 0, "total": 0, "epoch": 0, "stage": "Waiting", "progress_detail": "", "best_epoch": -1, "best_metric_name": "", "best_metric_value": 0.0, "elapsed_time": "", "remaining_time": "", "it_per_sec": 0.0, "grad_step": 0, "loss": 0.0, "test_results_html": "", "test_metrics": {}, "is_completed": False, "lines": [] } # Explicitly clear stats by creating a new dictionary monitor.stats = {} if hasattr(monitor, "process") and monitor.process: monitor.process = None # Reset state variables is_training = False current_process = None # Explicitly reset tracking to clear all state monitor._reset_tracking() monitor._reset_stats() # Reset all plots and statistics with new empty lists monitor.train_losses = [] monitor.val_losses = [] monitor.metrics = {} monitor.epochs = [] # Create entirely fresh UI components empty_model_stats = [["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"]] success_html = """

Training successfully terminated!

All training state has been reset. You can start a new training session.

""" # Return updates for all relevant components return (gr.HTML(success_html), empty_model_stats, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)) except Exception as e: # Still need to reset states even if there's an error is_training = False current_process = None process_aborted = False # Reset monitor state regardless of error monitor.is_training = False monitor.stats = {} if hasattr(monitor, "process") and monitor.process: monitor.process = None monitor._reset_tracking() monitor._reset_stats() # Fresh empty components empty_model_stats = [["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"]] error_html = f"""

Failed to terminate training: {str(e)}

Training state has been reset.

""" # Return updates for all relevant components including empty model stats return (gr.HTML(error_html), empty_model_stats, "Best Model: None", gr.update(value="", visible=False), None, None, gr.update(visible=False)) def update_wandb_visibility(checkbox): return { wandb_project: gr.update(visible=checkbox), wandb_entity: gr.update(visible=checkbox) } # define all input components input_components = [ plm_model, #0 is_custom_dataset, #1 dataset_config, #2 dataset_custom, #3 problem_type, #4 num_labels, #5 metrics, #6 training_method, #7 pooling_method, #8 batch_mode, #9 batch_size, #10 batch_token, #11 learning_rate, #12 num_epochs, #13 max_seq_len, #14 gradient_accumulation_steps, #15 warmup_steps, #16 scheduler_type, #17 output_model_name, #18 output_dir, #19 wandb_logging, #20 wandb_project, #21 wandb_entity, #22 patience, #23 num_workers, #24 max_grad_norm, #25 structure_seq, #26 lora_r, #27 lora_alpha, #28 lora_dropout, #29 lora_target_modules, #30 ] # bind preview and train buttons def handle_preview(*args): if command_preview.visible: return gr.update(visible=False) training_args = TrainingArgs(args, plm_models, dataset_configs) preview_text = preview_command(training_args.to_dict()) return gr.update(value=preview_text, visible=True) def reset_train_ui(): """Reset the UI state before training starts""" # Reset monitor state monitor._reset_tracking() monitor._reset_stats() # Explicitly ensure stats are reset if hasattr(monitor, "stats"): monitor.stats = {} # Create a completely fresh progress state monitor.current_progress = { "current": 0, "total": 0, "epoch": 0, "stage": "Waiting", "progress_detail": "", "best_epoch": -1, "best_metric_name": "", "best_metric_value": 0.0, "elapsed_time": "", "remaining_time": "", "it_per_sec": 0.0, "grad_step": 0, "loss": 0.0, "test_results_html": "", "test_metrics": {}, "is_completed": False, "lines": [] } # Reset all statistical data monitor.train_losses = [] monitor.val_losses = [] monitor.metrics = {} monitor.epochs = [] # Force UI to reset by creating completely fresh components empty_model_stats = [["Training Model", "-", "-", "-"], ["Pre-trained Model", "-", "-", "-"], ["Combined Model", "-", "-", "-"]] empty_progress_status = """
Training Status: Preparing to start training...
""" # Return exactly 7 values matching the 7 output components return ( empty_model_stats, empty_progress_status, "Best Model: None", gr.update(value="", visible=False), None, # loss_plot must be None, not a string None, # metrics_plot must be None, not a string gr.update(visible=False) ) preview_button.click( fn=handle_preview, inputs=input_components, outputs=[command_preview] ) train_button.click( fn=reset_train_ui, outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] ).then( fn=handle_train, inputs=input_components, outputs=[model_stats, progress_status, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] ) # bind abort button abort_button.click( fn=handle_abort, outputs=[progress_status, model_stats, best_model_info, test_results_html, loss_plot, metrics_plot, download_csv_btn] ) wandb_logging.change( fn=update_wandb_visibility, inputs=[wandb_logging], outputs=[wandb_project, wandb_entity] ) def update_dataset_preview(dataset_type=None, dataset_name=None, custom_dataset=None): """Update dataset preview content""" # Determine which dataset to use based on selection if dataset_type == "Use Custom Dataset" and custom_dataset: try: # Try to load custom dataset dataset = load_dataset(custom_dataset) stats_html = f"""
Dataset Train Samples Val Samples Test Samples
{custom_dataset} {len(dataset["train"]) if "train" in dataset else 0} {len(dataset["validation"]) if "validation" in dataset else 0} {len(dataset["test"]) if "test" in dataset else 0}
""" # Get sample data points split = "train" if "train" in dataset else list(dataset.keys())[0] samples = dataset[split].select(range(min(3, len(dataset[split])))) if len(samples) == 0: return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Get fields actually present in the dataset available_fields = list(samples[0].keys()) # Build sample data sample_data = [] for sample in samples: sample_dict = {} for field in available_fields: # Keep full sequence sample_dict[field] = str(sample[field]) sample_data.append(sample_dict) df = pd.DataFrame(sample_data) return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) except Exception as e: error_html = f"""

Error loading dataset

{str(e)}

""" return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Use predefined dataset elif dataset_type == "Use Pre-defined Dataset" and dataset_name: try: config_path = dataset_configs[dataset_name] with open(config_path, 'r') as f: config = json.load(f) # Load dataset statistics dataset = load_dataset(config["dataset"]) stats_html = f"""
Dataset Train Samples Val Samples Test Samples
{config["dataset"]} {len(dataset["train"]) if "train" in dataset else 0} {len(dataset["validation"]) if "validation" in dataset else 0} {len(dataset["test"]) if "test" in dataset else 0}
""" # Get sample data points and available fields samples = dataset["train"].select(range(min(3, len(dataset["train"])))) if len(samples) == 0: return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Get fields actually present in the dataset available_fields = list(samples[0].keys()) # Build sample data sample_data = [] for sample in samples: sample_dict = {} for field in available_fields: # Keep full sequence sample_dict[field] = str(sample[field]) sample_data.append(sample_dict) df = pd.DataFrame(sample_data) return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) except Exception as e: error_html = f"""

Error loading dataset

{str(e)}

""" return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # If no valid dataset information provided return gr.update(value=""), gr.update(value=[["No dataset selected", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Preview button click event dataset_preview_button.click( fn=update_dataset_preview, inputs=[is_custom_dataset, dataset_config, dataset_custom], outputs=[dataset_stats_md, preview_table, preview_accordion] ) # 添加自定义数据集设置的函数 def update_dataset_settings(choice, dataset_name=None): if choice == "Use Pre-defined Dataset": # 从dataset_config加载配置 result = { dataset_config: gr.update(visible=True), dataset_custom: gr.update(visible=False), custom_dataset_settings: gr.update(visible=True) } # 如果有选择特定数据集,自动加载配置 if dataset_name and dataset_name in dataset_configs: with open(dataset_configs[dataset_name], 'r') as f: config = json.load(f) # 处理metrics,将字符串转换为列表以适应多选组件 metrics_value = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") if isinstance(metrics_value, str): metrics_value = metrics_value.split(",") result.update({ problem_type: gr.update(value=config.get("problem_type", "single_label_classification"), interactive=False), num_labels: gr.update(value=config.get("num_labels", 2), interactive=False), metrics: gr.update(value=metrics_value, interactive=False), }) return result else: # 自定义数据集设置,清零/设为默认值并可编辑 # 为多选组件提供默认值列表 default_metrics = ["accuracy", "mcc", "f1", "precision", "recall", "auroc"] return { dataset_config: gr.update(visible=False), dataset_custom: gr.update(visible=True), custom_dataset_settings: gr.update(visible=True), problem_type: gr.update(value="single_label_classification", interactive=True), num_labels: gr.update(value=2, interactive=True), metrics: gr.update(value=default_metrics, interactive=True) } # 绑定数据集设置更新事件 is_custom_dataset.change( fn=update_dataset_settings, inputs=[is_custom_dataset, dataset_config], outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics] ) dataset_config.change( fn=lambda x: update_dataset_settings("Use Pre-defined Dataset", x), inputs=[dataset_config], outputs=[dataset_config, dataset_custom, custom_dataset_settings, problem_type, num_labels, metrics] ) # Return components that need to be accessed from outside return { "output_text": progress_status, "loss_plot": loss_plot, "metrics_plot": metrics_plot, "train_button": train_button, "monitor": monitor, "test_results_html": test_results_html, # 添加测试结果HTML组件 "components": { "plm_model": plm_model, "dataset_config": dataset_config, "training_method": training_method, "pooling_method": pooling_method, "batch_mode": batch_mode, "batch_size": batch_size, "batch_token": batch_token, "learning_rate": learning_rate, "num_epochs": num_epochs, "max_seq_len": max_seq_len, "gradient_accumulation_steps": gradient_accumulation_steps, "warmup_steps": warmup_steps, "scheduler_type": scheduler_type, "output_model_name": output_model_name, "output_dir": output_dir, "wandb_logging": wandb_logging, "wandb_project": wandb_project, "wandb_entity": wandb_entity, "patience": patience, "num_workers": num_workers, "max_grad_norm": max_grad_norm, "structure_seq": structure_seq, "lora_r": lora_r, "lora_alpha": lora_alpha, "lora_dropout": lora_dropout, "lora_target_modules": lora_target_modules, } }