import gradio as gr import json import os import subprocess import sys import signal import threading import queue import time import pandas as pd import numpy as np import matplotlib.pyplot as plt import re from datasets import load_dataset from web.utils.command import preview_eval_command def create_eval_tab(constant): plm_models = constant["plm_models"] dataset_configs = constant["dataset_configs"] is_evaluating = False current_process = None output_queue = queue.Queue() stop_thread = False process_aborted = False # 新增标志,表示进程是否被手动终止 plm_models = constant["plm_models"] def format_metrics(metrics_file): """Convert metrics to HTML table format for display""" try: df = pd.read_csv(metrics_file) metrics_dict = df.iloc[0].to_dict() # 定义指标优先级顺序 priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc'] # 构建优先级排序键 def get_priority(item): name = item[0] if name in priority_metrics: return priority_metrics.index(name) return len(priority_metrics) # 按优先级排序 sorted_metrics = sorted(metrics_dict.items(), key=get_priority) # 计算指标数量 metrics_count = len(sorted_metrics) html = f"""

{metrics_count} metrics found

""" # 添加每个指标行,使用交替行颜色 for i, (metric_name, metric_value) in enumerate(sorted_metrics): row_style = 'background-color: #f9f9f9;' if i % 2 == 0 else '' # 对优先级指标使用粗体 is_priority = metric_name in priority_metrics name_style = 'font-weight: bold;' if is_priority else '' # 转换指标名称:缩写用大写,非缩写首字母大写 display_name = metric_name if metric_name.lower() in ['f1', 'mcc', 'auroc']: display_name = metric_name.upper() else: display_name = metric_name.capitalize() html += f""" """ html += """
Metric Value
{display_name} {metric_value:.4f}

Test completed at: """ + time.strftime("%Y-%m-%d %H:%M:%S") + """

""" return html except Exception as e: return f"Error formatting metrics: {str(e)}" def process_output(process, queue): nonlocal stop_thread while True: if stop_thread: break output = process.stdout.readline() if output == '' and process.poll() is not None: break if output: queue.put(output.strip()) process.stdout.close() def evaluate_model(plm_model, model_path, eval_method, is_custom_dataset, dataset_defined, dateset_custom, problem_type, num_labels, metrics, batch_mode, batch_size, batch_token, eval_structure_seq, pooling_method): nonlocal is_evaluating, current_process, stop_thread, process_aborted if is_evaluating: return "Evaluation is already in progress. Please wait...", gr.update(visible=False) # First reset all state variables to ensure clean start is_evaluating = True stop_thread = False process_aborted = False # Reset abort flag # Clear the output queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break # Initialize progress info and start time start_time = time.time() progress_info = { "stage": "Preparing", "progress": 0, "total_samples": 0, "current": 0, "total": 0, "elapsed_time": "00:00:00", "lines": [] } # Create initial progress bar with completely empty state initial_progress_html = generate_progress_bar(progress_info) yield initial_progress_html, gr.update(visible=False) try: # Validate inputs if not model_path or not os.path.exists(os.path.dirname(model_path)): is_evaluating = False yield """

Error: Invalid model path

""", gr.update(visible=False) return if is_custom_dataset == "Use Custom Dataset": dataset = dateset_custom test_file = dateset_custom else: dataset = dataset_defined if dataset not in dataset_configs: is_evaluating = False yield """

Error: Invalid dataset selection

