Spaces:
Runtime error
Runtime error
import threading | |
import queue | |
import subprocess | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import os | |
import signal | |
import re | |
import time | |
from typing import Dict, Any, Optional | |
from .command import build_command_list | |
import logging | |
import json | |
import io | |
import base64 | |
class TrainingMonitor: | |
def __init__(self): | |
"""Initialize training monitor.""" | |
# Queues for thread-safe data exchange | |
self.stats_queue = queue.Queue() | |
self.message_queue = queue.Queue() | |
self.is_training = False | |
self.stop_thread = False | |
self.process = None | |
self.training_thread = None | |
self.debug_progress = False # Enable for debug info | |
# Metrics tracking | |
self._reset_tracking() | |
# Progress tracking | |
self.current_progress = { | |
'stage': 'Waiting', # Training/Validation/Testing | |
'progress': '', # Progress bar text | |
'epoch': 0, | |
'current': 0, | |
'total': 100, | |
'total_epochs': 0, # Add total_epochs field, for storing total training rounds | |
'val_accuracy': 0.0, | |
'best_accuracy': 0.0, | |
'best_epoch': 0, | |
'best_metric_name': 'accuracy', # Name of the best metric | |
'best_metric_value': 0.0, # Value of the best metric | |
'progress_detail': '', # Detailed progress information | |
'elapsed_time': '', # Elapsed time | |
'remaining_time': '', # Remaining time | |
'it_per_sec': 0.0, # Iterations per second | |
'grad_step': 0, # Gradient steps | |
'loss': 0.0, # Loss value | |
'test_metrics': {}, # Add test metrics container | |
'test_progress': 0.0, # Test progress percentage | |
'test_results_html': '', # HTML formatted test results | |
'lines': [] # 添加lines字段来存储输出行 | |
} | |
self.error_message = None | |
self.skip_output_patterns = [ | |
r"Model Parameters Statistics:", | |
r"Dataset Statistics:", | |
r"Sample \d+ data points from train dataset:" | |
] | |
# Simplified regex patterns | |
self.patterns = { | |
# Basic training log patterns | |
'train': r'Epoch (\d+) Train Loss: ([\d.]+)', | |
'val': r'Epoch (\d+) Val Loss: ([\d.]+)', | |
'val_metric': r'Epoch (\d+) Val ([a-zA-Z_\-]+(?:\s[a-zA-Z_\-]+)*): ([\d.]+)', | |
'epoch_header': r'---------- Epoch (\d+) ----------', | |
'best_save': r'Saving model with best val ([a-zA-Z_\-]+(?:\s[a-zA-Z_\-]+)*): ([\d.]+)', | |
# Test result patterns - improved to match log format exactly | |
'test_header': r'Test Results:', | |
'test_phase_start': r'---------- Starting Test Phase ----------', | |
# 修改测试指标模式,使其更加通用 | |
'test_metric': r'Test\s+([a-zA-Z0-9_\-]+):\s+([\d.]+)', | |
# 添加特定的f1指标模式 | |
'test_f1': r'Test\s+f1:\s+([\d.]+)', | |
# 其他常见指标模式 | |
'test_common_metrics': r'Test\s+((?:accuracy|precision|recall|auroc|mcc)):\s*([\d.]+)', | |
# 特定的loss模式 | |
'test_loss': r'Test\s+Loss:\s*([\d.]+)', | |
# 替代格式模式 | |
'test_alt_format': r'([a-zA-Z0-9_\-]+(?:\s[a-zA-Z0-9_\-]+)*)\s+on\s+test:\s*([\d.]+)', | |
# Model parameter statistics | |
'model_param': r'([A-Za-z\s]+):\s*([\d,.]+[KM]?)', | |
} | |
# Progress bar patterns - Updated to handle both Validating and Testing phases | |
self.progress_patterns = { | |
'train': r'Training:\s*(\d+)%\|[^|]*\|\s*(\d+)/(\d+)\s*\[([\d:]+)<([\d:]+),\s*([\d.]+)it/s(?:,\s*grad_step=(\d+),\s*train_loss=([\d.]+))?\]', | |
# Combined pattern for both Validating and Testing since they use same tqdm format | |
'valid_or_test': r'(?:Validating|Valid|Testing|Test):\s*(\d+)%\|[^|]*\|\s*(\d+)/(\d+)\s*\[([\d:]+)<([\d:]+),\s*([\d.]+)it/s(?:[^\]]*)\]', | |
} | |
# Test results storage | |
self.test_results = {} | |
self.parsing_test_results = False | |
self.test_results_table = None | |
self.test_results_html = None | |
def _should_skip_line(self, line: str) -> bool: | |
"""Check if the line should be skipped from output.""" | |
for pattern in self.skip_output_patterns: | |
if re.search(pattern, line): | |
return True | |
return False | |
def _process_output(self, process): | |
"""Process output from training process in real-time.""" | |
while True: | |
if self.stop_thread: | |
break | |
output = process.stdout.readline() | |
if output == '' and process.poll() is not None: | |
break | |
if output: | |
line = output.strip() | |
if not self._should_skip_line(line): | |
self.message_queue.put(line) | |
self._process_output_line(line) | |
process.stdout.close() | |
def start_training(self, args: Dict[str, Any]): | |
"""Start training process.""" | |
if self.is_training: | |
self.message_queue.put("Training already in progress") | |
return | |
self.is_training = True | |
self.stop_thread = False | |
self._reset_tracking() | |
self._reset_stats() | |
self.error_message = None | |
# Store total epochs for progress calculation | |
self.current_progress['total_epochs'] = args.get('num_epochs', 100) | |
try: | |
# Build command | |
cmd = build_command_list(args) | |
# Log command | |
self.message_queue.put(f"Starting training with command: {' '.join(cmd)}") | |
# Start process | |
self.process = subprocess.Popen( | |
cmd, | |
stdout=subprocess.PIPE, | |
stderr=subprocess.STDOUT, | |
text=True, | |
bufsize=1, | |
universal_newlines=True | |
) | |
# Start thread to process output | |
self.training_thread = threading.Thread( | |
target=self._process_output, | |
args=(self.process,) | |
) | |
self.training_thread.daemon = True | |
self.training_thread.start() | |
except Exception as e: | |
self.error_message = f"Error starting training: {str(e)}" | |
self.is_training = False | |
self.message_queue.put(f"ERROR: {self.error_message}") | |
def abort_training(self): | |
"""Abort the training process.""" | |
if self.process: | |
# Save completed state before termination | |
was_completed = self.current_progress.get('is_completed', False) | |
# Terminate process | |
try: | |
os.killpg(os.getpgid(self.process.pid), signal.SIGTERM) | |
except: | |
self.process.terminate() | |
# Mark as not training | |
self.is_training = False | |
# Fully reset the tracking state | |
self._reset_tracking() | |
self._reset_stats() | |
# Create fresh progress state | |
self.current_progress = { | |
'stage': 'Aborted', | |
'progress': '', | |
'epoch': 0, | |
'current': 0, | |
'total': 0, | |
'total_epochs': 0, | |
'val_accuracy': 0.0, | |
'best_accuracy': 0.0, | |
'best_epoch': -1, | |
'best_metric_name': '', | |
'best_metric_value': 0.0, | |
'progress_detail': '', | |
'elapsed_time': '', | |
'remaining_time': '', | |
'it_per_sec': 0.0, | |
'grad_step': 0, | |
'loss': 0.0, | |
'test_metrics': {}, | |
'test_progress': 0.0, | |
'test_results_html': '', | |
'is_completed': False, | |
'lines': [] | |
} | |
# Clear process reference | |
self.process = None | |
# Return reset state | |
return { | |
'progress_status': "Training aborted by user.", | |
'best_model': "Training aborted by user.", | |
'test_results': "", | |
'plot': None | |
} | |
def get_messages(self) -> str: | |
"""Get all messages from queue.""" | |
messages = [] | |
while not self.message_queue.empty(): | |
try: | |
messages.append(self.message_queue.get_nowait()) | |
except queue.Empty: | |
break | |
message_text = "\n".join(messages) | |
if self.error_message: | |
message_text += f"\n\nERROR: {self.error_message}" | |
return message_text | |
def get_loss_plot(self): | |
""" | |
Generate a static plot showing training and validation loss. | |
Returns: | |
matplotlib Figure object for display in gr.Plot | |
""" | |
# Return None if insufficient data | |
if not self.epochs or (not self.train_losses and not self.val_losses): | |
return None | |
try: | |
import matplotlib.pyplot as plt | |
import matplotlib | |
# Close any existing figures to prevent memory leaks | |
plt.close('all') | |
# 设置科研风格的matplotlib样式 | |
plt.style.use('seaborn-v0_8-whitegrid') | |
matplotlib.rcParams.update({ | |
'font.family': ['serif', 'DejaVu Serif'], | |
'font.size': 12, | |
'axes.labelsize': 14, | |
'axes.titlesize': 16, | |
'xtick.labelsize': 12, | |
'ytick.labelsize': 12, | |
'legend.fontsize': 10, | |
'figure.titlesize': 18, | |
'figure.figsize': (8, 6), | |
'figure.dpi': 150, | |
'axes.grid': True, | |
'grid.alpha': 0.3, | |
'axes.axisbelow': True, | |
'axes.edgecolor': '#888888', | |
'axes.linewidth': 1.5, | |
'axes.spines.top': False, | |
'axes.spines.right': False, | |
}) | |
# 创建图表 | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
# 绘制训练损失 | |
if self.train_losses: | |
valid_indices = [i for i, loss in enumerate(self.train_losses) if loss is not None] | |
if valid_indices: # 确保有有效数据 | |
valid_epochs = [self.epochs[i] for i in valid_indices] | |
valid_losses = [self.train_losses[i] for i in valid_indices] | |
ax.plot(valid_epochs, valid_losses, 'o-', label='Train Loss', | |
color='#1f77b4', linewidth=2, markersize=6, markeredgecolor='white', | |
markeredgewidth=1.5) | |
# 绘制验证损失 | |
if self.val_losses: | |
valid_indices = [i for i, loss in enumerate(self.val_losses) if loss is not None] | |
if valid_indices: # 确保有有效数据 | |
valid_epochs = [self.epochs[i] for i in valid_indices] | |
valid_losses = [self.val_losses[i] for i in valid_indices] | |
ax.plot(valid_epochs, valid_losses, 'o-', label='Validation Loss', | |
color='#ff7f0e', linewidth=2, markersize=6, markeredgecolor='white', | |
markeredgewidth=1.5) | |
# 设置损失图表属性 | |
ax.set_title('Training and Validation Loss', fontweight='bold', pad=15) | |
ax.set_xlabel('Epoch', fontweight='bold') | |
ax.set_ylabel('Loss', fontweight='bold') | |
# 确保有图例数据后再添加图例 | |
handles, labels = ax.get_legend_handles_labels() | |
if handles: | |
ax.legend(loc='upper right', frameon=True, fancybox=True, | |
framealpha=0.9, edgecolor='gray', facecolor='white') | |
# 设置x轴刻度为整数 | |
ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) | |
# 如果所有损失值都是正数,则y轴从0开始 | |
if self.train_losses and self.val_losses: | |
all_losses = [l for l in self.train_losses + self.val_losses if l is not None] | |
if all_losses and min(all_losses) >= 0: | |
ax.set_ylim(bottom=0) | |
# 调整布局 | |
plt.tight_layout() | |
# 返回图表 | |
return fig | |
except Exception as e: | |
print(f"Error generating loss plot: {str(e)}") | |
plt.close('all') # Close any open figures in case of error | |
return None | |
def get_metrics_plot(self): | |
""" | |
Generate a static plot showing validation metrics. | |
Returns: | |
matplotlib Figure object for display in gr.Plot | |
""" | |
# Return None if insufficient data | |
if not self.epochs or not self.val_metrics: | |
return None | |
try: | |
import matplotlib.pyplot as plt | |
import matplotlib | |
# Close any existing figures to prevent memory leaks | |
plt.close('all') | |
# 设置科研风格的matplotlib样式 | |
plt.style.use('seaborn-v0_8-whitegrid') | |
matplotlib.rcParams.update({ | |
'font.family': ['serif', 'DejaVu Serif'], | |
'font.size': 12, | |
'axes.labelsize': 14, | |
'axes.titlesize': 16, | |
'xtick.labelsize': 12, | |
'ytick.labelsize': 12, | |
'legend.fontsize': 10, | |
'figure.titlesize': 18, | |
'figure.figsize': (8, 6), | |
'figure.dpi': 150, | |
'axes.grid': True, | |
'grid.alpha': 0.3, | |
'axes.axisbelow': True, | |
'axes.edgecolor': '#888888', | |
'axes.linewidth': 1.5, | |
'axes.spines.top': False, | |
'axes.spines.right': False, | |
}) | |
# 创建图表 | |
fig, ax = plt.subplots(figsize=(8, 6)) | |
# 绘制验证指标图表 | |
colors = ['#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf'] | |
# 检查是否有任何指标有有效数据 | |
has_valid_data = False | |
# 为每个指标绘制一条线 | |
for i, (metric_name, values) in enumerate(self.val_metrics.items()): | |
if values: | |
valid_indices = [i for i, val in enumerate(values) if val is not None] | |
if valid_indices: # 确保有有效数据 | |
has_valid_data = True | |
valid_epochs = [self.epochs[i] for i in valid_indices] | |
valid_values = [values[i] for i in valid_indices] | |
# 确保所有值都不超过1.0 | |
valid_values = [min(val, 1.0) for val in valid_values] | |
# 格式化指标名称:缩写全大写,非缩写首字母大写 | |
formatted_name = metric_name | |
if metric_name.lower() in ['acc', 'f1', 'mcc', 'auroc']: | |
formatted_name = metric_name.upper() | |
else: | |
formatted_name = metric_name.capitalize() | |
ax.plot(valid_epochs, valid_values, 'o-', | |
label=formatted_name, | |
color=colors[i % len(colors)], | |
linewidth=2, | |
markersize=6, | |
markeredgecolor='white', | |
markeredgewidth=1.5) | |
# 如果没有有效数据,返回None | |
if not has_valid_data: | |
plt.close(fig) | |
return None | |
# 设置验证指标图表属性 | |
ax.set_title('Validation Metrics', fontweight='bold', pad=15) | |
ax.set_xlabel('Epoch', fontweight='bold') | |
ax.set_ylabel('Value', fontweight='bold') | |
handles, labels = ax.get_legend_handles_labels() | |
if handles: | |
ax.legend(loc='lower right', frameon=True, fancybox=True, | |
framealpha=0.9, edgecolor='gray', facecolor='white') | |
ax.xaxis.set_major_locator(matplotlib.ticker.MaxNLocator(integer=True)) | |
# 严格限制y轴范围在0到1之间 | |
ax.set_ylim(0, 1.0) | |
# # 标记最佳模型位置 | |
# best_epoch = self.current_progress.get('best_epoch', 0) | |
# best_metric_name = self.current_progress.get('best_metric_name', '') | |
# best_metric_value = self.current_progress.get('best_metric_value', 0.0) | |
# # if best_epoch > 0 and best_metric_name in self.val_metrics: | |
# # metric_values = self.val_metrics[best_metric_name] | |
# # if best_epoch <= len(metric_values) and metric_values[best_epoch-1] is not None: | |
# # best_y = metric_values[best_epoch-1] | |
# # ax.scatter([best_epoch], [best_y], color='red', s=120, zorder=5, | |
# # marker='*', edgecolor='white', linewidth=1.5) | |
# # ax.annotate(f'Best: {best_metric_value:.4f}', | |
# # xy=(best_epoch, best_y), | |
# # xytext=(10, -15), | |
# # textcoords='offset points', | |
# # color='red', | |
# # fontsize=12, | |
# # fontweight='bold', | |
# # arrowprops=dict(arrowstyle='->', | |
# # connectionstyle='arc3,rad=.2', | |
# # color='red')) | |
plt.tight_layout() | |
# 返回图表 | |
return fig | |
except Exception as e: | |
print(f"Error generating metrics plot: {str(e)}") | |
plt.close('all') # Close any open figures in case of error | |
return None | |
def get_plot(self): | |
""" | |
Legacy function for compatibility. | |
Returns: | |
None (use get_loss_plot and get_metrics_plot instead) | |
""" | |
return None | |
def get_progress(self) -> Dict[str, Any]: | |
"""Return current progress information.""" | |
# Ensure we're returning a deep copy to prevent reference issues | |
progress_copy = self.current_progress.copy() | |
# Ensure all expected keys have default values if missing | |
default_progress = { | |
'stage': 'Waiting', | |
'progress': '', | |
'epoch': 0, | |
'current': 0, | |
'total': 0, | |
'total_epochs': 0, | |
'val_accuracy': 0.0, | |
'best_accuracy': 0.0, | |
'best_epoch': -1, | |
'best_metric_name': '', | |
'best_metric_value': 0.0, | |
'progress_detail': '', | |
'elapsed_time': '', | |
'remaining_time': '', | |
'it_per_sec': 0.0, | |
'grad_step': 0, | |
'loss': 0.0, | |
'test_metrics': {}, | |
'test_progress': 0.0, | |
'test_results_html': '', | |
'lines': [] | |
} | |
# Update with defaults for any missing keys | |
for key, value in default_progress.items(): | |
if key not in progress_copy: | |
progress_copy[key] = value | |
return progress_copy | |
def _process_output_line(self, line: str): | |
"""Process training output line for metric tracking.""" | |
try: | |
# 保存每一行输出到progress_info中 | |
if 'lines' not in self.current_progress: | |
self.current_progress['lines'] = [] | |
self.current_progress['lines'].append(line) | |
# 限制保存的行数,避免内存占用过大 | |
max_lines = 1000 # 保留最近的1000行 | |
if len(self.current_progress['lines']) > max_lines: | |
self.current_progress['lines'] = self.current_progress['lines'][-max_lines:] | |
# Always check for test progress if in Testing stage | |
if self.current_progress.get('stage') == 'Testing': | |
if self._process_test_progress(line): | |
return | |
# Check for test phase start | |
if re.search(self.patterns['test_phase_start'], line): | |
self.current_progress['stage'] = 'Testing' | |
# Reset test metrics at the start of test phase | |
self.current_progress['test_metrics'] = {} | |
self.current_progress['test_results_html'] = '' | |
return | |
# Check for epoch header pattern (e.g., "---------- Epoch 1 ----------") | |
epoch_header_match = re.search(self.patterns['epoch_header'], line) | |
if epoch_header_match: | |
new_epoch = int(epoch_header_match.group(1)) | |
# Update current epoch | |
self.current_epoch = new_epoch | |
self.current_progress['epoch'] = new_epoch | |
if self.debug_progress: | |
print(f"Detected epoch header, setting current epoch to: {new_epoch}") | |
return | |
# Detect test results header | |
if re.search(self.patterns['test_header'], line): | |
self.parsing_test_results = True | |
self.test_results = {} | |
# Set stage to 'Testing' when we see the test results header | |
self.current_progress['stage'] = 'Testing' | |
return | |
# Extract the actual content part of the log line if it contains timestamp and INFO | |
log_content = line | |
log_match = re.search(r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2} - [a-zA-Z]+ - INFO - (.*)', line) | |
if log_match: | |
log_content = log_match.group(1) | |
if self.parsing_test_results: | |
collected_new_metric = False | |
# 尝试匹配测试损失值 | |
test_loss_match = re.search(self.patterns['test_loss'], log_content) | |
if test_loss_match: | |
loss_value = float(test_loss_match.group(1)) | |
self.test_results['loss'] = loss_value | |
collected_new_metric = True | |
if self.debug_progress: | |
print(f"Matched test loss: {loss_value}") | |
# 特别处理f1指标 | |
test_f1_match = re.search(self.patterns['test_f1'], log_content) | |
if test_f1_match and not test_loss_match: | |
f1_value = float(test_f1_match.group(1)) | |
self.test_results['f1'] = f1_value | |
collected_new_metric = True | |
if self.debug_progress: | |
print(f"Matched test f1: {f1_value}") | |
# 尝试匹配常见指标 | |
if not test_loss_match and not test_f1_match: | |
common_metric_match = re.search(self.patterns['test_common_metrics'], log_content) | |
if common_metric_match: | |
metric_name, metric_value = common_metric_match.groups() | |
metric_name = metric_name.strip().lower() | |
try: | |
value = float(metric_value) | |
self.test_results[metric_name] = value | |
collected_new_metric = True | |
if self.debug_progress: | |
print(f"Matched common test metric: {metric_name} = {value}") | |
except ValueError: | |
if self.debug_progress: | |
print(f"Failed to parse value for common metric {metric_name}: {metric_value}") | |
# 尝试匹配其他测试指标 | |
if not test_loss_match and not test_f1_match and not (locals().get('common_metric_match')): | |
test_metric_match = re.search(self.patterns['test_metric'], log_content) | |
if test_metric_match: | |
metric_name, metric_value = test_metric_match.groups() | |
metric_name = metric_name.strip().lower() | |
try: | |
value = float(metric_value) | |
self.test_results[metric_name] = value | |
collected_new_metric = True | |
if self.debug_progress: | |
print(f"Matched test metric: {metric_name} = {value}") | |
except ValueError: | |
if self.debug_progress: | |
print(f"Failed to parse value for metric {metric_name}: {metric_value}") | |
# 如果收集到新指标,更新显示 | |
if collected_new_metric: | |
self._update_test_results_display() | |
# Determine if we should end test results parsing | |
# Only end parsing when line doesn't start with "Test", is not empty, and we've collected metrics, or if line is empty | |
if ((not log_content.strip().startswith("Test") and | |
len(log_content.strip()) > 0 and | |
self.test_results) or | |
log_content.strip() == ""): | |
# Ensure we've collected at least some metrics before ending parsing | |
if self.test_results: | |
self.parsing_test_results = False | |
# Final update of the display | |
self._update_test_results_display() | |
return | |
# Parse model parameter statistics | |
if "Model Parameters Statistics:" in line: | |
self.current_stats = {} | |
self.parsing_stats = True | |
self.skipped_first_separator = False | |
return | |
if self.parsing_stats: | |
# Handle separator line | |
if "------------------------" in line: | |
# If this is the first separator line, skip it | |
if not self.skipped_first_separator: | |
self.skipped_first_separator = True | |
return | |
# If it's the last separator line, check if we have enough information | |
required_keys = ["adapter_total", "adapter_trainable", | |
"pretrain_total", "pretrain_trainable", | |
"combined_total", "combined_trainable", | |
"trainable_percentage"] | |
missing_keys = [key for key in required_keys if key not in self.current_stats] | |
if not missing_keys: | |
# Put statistics in queue | |
self.stats_queue.put(self.current_stats.copy()) | |
# Update cache | |
self.last_stats.update(self.current_stats) | |
self.parsing_stats = False | |
self.current_model = None | |
self.skipped_first_separator = False | |
return | |
# If first separator not yet skipped, don't process other lines | |
if not self.skipped_first_separator: | |
return | |
# Match model name sections | |
if "Adapter Model:" in line: | |
self.current_model = "adapter" | |
return | |
elif "Pre-trained Model:" in line: | |
self.current_model = "pretrain" | |
return | |
elif "Combined:" in line: | |
self.current_model = "combined" | |
return | |
# Parse parameter information | |
param_match = re.search(self.patterns['model_param'], line) | |
if param_match and self.current_model: | |
stat_name, stat_value = param_match.groups() | |
stat_name = stat_name.strip().lower() | |
if "total parameters" in stat_name: | |
self.current_stats[f"{self.current_model}_total"] = stat_value | |
elif "trainable parameters" in stat_name: | |
self.current_stats[f"{self.current_model}_trainable"] = stat_value | |
elif "trainable percentage" in stat_name and self.current_model == "combined": | |
self.current_stats["trainable_percentage"] = stat_value | |
return | |
# Process training progress | |
train_progress_match = re.search(self.progress_patterns['train'], line) | |
if train_progress_match: | |
percentage, current, total, elapsed, remaining, it_per_sec = train_progress_match.groups()[:6] | |
grad_step = train_progress_match.group(7) if len(train_progress_match.groups()) >= 7 and train_progress_match.group(7) else "0" | |
loss = train_progress_match.group(8) if len(train_progress_match.groups()) >= 8 and train_progress_match.group(8) else "0.0" | |
# Update progress information | |
self.current_progress['stage'] = 'Training' | |
self.current_progress['current'] = int(current) | |
self.current_progress['total'] = int(total) | |
self.current_progress['progress_detail'] = f"{current}/{total}[{elapsed}<{remaining},{it_per_sec}it/s" | |
if grad_step: | |
self.current_progress['progress_detail'] += f",grad_step={grad_step}" | |
self.current_progress['progress_detail'] += f",train_loss={loss}]" | |
self.current_progress['elapsed_time'] = elapsed | |
self.current_progress['remaining_time'] = remaining | |
self.current_progress['it_per_sec'] = float(it_per_sec) | |
if grad_step: | |
self.current_progress['grad_step'] = int(grad_step) | |
if loss and float(loss) > 0: | |
self.current_progress['loss'] = float(loss) | |
return | |
# Validation or Testing progress - consolidated since they use same tqdm format | |
valid_or_test_match = re.search(self.progress_patterns['valid_or_test'], line) | |
if valid_or_test_match: | |
percentage, current, total, elapsed, remaining, it_per_sec = valid_or_test_match.groups() | |
# Determine stage based on current context and line content | |
# If line contains 'Test' or we've already detected test phase, set to 'Testing' | |
if 'Test' in line or self.current_progress.get('stage') == 'Testing' or self.parsing_test_results: | |
self.current_progress['stage'] = 'Testing' | |
else: | |
self.current_progress['stage'] = 'Validation' | |
self.current_progress['current'] = int(current) | |
self.current_progress['total'] = int(total) | |
self.current_progress['progress_detail'] = f"{current}/{total}[{elapsed}<{remaining},{it_per_sec}it/s]" | |
self.current_progress['elapsed_time'] = elapsed | |
self.current_progress['remaining_time'] = remaining | |
self.current_progress['it_per_sec'] = float(it_per_sec) | |
return | |
# Parse training loss | |
train_match = re.search(self.patterns['train'], line) | |
if train_match: | |
epoch, loss = train_match.groups() | |
current_epoch = int(epoch) | |
self.current_progress['epoch'] = current_epoch | |
self.current_progress['loss'] = float(loss) | |
self.current_epoch = current_epoch | |
# Add new epoch to epochs list | |
if current_epoch not in self.epochs: | |
self.epochs.append(current_epoch) | |
self.train_losses.append(float(loss)) | |
else: | |
# Update existing epoch | |
idx = self.epochs.index(current_epoch) | |
self.train_losses[idx] = float(loss) | |
return | |
# Parse validation loss | |
val_match = re.search(self.patterns['val'], line) | |
if val_match: | |
epoch, loss = val_match.groups() | |
current_epoch = int(epoch) | |
# Ensure current epoch exists | |
if current_epoch not in self.epochs: | |
self.epochs.append(current_epoch) | |
if len(self.train_losses) < len(self.epochs): | |
self.train_losses.append(None) | |
idx = self.epochs.index(current_epoch) | |
# Ensure val_losses list matches epochs list length | |
while len(self.val_losses) < len(self.epochs): | |
self.val_losses.append(None) | |
# Update val_losses at correct position | |
self.val_losses[idx] = float(loss) | |
# Also update val_metrics dictionary | |
if 'loss' not in self.val_metrics: | |
self.val_metrics['loss'] = [] | |
# Ensure val_metrics['loss'] matches epochs length | |
while len(self.val_metrics['loss']) < len(self.epochs): | |
self.val_metrics['loss'].append(None) | |
# Update val_metrics['loss'] at correct position | |
self.val_metrics['loss'][idx] = float(loss) | |
return | |
# Parse validation metrics | |
val_metric_match = re.search(self.patterns['val_metric'], line) | |
if val_metric_match: | |
epoch, metric_name, metric_value = val_metric_match.groups() | |
current_epoch = int(epoch) | |
metric_name = metric_name.strip().lower() | |
# Handle different possible metrics | |
if metric_name == 'accuracy' or metric_name == 'acc': | |
self.current_progress['val_accuracy'] = float(metric_value) | |
# Ensure current epoch exists | |
if current_epoch not in self.epochs: | |
self.epochs.append(current_epoch) | |
if len(self.train_losses) < len(self.epochs): | |
self.train_losses.append(None) | |
# Add to corresponding metric list | |
if metric_name not in self.val_metrics: | |
self.val_metrics[metric_name] = [] | |
# Ensure list length matches epochs | |
while len(self.val_metrics[metric_name]) < len(self.epochs): | |
self.val_metrics[metric_name].append(None) | |
idx = self.epochs.index(current_epoch) | |
self.val_metrics[metric_name][idx] = float(metric_value) | |
return | |
# 首先检查原始行是否包含"Saving model with best val" | |
if "Saving model with best val" in line: | |
# 直接从原始行提取信息,避免依赖正则表达式 | |
try: | |
# 尝试直接解析行内容 | |
parts = line.split("Saving model with best val ")[1].split(": ") | |
if len(parts) == 2: | |
metric_name = parts[0].strip().lower() | |
metric_value = float(parts[1].strip()) | |
# 更新Best Performance信息 | |
self.current_progress['best_metric_name'] = metric_name | |
self.current_progress['best_metric_value'] = metric_value | |
self.current_progress['best_epoch'] = self.current_epoch | |
# 如果是accuracy指标,同时更新best_accuracy | |
if metric_name == 'accuracy': | |
self.current_progress['best_accuracy'] = metric_value | |
# 记录调试信息 | |
print(f"Best model updated - Metric: {metric_name}, Value: {metric_value}, Epoch: {self.current_epoch}") | |
# 将最佳模型信息添加到消息队列,确保UI能够显示 | |
best_model_msg = f"Best model saved at epoch {self.current_epoch} with {metric_name}: {metric_value:.4f}" | |
self.message_queue.put(best_model_msg) | |
return | |
except Exception as e: | |
print(f"Error parsing best model info: {e}, line: {line}") | |
# 如果直接解析失败,尝试使用正则表达式 | |
# Match best model save info: e.g., "Saving model with best val accuracy: 0.9088" | |
best_save_match = re.search(self.patterns['best_save'], log_content) | |
if best_save_match: | |
metric_name, metric_value = best_save_match.groups() | |
metric_name = metric_name.strip().lower() | |
metric_value = float(metric_value) | |
# 更新Best Performance信息 | |
self.current_progress['best_metric_name'] = metric_name | |
self.current_progress['best_metric_value'] = metric_value | |
self.current_progress['best_epoch'] = self.current_epoch | |
# 如果是accuracy指标,同时更新best_accuracy | |
if metric_name == 'accuracy': | |
self.current_progress['best_accuracy'] = metric_value | |
# 记录调试信息 | |
print(f"Best model updated (regex) - Metric: {metric_name}, Value: {metric_value}, Epoch: {self.current_epoch}") | |
# 将最佳模型信息添加到消息队列,确保UI能够显示 | |
best_model_msg = f"Best model saved at epoch {self.current_epoch} with {metric_name}: {metric_value:.4f}" | |
self.message_queue.put(best_model_msg) | |
return | |
# 检查进程是否已经结束 | |
if self.process and self.process.poll() is not None: | |
self.is_training = False | |
self.current_progress['is_completed'] = True | |
print("Training process has completed. Setting is_completed flag.") | |
except Exception as e: | |
# 记录错误信息,同时也保存到输出行中 | |
error_msg = f"Error parsing line: {str(e)}" | |
self.error_message = error_msg | |
if 'lines' not in self.current_progress: | |
self.current_progress['lines'] = [] | |
self.current_progress['lines'].append(error_msg) | |
if self.debug_progress: | |
print(error_msg) | |
print(f"Line content: {line}") | |
def _reset_tracking(self): | |
"""重置所有跟踪状态""" | |
# 重置指标跟踪 | |
self.train_losses = [] | |
self.val_losses = [] | |
self.val_metrics = {} | |
self.epochs = [] | |
self.current_epoch = 0 | |
# 重置测试结果 | |
self.test_results = {} | |
self.parsing_test_results = False | |
self.test_results_html = "" | |
# Force complete reset by creating a new dictionary instead of modifying existing one | |
# This ensures no old keys remain in the dictionary | |
self.current_progress = { | |
'stage': 'Waiting', | |
'progress': '', | |
'epoch': 0, | |
'current': 0, | |
'total': 0, # Set to 0 initially to avoid showing progress | |
'total_epochs': 0, | |
'val_accuracy': 0.0, | |
'best_accuracy': 0.0, | |
'best_epoch': -1, # Set to -1 to indicate no best model | |
'best_metric_name': '', | |
'best_metric_value': 0.0, | |
'progress_detail': '', | |
'elapsed_time': '', | |
'remaining_time': '', | |
'it_per_sec': 0.0, | |
'grad_step': 0, | |
'loss': 0.0, | |
'test_metrics': {}, | |
'test_progress': 0.0, | |
'test_results_html': '', | |
'lines': [] # 添加lines字段来存储输出行 | |
} | |
# 重置统计信息 | |
self.current_stats = {} | |
self.parsing_stats = False | |
self.current_model = None | |
self.skipped_first_separator = False | |
# 重置缓存的统计信息 | |
self.last_stats = {} | |
# 重置错误信息 | |
if hasattr(self, 'error_message'): | |
self.error_message = None | |
def get_stats(self) -> Dict: | |
"""Get collected statistics.""" | |
# Save last retrieved statistics to avoid emptying queue every time | |
if not hasattr(self, 'last_stats'): | |
self.last_stats = {} | |
try: | |
# Check if there's new data in the queue | |
if not self.stats_queue.empty(): | |
# Get the latest statistics data | |
while not self.stats_queue.empty(): | |
stat = self.stats_queue.get_nowait() | |
self.last_stats.update(stat) | |
except queue.Empty: | |
pass | |
except Exception as e: | |
print(f"Error getting statistics data: {str(e)}") | |
return self.last_stats | |
def _reset_stats(self): | |
"""Reset statistics tracking.""" | |
# Clear statistics queue | |
while not self.stats_queue.empty(): | |
try: | |
self.stats_queue.get_nowait() | |
except queue.Empty: | |
break | |
# Reset current statistics with new dictionaries | |
self.current_stats = {} | |
self.parsing_stats = False | |
self.current_model = None | |
self.skipped_first_separator = False # Reset flag | |
# Reset cached statistics | |
self.last_stats = {} | |
# Reset stats property explicitly | |
self.stats = {} | |
# Reset training and validation metrics | |
self._reset_tracking() | |
# Reset progress info | |
self.current_progress = { | |
'stage': 'Waiting', | |
'progress': '', | |
'epoch': 0, | |
'current': 0, | |
'total': 100, | |
'total_epochs': 0, # Ensure total_epochs is reset | |
'val_accuracy': 0.0, | |
'best_accuracy': 0.0, | |
'best_epoch': 0, | |
'best_metric_name': 'accuracy', | |
'best_metric_value': 0.0, | |
'progress_detail': '', | |
'elapsed_time': '', | |
'remaining_time': '', | |
'it_per_sec': 0.0, | |
'grad_step': 0, | |
'loss': 0.0, | |
'test_metrics': {}, | |
'test_progress': 0.0, | |
'test_results_html': '', | |
'lines': [] # 添加lines字段来存储输出行 | |
} | |
def _update_test_results_display(self): | |
"""Update the display of test results, in both HTML and text formats.""" | |
if not self.test_results: | |
return | |
# Count number of metrics | |
metrics_count = len(self.test_results) | |
# Create a more beautiful HTML table with summary information | |
html_content = f""" | |
<div style="max-width: 800px; margin: 0 auto; font-family: Arial, sans-serif;"> | |
<h3 style="text-align: center; margin-bottom: 15px; color: #333;">Test Results</h3> | |
<p style="text-align: center; margin-bottom: 15px; color: #666;">{metrics_count} metrics found</p> | |
<table style="width: 100%; border-collapse: collapse; font-size: 14px; border: 1px solid #ddd; box-shadow: 0 2px 3px rgba(0,0,0,0.1);"> | |
<thead> | |
<tr style="background-color: #f0f0f0;"> | |
<th style="padding: 12px; text-align: center; border: 1px solid #ddd; font-weight: bold; width: 50%;">Metric</th> | |
<th style="padding: 12px; text-align: center; border: 1px solid #ddd; font-weight: bold; width: 50%;">Value</th> | |
</tr> | |
</thead> | |
<tbody> | |
""" | |
# Sort by priority and alphabetically to ensure important metrics are displayed first | |
priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc'] | |
# Build priority sorting key | |
def get_priority(item): | |
name = item[0] | |
if name in priority_metrics: | |
return priority_metrics.index(name) | |
return len(priority_metrics) | |
# Sort by priority | |
sorted_metrics = sorted(self.test_results.items(), key=get_priority) | |
# Add a row for each metric, using alternating row colors | |
for i, (metric_name, metric_value) in enumerate(sorted_metrics): | |
row_style = 'background-color: #f9f9f9;' if i % 2 == 0 else '' | |
# Use bold for priority metrics | |
is_priority = metric_name in priority_metrics | |
name_style = 'font-weight: bold;' if is_priority else '' | |
# 转换指标名称:缩写用大写,非缩写首字母大写 | |
display_name = metric_name | |
if metric_name.lower() in ['f1', 'mcc', 'auroc']: | |
display_name = metric_name.upper() | |
else: | |
display_name = metric_name.capitalize() | |
html_content += f""" | |
<tr style="{row_style}"> | |
<td style="padding: 10px; text-align: center; border: 1px solid #ddd; {name_style}">{display_name}</td> | |
<td style="padding: 10px; text-align: center; border: 1px solid #ddd;">{metric_value:.4f}</td> | |
</tr> | |
""" | |
html_content += """ | |
</tbody> | |
</table> | |
<p style="text-align: center; margin-top: 10px; color: #888; font-size: 12px;">Test completed at: """ + time.strftime("%Y-%m-%d %H:%M:%S") + """</p> | |
</div> | |
""" | |
# Save to current_progress for UI access | |
self.current_progress['test_metrics'] = self.test_results.copy() | |
self.current_progress['test_results_html'] = html_content | |
# Generate text representation for logging | |
text_results = "\nTest Results:\n" + "-" * 30 + "\n" | |
# Display in same order as HTML | |
for metric_name, metric_value in sorted_metrics: | |
# 转换指标名称:缩写用大写,非缩写首字母大写 | |
display_name = metric_name | |
if metric_name.lower() in ['f1', 'mcc', 'auroc']: | |
display_name = metric_name.upper() | |
else: | |
display_name = metric_name.capitalize() | |
text_results += f"{display_name.ljust(15)}: {metric_value:.4f}\n" | |
text_results += "-" * 30 | |
text_results += f"\nTotal {metrics_count} metrics" | |
# Add text results to message queue | |
self.message_queue.put(text_results) | |
# Generate CSV content for download | |
csv_content = "Metric,Value\n" | |
for metric_name, metric_value in sorted_metrics: | |
# 转换指标名称:缩写用大写,非缩写首字母大写 | |
display_name = metric_name | |
if metric_name.lower() in ['f1', 'mcc', 'auroc']: | |
display_name = metric_name.upper() | |
else: | |
display_name = metric_name.capitalize() | |
csv_content += f"{display_name},{metric_value:.6f}\n" | |
self.current_progress['test_results_csv'] = csv_content | |
def _process_test_progress(self, line: str): | |
"""Process test progress from output lines during testing phase.""" | |
# Parse intermediate test results if available | |
test_metric_interim_match = re.search(r'Batch (\d+)/(\d+): ([a-zA-Z_\-]+) = ([\d.]+)', line) | |
if test_metric_interim_match: | |
batch, total_batches, metric_name, metric_value = test_metric_interim_match.groups() | |
progress = int(batch) / int(total_batches) * 100 | |
self.current_progress['test_progress'] = progress | |
# Update test metrics with interim values | |
if 'interim_metrics' not in self.current_progress: | |
self.current_progress['interim_metrics'] = {} | |
self.current_progress['interim_metrics'][metric_name] = float(metric_value) | |
return True | |
return False | |
def check_process_status(self): | |
"""Check if the training process has completed.""" | |
if self.process and self.process.poll() is not None: | |
self.is_training = False | |
# Check for normal vs error termination based on return code | |
if self.process.returncode == 0: | |
# Normal termination | |
self.current_progress['is_completed'] = True | |
print("Training process has completed successfully. Setting is_completed flag.") | |
else: | |
# Error termination - ensure UI doesn't show "completed" | |
self.current_progress['is_completed'] = False | |
# Explicitly mark the stage as Error for proper UI handling | |
self.current_progress['stage'] = 'Error' | |
# Log the error more prominently | |
print(f"Training process terminated with error code {self.process.returncode}. Setting stage to 'Error'.") | |
# Clear the process reference | |
self.process = None | |
return True | |
return False |