""", gr.update(visible=False) return config_path = dataset_configs[dataset] with open(config_path, 'r') as f: dataset_config = json.load(f) test_file = dataset_config["dataset"] # Get dataset name dataset_display_name = dataset.split('/')[-1] test_result_name = f"test_results_{os.path.basename(model_path)}_{dataset_display_name}" test_result_dir = os.path.join(os.path.dirname(model_path), test_result_name) # Prepare command cmd = [sys.executable, "src/eval.py"] args_dict = { "eval_method": eval_method, "model_path": model_path, "test_file": test_file, "problem_type": problem_type, "num_labels": num_labels, "metrics": ",".join(metrics), "plm_model": plm_models[plm_model], "test_result_dir": test_result_dir, "dataset": dataset_display_name, "pooling_method": pooling_method, } if batch_mode == "Batch Size Mode": args_dict["batch_size"] = batch_size else: args_dict["batch_token"] = batch_token if eval_method == "ses-adapter": args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None # Add flags for using foldseek and ss8 if "foldseek_seq" in eval_structure_seq: args_dict["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args_dict["use_ss8"] = True else: args_dict["structure_seq"] = "" for k, v in args_dict.items(): if v is True: cmd.append(f"--{k}") elif v is not False and v is not None: cmd.append(f"--{k}") cmd.append(str(v)) # Start evaluation process current_process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True, preexec_fn=os.setsid ) output_thread = threading.Thread(target=process_output, args=(current_process, output_queue)) output_thread.daemon = True output_thread.start() sample_pattern = r"Total samples: (\d+)" progress_pattern = r"(\d+)/(\d+)" last_update_time = time.time() while True: # Check if the process still exists and hasn't been aborted if process_aborted or current_process is None or current_process.poll() is not None: break try: new_lines = [] lines_processed = 0 while lines_processed < 10: try: line = output_queue.get_nowait() new_lines.append(line) progress_info["lines"].append(line) # print(line) # Parse total samples if "Total samples" in line: match = re.search(sample_pattern, line) if match: progress_info["total_samples"] = int(match.group(1)) progress_info["stage"] = "Evaluating" # Parse progress if "it/s" in line and "/" in line: match = re.search(progress_pattern, line) if match: progress_info["current"] = int(match.group(1)) progress_info["total"] = int(match.group(2)) progress_info["progress"] = (progress_info["current"] / progress_info["total"]) * 100 if "Evaluation completed" in line: progress_info["stage"] = "Completed" progress_info["progress"] = 100 lines_processed += 1 except queue.Empty: break # 无论是否有新行,都更新时间信息 elapsed = time.time() - start_time hours, remainder = divmod(int(elapsed), 3600) minutes, seconds = divmod(remainder, 60) progress_info["elapsed_time"] = f"{hours:02}:{minutes:02}:{seconds:02}" # 即使没有新行,也定期更新进度条(每0.5秒) current_time = time.time() if lines_processed > 0 or (current_time - last_update_time) >= 0.5: # Generate progress bar HTML progress_html = generate_progress_bar(progress_info) # Only yield updates if there's actual new information yield progress_html, gr.update(visible=False) last_update_time = current_time time.sleep(0.1) # 减少循环间隔,使更新更频繁 except Exception as e: yield f"""

Error reading output: {str(e)}

""", gr.update(visible=False) if current_process.returncode == 0: # Load and format results result_file = os.path.join(test_result_dir, "evaluation_metrics.csv") if os.path.exists(result_file): metrics_html = format_metrics(result_file) # Calculate total evaluation time total_time = time.time() - start_time hours, remainder = divmod(int(total_time), 3600) minutes, seconds = divmod(remainder, 60) time_str = f"{hours:02}:{minutes:02}:{seconds:02}" summary_html = f"""

Evaluation completed successfully!

Total evaluation time: {time_str}

Evaluation dataset: {dataset_display_name}

Total samples: {progress_info.get('total_samples', 0)}

Evaluation Results
{metrics_html} """ # 设置下载按钮可见并指向结果文件 yield summary_html, gr.update(value=result_file, visible=True) else: error_output = "\n".join(progress_info.get("lines", [])) yield f"""

Evaluation completed, but metrics file not found at: {result_file}

""", gr.update(visible=False) else: error_output = "\n".join(progress_info.get("lines", [])) if not error_output: error_output = "No output captured from the evaluation process" yield f"""

Evaluation failed:

{error_output}
""", gr.update(visible=False) except Exception as e: yield f"""

Error during evaluation process:

{str(e)}
""", gr.update(visible=False) finally: if current_process: stop_thread = True is_evaluating = False current_process = None def generate_progress_bar(progress_info): """Generate HTML for evaluation progress bar""" stage = progress_info.get("stage", "Preparing") progress = progress_info.get("progress", 0) current = progress_info.get("current", 0) total = progress_info.get("total", 0) total_samples = progress_info.get("total_samples", 0) # 确保进度在0-100之间 progress = max(0, min(100, progress)) # 准备详细信息 details = [] if total_samples > 0: details.append(f"Total samples: {total_samples}") if current > 0 and total > 0: details.append(f"Current progress: {current}/{total}") # 计算评估时间(如果有) elapsed_time = progress_info.get("elapsed_time", "") if elapsed_time: details.append(f"Elapsed time: {elapsed_time}") details_text = ", ".join(details) # 创建更现代化的进度条 html = f"""
Evaluation Status: {stage}
{progress:.1f}%
{f'
Total samples: {total_samples}
' if total_samples > 0 else ''} {f'
Progress: {current}/{total}
' if current > 0 and total > 0 else ''} {f'
Time: {elapsed_time}
' if elapsed_time else ''}
""" return html def handle_abort(): """Handle abortion of the evaluation process""" nonlocal is_evaluating, current_process, stop_thread, process_aborted if current_process is None: return """

No evaluation in progress to terminate.

""", gr.update(visible=False) try: # Set the abort flag before terminating the process process_aborted = True stop_thread = True # Using terminate instead of killpg for safety current_process.terminate() # Wait for process to terminate (with timeout) try: current_process.wait(timeout=5) except subprocess.TimeoutExpired: current_process.kill() # Reset state completely current_process = None is_evaluating = False # Reset output queue to clear any pending messages while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break return """

Evaluation successfully terminated!

All evaluation state has been reset.

""", gr.update(visible=False) except Exception as e: # Still need to reset states even if there's an error current_process = None is_evaluating = False process_aborted = False # Reset output queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break return f"""

Failed to terminate evaluation: {str(e)}

Evaluation state has been reset.

""", gr.update(visible=False) with gr.Tab("Evaluation"): gr.Markdown("### Model and Dataset Configuration") # Original evaluation interface components with gr.Group(): with gr.Row(): eval_model_path = gr.Textbox( label="Model Path", value="ckpt/demo/demo_provided.pt", placeholder="Path to the trained model" ) eval_plm_model = gr.Dropdown( choices=list(plm_models.keys()), label="Protein Language Model" ) with gr.Row(): eval_method = gr.Dropdown( choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"], label="Evaluation Method", value="freeze" ) eval_pooling_method = gr.Dropdown( choices=["mean", "attention1d", "light_attention"], label="Pooling Method", value="mean" ) with gr.Row(): with gr.Column(scale=4): with gr.Row(): is_custom_dataset = gr.Radio( choices=["Use Custom Dataset", "Use Pre-defined Dataset"], label="Dataset Selection", value="Use Pre-defined Dataset" ) eval_dataset_defined = gr.Dropdown( choices=list(dataset_configs.keys()), label="Evaluation Dataset", visible=True ) eval_dataset_custom = gr.Textbox( label="Custom Dataset Path", placeholder="Huggingface Dataset eg: user/dataset", visible=False ) with gr.Column(scale=1, min_width=120, elem_classes="preview-button-container"): # Add dataset preview functionality preview_button = gr.Button( "Preview Dataset", variant="primary", size="lg", elem_classes="preview-button" ) # 将数据统计和表格都放入折叠面板 with gr.Row(): with gr.Accordion("Dataset Preview", open=False) as preview_accordion: # 数据统计区域 with gr.Row(): dataset_stats_md = gr.HTML("", elem_classes=["dataset-stats"]) # 表格区域 with gr.Row(): preview_table = gr.Dataframe( headers=["Name", "Sequence", "Label"], value=[["No dataset selected", "-", "-"]], wrap=True, interactive=False, row_count=3, elem_classes=["preview-table"] ) # Add CSS styles gr.HTML(""" """, visible=True) ### These are settings for custom dataset. ### with gr.Row(visible=True) as custom_dataset_settings: problem_type = gr.Dropdown( choices=["single_label_classification", "multi_label_classification", "regression"], label="Problem Type", value="single_label_classification", scale=23, interactive=False ) num_labels = gr.Number( value=2, label="Number of Labels", scale=11, interactive=False ) metrics = gr.Dropdown( choices=["accuracy", "recall", "precision", "f1", "mcc", "auroc", "f1_max", "spearman_corr", "mse"], label="Metrics", value=["accuracy", "mcc", "f1", "precision", "recall", "auroc"], scale=101, multiselect=True, interactive=False ) # Add dataset preview function def update_dataset_preview(dataset_type=None, defined_dataset=None, custom_dataset=None): """Update dataset preview content""" # Determine which dataset to use based on selection if dataset_type == "Use Custom Dataset" and custom_dataset: try: # Try to load custom dataset dataset = load_dataset(custom_dataset) stats_html = f"""
Dataset Train Samples Val Samples Test Samples
{custom_dataset} {len(dataset["train"]) if "train" in dataset else 0} {len(dataset["validation"]) if "validation" in dataset else 0} {len(dataset["test"]) if "test" in dataset else 0}
""" # Get sample data points split = "train" if "train" in dataset else list(dataset.keys())[0] samples = dataset[split].select(range(min(3, len(dataset[split])))) if len(samples) == 0: return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Get fields actually present in the dataset available_fields = list(samples[0].keys()) # Build sample data sample_data = [] for sample in samples: sample_dict = {} for field in available_fields: # Keep full sequence sample_dict[field] = str(sample[field]) sample_data.append(sample_dict) df = pd.DataFrame(sample_data) return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) except Exception as e: error_html = f"""

Error loading dataset

{str(e)}

""" return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Use predefined dataset elif dataset_type == "Use Pre-defined Dataset" and defined_dataset: try: config_path = dataset_configs[defined_dataset] with open(config_path, 'r') as f: config = json.load(f) # Load dataset statistics dataset = load_dataset(config["dataset"]) stats_html = f"""
Dataset Train Samples Val Samples Test Samples
{config["dataset"]} {len(dataset["train"]) if "train" in dataset else 0} {len(dataset["validation"]) if "validation" in dataset else 0} {len(dataset["test"]) if "test" in dataset else 0}
""" # Get sample data points and available fields samples = dataset["train"].select(range(min(3, len(dataset["train"])))) if len(samples) == 0: return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Get fields actually present in the dataset available_fields = list(samples[0].keys()) # Build sample data sample_data = [] for sample in samples: sample_dict = {} for field in available_fields: # Keep full sequence sample_dict[field] = str(sample[field]) sample_data.append(sample_dict) df = pd.DataFrame(sample_data) return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) except Exception as e: error_html = f"""

Error loading dataset

{str(e)}

""" return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # If no valid dataset information provided return gr.update(value=""), gr.update(value=[["No dataset selected", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) # Preview button click event preview_button.click( fn=update_dataset_preview, inputs=[is_custom_dataset, eval_dataset_defined, eval_dataset_custom], outputs=[dataset_stats_md, preview_table, preview_accordion] ) def update_dataset_settings(choice, dataset_name=None): if choice == "Use Pre-defined Dataset": # Load configuration from dataset_config if dataset_name and dataset_name in dataset_configs: with open(dataset_configs[dataset_name], 'r') as f: config = json.load(f) # 处理metrics,将字符串转换为列表以适应多选组件 metrics_value = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") if isinstance(metrics_value, str): metrics_value = metrics_value.split(",") return [ gr.update(visible=True), # eval_dataset_defined gr.update(visible=False), # eval_dataset_custom gr.update(value=config.get("problem_type", ""), interactive=False), gr.update(value=config.get("num_labels", 1), interactive=False), gr.update(value=metrics_value, interactive=False) ] else: # Custom dataset settings return [ gr.update(visible=False), # eval_dataset_defined gr.update(visible=True), # eval_dataset_custom gr.update(value="", interactive=True), gr.update(value=2, interactive=True), gr.update(value="", interactive=True) ] is_custom_dataset.change( fn=update_dataset_settings, inputs=[is_custom_dataset, eval_dataset_defined], outputs=[eval_dataset_defined, eval_dataset_custom, problem_type, num_labels, metrics] ) eval_dataset_defined.change( fn=lambda x: update_dataset_settings("Use Pre-defined Dataset", x), inputs=[eval_dataset_defined], outputs=[eval_dataset_defined, eval_dataset_custom, problem_type, num_labels, metrics] ) ### These are settings for different training methods. ### # for ses-adapter with gr.Row(visible=False) as structure_seq_row: eval_structure_seq = gr.CheckboxGroup( label="Structure Sequence", choices=["foldseek_seq", "ss8_seq"], value=["foldseek_seq", "ss8_seq"] ) def update_training_method(method): return { structure_seq_row: gr.update(visible=method == "ses-adapter") } eval_method.change( fn=update_training_method, inputs=[eval_method], outputs=[structure_seq_row] ) gr.Markdown("### Batch Processing Configuration") with gr.Group(): with gr.Row(equal_height=True): with gr.Column(scale=1): batch_mode = gr.Radio( choices=["Batch Size Mode", "Batch Token Mode"], label="Batch Processing Mode", value="Batch Size Mode" ) with gr.Column(scale=2): batch_size = gr.Slider( minimum=1, maximum=128, value=16, step=1, label="Batch Size", visible=True ) batch_token = gr.Slider( minimum=1000, maximum=50000, value=10000, step=1000, label="Tokens per Batch", visible=False ) def update_batch_inputs(mode): return { batch_size: gr.update(visible=mode == "Batch Size Mode"), batch_token: gr.update(visible=mode == "Batch Token Mode") } # Update visibility when mode changes batch_mode.change( fn=update_batch_inputs, inputs=[batch_mode], outputs=[batch_size, batch_token] ) with gr.Row(): preview_button = gr.Button("Preview Command") abort_button = gr.Button("Abort", variant="stop") eval_button = gr.Button("Start Evaluation", variant="primary") with gr.Row(): command_preview = gr.Code( label="Command Preview", language="shell", interactive=False, visible=False ) def handle_preview(plm_model, model_path, eval_method, is_custom_dataset, dataset_defined, dataset_custom, problem_type, num_labels, metrics, batch_mode, batch_size, batch_token, eval_structure_seq, eval_pooling_method): """处理预览命令按钮点击事件""" if command_preview.visible: return gr.update(visible=False) # 构建参数字典 args = { "plm_model": plm_models[plm_model], "model_path": model_path, "eval_method": eval_method, "pooling_method": eval_pooling_method } # 处理数据集相关参数 if is_custom_dataset == "Use Custom Dataset": args["dataset"] = dataset_custom args["problem_type"] = problem_type args["num_labels"] = num_labels args["metrics"] = ",".join(metrics) else: with open(dataset_configs[dataset_defined], 'r') as f: config = json.load(f) args["dataset_config"] = dataset_configs[dataset_defined] # 处理批处理参数 if batch_mode == "Batch Size Mode": args["batch_size"] = batch_size else: args["batch_token"] = batch_token # 处理结构序列参数 if eval_method == "ses-adapter" and eval_structure_seq: args["structure_seq"] = ",".join(eval_structure_seq) if "foldseek_seq" in eval_structure_seq: args["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args["use_ss8"] = True # 生成预览命令 preview_text = preview_eval_command(args) return gr.update(value=preview_text, visible=True) # 绑定预览按钮事件 preview_button.click( fn=handle_preview, inputs=[ eval_plm_model, eval_model_path, eval_method, is_custom_dataset, eval_dataset_defined, eval_dataset_custom, problem_type, num_labels, metrics, batch_mode, batch_size, batch_token, eval_structure_seq, eval_pooling_method ], outputs=[command_preview] ) eval_output = gr.HTML( value="

Click the 「Start Evaluation」 button to begin model evaluation

", label="Evaluation Status & Results" ) with gr.Row(): with gr.Column(scale=4): pass with gr.Column(scale=1): download_csv_btn = gr.DownloadButton( "Download CSV", visible=False, size="lg" ) with gr.Column(scale=4): pass # Connect buttons to functions eval_button.click( fn=evaluate_model, inputs=[ eval_plm_model, eval_model_path, eval_method, is_custom_dataset, eval_dataset_defined, eval_dataset_custom, problem_type, num_labels, metrics, batch_mode, batch_size, batch_token, eval_structure_seq, eval_pooling_method ], outputs=[eval_output, download_csv_btn] ) abort_button.click( fn=handle_abort, inputs=[], outputs=[eval_output, download_csv_btn] ) return { "eval_button": eval_button, "eval_output": eval_output }