import gradio as gr import json import os import subprocess import sys import signal import threading import queue import time import pandas as pd import tempfile import csv from pathlib import Path import traceback import re from web.utils.command import preview_predict_command import select def create_predict_tab(constant): plm_models = constant["plm_models"] is_predicting = False current_process = None output_queue = queue.Queue() stop_thread = False process_aborted = False # Flag indicating if the process was manually terminated def process_output(process, queue): """Process output from subprocess and put it in queue""" nonlocal stop_thread while True: if stop_thread: break output = process.stdout.readline() if output == '' and process.poll() is not None: break if output: queue.put(output.strip()) process.stdout.close() def generate_status_html(status_info): """Generate HTML for single sequence prediction status""" stage = status_info.get("current_step", "Preparing") status = status_info.get("status", "running") # Determine status color and icon if status == "running": status_color = "#4285f4" # Blue icon = "⏳" animation = """ @keyframes pulse { 0% { transform: scale(1); } 50% { transform: scale(1.05); } 100% { transform: scale(1); } } """ animation_style = "animation: pulse 1.5s infinite ease-in-out;" elif status == "completed": status_color = "#2ecc71" # Green icon = "✅" animation = "" animation_style = "" else: # failed status_color = "#e74c3c" # Red icon = "❌" animation = "" animation_style = "" # Create a clean, centered notification return f"""
{icon}

{stage}

{status.capitalize()}

""" def predict_sequence(plm_model, model_path, aa_seq, eval_method, eval_structure_seq, pooling_method, problem_type, num_labels): """Predict for a single protein sequence""" nonlocal is_predicting, current_process, stop_thread, process_aborted # Check if we're already predicting if is_predicting: return gr.HTML("""

A prediction is already running. Please wait or abort it.

""") # If the process was aborted but not reset properly, ensure we're in a clean state if process_aborted: process_aborted = False # Set the prediction flag is_predicting = True stop_thread = False # Ensure this is reset # Create a status info object, similar to batch prediction status_info = { "status": "running", "current_step": "Starting prediction" } # Show initial status yield generate_status_html(status_info) try: # Validate inputs if not model_path: is_predicting = False return gr.HTML("""
Please provide a model path
""") if not os.path.exists(os.path.dirname(model_path)): is_predicting = False return gr.HTML("""
Invalid model path - directory does not exist
""") if not aa_seq: is_predicting = False return gr.HTML("""
Amino acid sequence is required
""") # Update status status_info["current_step"] = "Preparing model and parameters" yield generate_status_html(status_info) # Prepare command args_dict = { "model_path": model_path, "plm_model": plm_models[plm_model], "aa_seq": aa_seq, "pooling_method": pooling_method, "problem_type": problem_type, "num_labels": num_labels, "eval_method": eval_method } if eval_method == "ses-adapter": # Handle structure sequence selection from multi-select dropdown args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None # Set flags based on selected structure sequences if eval_structure_seq: if "foldseek_seq" in eval_structure_seq: args_dict["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args_dict["use_ss8"] = True else: args_dict["structure_seq"] = None args_dict["use_foldseek"] = False args_dict["use_ss8"] = False # Build command line final_cmd = [sys.executable, "src/predict.py"] for k, v in args_dict.items(): if v is True: final_cmd.append(f"--{k}") elif v is not False and v is not None: final_cmd.append(f"--{k}") final_cmd.append(str(v)) # Update status status_info["current_step"] = "Starting prediction process" yield generate_status_html(status_info) # Start prediction process try: current_process = subprocess.Popen( final_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True, preexec_fn=os.setsid if hasattr(os, "setsid") else None ) except Exception as e: is_predicting = False return gr.HTML(f"""
Error starting prediction process: {str(e)}
""") output_thread = threading.Thread(target=process_output, args=(current_process, output_queue)) output_thread.daemon = True output_thread.start() # Collect output result_output = "" prediction_data = None json_str = "" in_json_block = False json_lines = [] # Update status status_info["current_step"] = "Processing sequence" yield generate_status_html(status_info) while current_process.poll() is None: # Check if the process was aborted if process_aborted or stop_thread: break try: while not output_queue.empty(): line = output_queue.get_nowait() result_output += line + "\n" # Update status with more meaningful messages if "Loading model" in line: status_info["current_step"] = "Loading model and tokenizer" elif "Processing sequence" in line: status_info["current_step"] = "Processing protein sequence" elif "Tokenizing" in line: status_info["current_step"] = "Tokenizing sequence" elif "Forward pass" in line: status_info["current_step"] = "Running model inference" elif "Making prediction" in line: status_info["current_step"] = "Calculating final prediction" elif "Prediction Results" in line: status_info["current_step"] = "Finalizing results" # Update status display yield generate_status_html(status_info) # Detect start of JSON results block if "---------- Prediction Results ----------" in line: in_json_block = True json_lines = [] continue # If in JSON block, collect JSON lines if in_json_block and line.strip(): json_lines.append(line.strip()) # Try to parse the complete JSON when we have multiple lines if line.strip() == "}": # Potential end of JSON object try: complete_json = " ".join(json_lines) # Clean up the JSON string by removing line breaks and extra spaces complete_json = re.sub(r'\s+', ' ', complete_json).strip() prediction_data = json.loads(complete_json) print(f"Successfully parsed complete JSON: {prediction_data}") except json.JSONDecodeError as e: print(f"Failed to parse complete JSON: {e}") time.sleep(0.1) except Exception as e: yield gr.HTML(f"""
⚠️
Warning reading output: {str(e)}
""") # Check if the process was aborted if process_aborted: # Show aborted message abort_html = """

Prediction was aborted by user

""" yield gr.HTML(abort_html) is_predicting = False return # Process has completed if current_process and current_process.returncode == 0: # Update status status_info["status"] = "completed" status_info["current_step"] = "Prediction completed successfully" yield generate_status_html(status_info) # If no prediction data found, try to parse from complete output if not prediction_data: try: # Find the JSON block in the output results_marker = "---------- Prediction Results ----------" if results_marker in result_output: json_part = result_output.split(results_marker)[1].strip() # Try to extract the JSON object json_match = re.search(r'(\{.*?\})', json_part.replace('\n', ' '), re.DOTALL) if json_match: try: json_str = json_match.group(1) # Clean up the JSON string json_str = re.sub(r'\s+', ' ', json_str).strip() prediction_data = json.loads(json_str) print(f"Parsed prediction data from regex: {prediction_data}") except json.JSONDecodeError as e: print(f"JSON parse error from regex: {e}") except Exception as e: print(f"Error parsing JSON from complete output: {e}") if prediction_data: # Create styled HTML table based on problem type if problem_type == "regression": html_result = f"""

Regression Prediction Results

OutputValue
Predicted Value{prediction_data['prediction']:.4f}
""" elif problem_type == "single_label_classification": # Create probability table prob_rows = "" if isinstance(prediction_data.get('probabilities'), list): prob_rows = "".join([ f"Class {i}{prob:.4f}" for i, prob in enumerate(prediction_data['probabilities']) ]) elif isinstance(prediction_data.get('probabilities'), dict): prob_rows = "".join([ f"Class {label}{prob:.4f}" for label, prob in prediction_data['probabilities'].items() ]) else: # Handle case where probabilities is not a list or dict prob_value = prediction_data.get('probabilities', 0) prob_rows = f"Class 0{prob_value:.4f}" html_result = f"""

Single-Label Classification Results

OutputValue
Predicted Class{prediction_data['predicted_class']}

Class Probabilities

{prob_rows}
ClassProbability
""" else: # multi_label_classification # Create prediction table pred_rows = "" if 'predictions' in prediction_data and 'probabilities' in prediction_data: # Handle different formats of predictions and probabilities if (isinstance(prediction_data['predictions'], list) and isinstance(prediction_data['probabilities'], list)): pred_rows = "".join([ f"Label {i}{pred}{prob:.4f}" for i, (pred, prob) in enumerate(zip(prediction_data['predictions'], prediction_data['probabilities'])) ]) elif (isinstance(prediction_data['predictions'], dict) and isinstance(prediction_data['probabilities'], dict)): pred_rows = "".join([ f"Label {label}{pred}{prediction_data['probabilities'].get(label, 0):.4f}" for label, pred in prediction_data['predictions'].items() ]) else: # Handle case where predictions or probabilities is not a list or dict pred = prediction_data['predictions'] if 'predictions' in prediction_data else "N/A" prob = prediction_data['probabilities'] if 'probabilities' in prediction_data else 0.0 pred_rows = f"Label 0{pred}{prob:.4f}" else: # Handle other prediction data formats for key, value in prediction_data.items(): if 'label' in key.lower() or 'class' in key.lower(): label_name = key label_value = value prob_value = prediction_data.get(f"{key}_prob", 0.0) pred_rows += f"{label_name}{label_value}{prob_value:.4f}" html_result = f"""

Multi-Label Classification Results

{pred_rows}
LabelPredictionProbability
""" # Add CSS styling html_result += """ """ yield gr.HTML(html_result) else: # If no prediction data found, display raw output yield gr.HTML(f"""

Prediction Completed

No prediction results found in output.

{result_output}
""") else: # Update status status_info["status"] = "failed" status_info["current_step"] = "Prediction failed" yield generate_status_html(status_info) stderr_output = "" if current_process and hasattr(current_process, 'stderr') and current_process.stderr: stderr_output = current_process.stderr.read() yield gr.HTML(f"""

Prediction Failed

Error code: {current_process.returncode if current_process else 'Unknown'}

{stderr_output}\n{result_output}
""") except Exception as e: # Update status status_info["status"] = "failed" status_info["current_step"] = "Error occurred" yield generate_status_html(status_info) yield gr.HTML(f"""

Error

{str(e)}

{traceback.format_exc()}
""") finally: # Reset state is_predicting = False # Properly clean up the process if current_process and current_process.poll() is None: try: # Use process group ID to kill all related processes if possible if hasattr(os, "killpg") and hasattr(os, "getpgid"): os.killpg(os.getpgid(current_process.pid), signal.SIGTERM) else: # On Windows or if killpg is not available current_process.terminate() # Wait briefly for termination try: current_process.wait(timeout=1) except subprocess.TimeoutExpired: # Force kill if necessary if hasattr(os, "killpg") and hasattr(os, "getpgid"): os.killpg(os.getpgid(current_process.pid), signal.SIGKILL) else: current_process.kill() except Exception as e: # Ignore errors during process cleanup print(f"Error cleaning up process: {e}") # Reset process reference current_process = None stop_thread = False def predict_batch(plm_model, model_path, eval_method, input_file, eval_structure_seq, pooling_method, problem_type, num_labels, batch_size): """Batch predict multiple protein sequences""" nonlocal is_predicting, current_process, stop_thread, process_aborted # Check if we're already predicting (this check is performed first) if is_predicting: return gr.HTML("""

A prediction is already running. Please wait or abort it.

"""), gr.update(visible=False) # If the process was aborted but not reset properly, ensure we're in a clean state if process_aborted: process_aborted = False # Reset all state completely is_predicting = True stop_thread = False # Clear the output queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break # Initialize progress tracking with completely fresh state progress_info = { "total": 0, "completed": 0, "current_step": "Initializing", "status": "running", "lines": [] # Store lines for error handling } # Generate completely empty initial progress display initial_progress_html = """
Initializing prediction environment...
0%

Sequences: 0/0

""" # Always ensure the download button is hidden when starting a new prediction yield gr.HTML(initial_progress_html), gr.update(visible=False) try: # Check abort state before continuing if process_aborted: is_predicting = False return gr.HTML("""

Process was aborted.

"""), gr.update(visible=False) # Validate inputs if not model_path: is_predicting = False yield gr.HTML("""

Error: Model path is required

"""), gr.update(visible=False) return if not os.path.exists(os.path.dirname(model_path)): is_predicting = False yield gr.HTML("""

Error: Invalid model path - directory does not exist

"""), gr.update(visible=False) return if not input_file: is_predicting = False yield gr.HTML("""

Error: Input file is required

"""), gr.update(visible=False) return # Update progress progress_info["current_step"] = "Preparing input file" yield generate_progress_html(progress_info), gr.update(visible=False) # Create temporary file to save uploaded file temp_dir = tempfile.mkdtemp() input_path = os.path.join(temp_dir, "input.csv") output_dir = temp_dir # Use the same temporary directory as output directory output_file = "predictions.csv" output_path = os.path.join(output_dir, output_file) # Save uploaded file try: with open(input_path, "wb") as f: # Fix file upload error, correctly handle files uploaded through gradio if hasattr(input_file, "name"): # If it's a NamedString object, read the file content with open(input_file.name, "rb") as uploaded: f.write(uploaded.read()) else: # If it's a bytes object, write directly f.write(input_file) # Verify file was saved correctly if not os.path.exists(input_path): is_predicting = False yield gr.HTML("""

Error: Failed to save input file

"""), gr.update(visible=False) progress_info["status"] = "failed" progress_info["current_step"] = "Failed to save input file" return # Count sequences in input file try: df = pd.read_csv(input_path) progress_info["total"] = len(df) progress_info["current_step"] = f"Found {len(df)} sequences to process" yield generate_progress_html(progress_info), gr.update(visible=False) except Exception as e: is_predicting = False yield gr.HTML(f"""

Error reading CSV file:

{str(e)}
"""), gr.update(visible=False) progress_info["status"] = "failed" progress_info["current_step"] = "Error reading CSV file" return except Exception as e: is_predicting = False yield gr.HTML(f"""

Error saving input file:

{str(e)}
"""), gr.update(visible=False) progress_info["status"] = "failed" progress_info["current_step"] = "Failed to save input file" return # Update progress progress_info["current_step"] = "Preparing model and parameters" yield generate_progress_html(progress_info), gr.update(visible=False) # Prepare command args_dict = { "model_path": model_path, "plm_model": plm_models[plm_model], "input_file": input_path, "output_dir": output_dir, # Update to output directory "output_file": output_file, # Output filename "pooling_method": pooling_method, "problem_type": problem_type, "num_labels": num_labels, "eval_method": eval_method, "batch_size": batch_size } if eval_method == "ses-adapter": args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None if eval_structure_seq: if "foldseek_seq" in eval_structure_seq: args_dict["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args_dict["use_ss8"] = True else: args_dict["structure_seq"] = None # Build command line final_cmd = [sys.executable, "src/predict_batch.py"] for k, v in args_dict.items(): if v is True: final_cmd.append(f"--{k}") elif v is not False and v is not None: final_cmd.append(f"--{k}") final_cmd.append(str(v)) # Update progress progress_info["current_step"] = "Starting batch prediction process" yield generate_progress_html(progress_info), gr.update(visible=False) # Start prediction process try: current_process = subprocess.Popen( final_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, universal_newlines=True, preexec_fn=os.setsid if hasattr(os, "setsid") else None ) except Exception as e: is_predicting = False yield gr.HTML(f"""

Error starting prediction process:

{str(e)}
"""), gr.update(visible=False) return output_thread = threading.Thread(target=process_output, args=(current_process, output_queue)) output_thread.daemon = True output_thread.start() # Start monitoring loop last_update_time = time.time() result_output = "" # Modified processing loop with abort check while True: # Check if process was aborted or completed if process_aborted or current_process is None or current_process.poll() is not None: break # Check for new output try: # Get new lines new_lines = [] for _ in range(10): # Process up to 10 lines at once try: line = output_queue.get_nowait() new_lines.append(line) result_output += line + "\n" progress_info["lines"].append(line) # Update progress based on output if "Predicting:" in line: try: # Extract progress from tqdm output match = re.search(r'(\d+)/(\d+)', line) if match: current, total = map(int, match.groups()) progress_info["completed"] = current progress_info["total"] = total progress_info["current_step"] = f"Processing sequence {current}/{total}" except: pass elif "Loading Model and Tokenizer" in line: progress_info["current_step"] = "Loading model and tokenizer" elif "Processing sequences" in line: progress_info["current_step"] = "Processing sequences" elif "Saving results" in line: progress_info["current_step"] = "Saving results" except queue.Empty: break # Check if the process has been aborted before updating UI if process_aborted: break # Check if we need to update the UI current_time = time.time() if new_lines or (current_time - last_update_time >= 0.5): yield generate_progress_html(progress_info), gr.update(visible=False) last_update_time = current_time # Small sleep to avoid busy waiting if not new_lines: time.sleep(0.1) except Exception as e: # Check if the process has been aborted before showing error if process_aborted: break error_html = f"""

Warning reading output:

{str(e)}
""" yield gr.HTML(error_html), gr.update(visible=False) # Check if aborted instead of completed if process_aborted: is_predicting = False aborted_html = """

Prediction was manually terminated.

All prediction state has been reset.

""" yield gr.HTML(aborted_html), gr.update(visible=False) return # Process has completed if os.path.exists(output_path): if current_process and current_process.returncode == 0: progress_info["status"] = "completed" # Generate final success HTML success_html = f"""

Prediction completed successfully!

Results saved to: {output_path}

Total sequences processed: {progress_info.get('total', 0)}

""" # Read prediction results try: df = pd.read_csv(output_path) # Create summary statistics based on problem type summary_html = "" if problem_type == "regression": summary_html = f"""
{len(df)}
Predictions
{df['prediction'].mean():.4f}
Mean
{df['prediction'].min():.4f}
Min
{df['prediction'].max():.4f}
Max
""" elif problem_type == "single_label_classification": if 'predicted_class' in df.columns: class_counts = df['predicted_class'].value_counts() class_stats = "".join([ f"""
{count}
Class {class_label}
""" for class_label, count in class_counts.items() ]) summary_html = f"""
{len(df)}
Predictions
{class_stats}
""" elif problem_type == "multi_label_classification": label_cols = [col for col in df.columns if col.startswith('label_') and not col.endswith('_prob')] if label_cols: label_stats = "".join([ f"""
{df[col].sum()}
{col}
""" for col in label_cols ]) summary_html = f"""
{len(df)}
Predictions
{label_stats}
""" # Create table preview with style consistent with dataset preview html_table = f"""

Batch Prediction Results Preview

{summary_html}
{' '.join([f'' for col in df.columns])} {generate_table_rows(df)}
{col}

You can download the complete prediction results using the button below.

""" # Add CSS styles final_html = success_html + f""" {html_table} """ # Return results preview and download link yield gr.HTML(final_html), gr.update(value=output_path, visible=True) except Exception as e: # If reading results file fails, show error but still provide download link error_html = f""" {success_html}

Unable to load preview results: {str(e)}

You can still download the complete prediction results file.

""" yield gr.HTML(error_html), gr.update(value=output_path, visible=True) else: # Process failed error_html = f"""

Prediction failed to complete

Process return code: {current_process.returncode if current_process else 'Unknown'}

{result_output}
""" yield gr.HTML(error_html), gr.update(visible=False) else: progress_info["status"] = "failed" error_html = f"""

Prediction completed, but output file not found at {output_path}

{result_output}
""" yield gr.HTML(error_html), gr.update(visible=False) except Exception as e: # Capture the full error with traceback error_traceback = traceback.format_exc() # Display error with traceback in UI error_html = f"""

Error during batch prediction: {str(e)}

{error_traceback}
""" yield gr.HTML(error_html), gr.update(visible=False) finally: # Always reset prediction state is_predicting = False if current_process: current_process = None process_aborted = False # Reset abort flag def generate_progress_html(progress_info): """Generate HTML progress bar similar to eval_tab""" current = progress_info.get("completed", 0) total = max(progress_info.get("total", 1), 1) # Avoid division by zero percentage = min(100, int((current / total) * 100)) stage = progress_info.get("current_step", "Preparing") # 确保进度在0-100之间 percentage = max(0, min(100, percentage)) # 准备详细信息 details = [] if total > 0: details.append(f"Total sequences: {total}") if current > 0 and total > 0: details.append(f"Current progress: {current}/{total}") details_text = ", ".join(details) # 创建更现代化的进度条 - 完全匹配eval_tab的样式 return f"""
Prediction Status: {stage}
{percentage:.1f}%
{f'
Total sequences: {total}
' if total > 0 else ''} {f'
Progress: {current}/{total}
' if current > 0 and total > 0 else ''} {f'
Status: {progress_info.get("status", "").capitalize()}
' if "status" in progress_info else ''}
""" def generate_table_rows(df, max_rows=100): """Generate HTML table rows with special handling for sequence data, maintaining consistent style with eval_tab""" rows = [] for i, row in df.iterrows(): if i >= max_rows: break cells = [] for col in df.columns: value = row[col] # Special handling for sequence type columns if col in ['aa_seq', 'foldseek_seq', 'ss8_seq'] and isinstance(value, str) and len(value) > 30: # Add title attribute to show full sequence on hover cell = f'{value[:30]}...' # Format numeric values to 4 decimal places elif isinstance(value, (int, float)) and not isinstance(value, bool): formatted_value = f"{value:.4f}" if isinstance(value, float) else value cell = f'{formatted_value}' else: cell = f'{value}' cells.append(cell) # Add alternating row background color bg_color = "#f9f9f9" if i % 2 == 1 else "white" rows.append(f'{" ".join(cells)}') if len(df) > max_rows: cols_count = len(df.columns) rows.append(f'Showing {max_rows} of {len(df)} rows') return '\n'.join(rows) def handle_abort(): """Handle abortion of the prediction process for both single and batch prediction""" nonlocal is_predicting, current_process, stop_thread, process_aborted if not is_predicting or current_process is None: empty_html = """

No prediction process is currently running.

""" # Return full HTML value (not gr.HTML component) return empty_html try: # Set the abort flag before terminating the process process_aborted = True stop_thread = True # Kill the process group if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGTERM) else: current_process.terminate() # Wait for process to terminate (with timeout) try: current_process.wait(timeout=5) except subprocess.TimeoutExpired: if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGKILL) else: current_process.kill() # Reset state is_predicting = False current_process = None # Clear output queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break success_html = """

Prediction successfully terminated!

All prediction state has been reset.

""" # Return full HTML value (not gr.HTML component) return success_html except Exception as e: # Reset states even on error is_predicting = False current_process = None process_aborted = False # Clear queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break error_html = f"""

Failed to terminate prediction: {str(e)}

Prediction state has been reset.

""" # Return full HTML value (not gr.HTML component) return error_html # Create handler functions for each tab def handle_abort_single(): """Handle abort for single sequence prediction tab""" # Flag the process for abortion first nonlocal stop_thread, process_aborted, is_predicting, current_process # Only proceed if there's an active prediction if not is_predicting or current_process is None: return gr.HTML("""

No prediction process is currently running.

""") # Set the abort flags process_aborted = True stop_thread = True # Terminate the process try: if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGTERM) else: current_process.terminate() # Wait briefly for termination try: current_process.wait(timeout=1) except subprocess.TimeoutExpired: # Force kill if necessary if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGKILL) else: current_process.kill() except Exception as e: pass # Catch any termination errors # Reset state is_predicting = False current_process = None # Return the success message return gr.HTML("""

Prediction successfully terminated!

All prediction state has been reset.

""") def handle_abort_batch(): """Handle abort for batch prediction tab""" # Flag the process for abortion first nonlocal stop_thread, process_aborted, is_predicting, current_process # Only proceed if there's an active prediction if not is_predicting or current_process is None: return gr.HTML("""

No prediction process is currently running.

"""), gr.update(visible=False) # Set the abort flags process_aborted = True stop_thread = True # Terminate the process try: if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGTERM) else: current_process.terminate() # Wait briefly for termination try: current_process.wait(timeout=1) except subprocess.TimeoutExpired: # Force kill if necessary if hasattr(os, "killpg"): os.killpg(os.getpgid(current_process.pid), signal.SIGKILL) else: current_process.kill() except Exception as e: pass # Catch any termination errors # Reset state is_predicting = False current_process = None # Clear output queue while not output_queue.empty(): try: output_queue.get_nowait() except queue.Empty: break # Return the success message and hide the download button return gr.HTML("""

Prediction successfully terminated!

All prediction state has been reset.

"""), gr.update(visible=False) def handle_preview(plm_model, model_path, eval_method, aa_seq, foldseek_seq, ss8_seq, eval_structure_seq, pooling_method, problem_type, num_labels): """处理单序列预测命令预览""" # 构建参数字典 args_dict = { "model_path": model_path, "plm_model": plm_models[plm_model], "aa_seq": aa_seq, "foldseek_seq": foldseek_seq if foldseek_seq else "", "ss8_seq": ss8_seq if ss8_seq else "", "pooling_method": pooling_method, "problem_type": problem_type, "num_labels": num_labels, "eval_method": eval_method } if eval_method == "ses-adapter": args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None if eval_structure_seq: if "foldseek_seq" in eval_structure_seq: args_dict["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args_dict["use_ss8"] = True # 生成预览命令 preview_text = preview_predict_command(args_dict, is_batch=False) return gr.update(value=preview_text, visible=True) def handle_batch_preview(plm_model, model_path, eval_method, input_file, eval_structure_seq, pooling_method, problem_type, num_labels, batch_size): """处理批量预测命令预览""" if not input_file: return gr.update(value="Please upload a file first", visible=True) # 创建临时目录作为输出目录 temp_dir = "temp_predictions" output_file = "predictions.csv" args_dict = { "model_path": model_path, "plm_model": plm_models[plm_model], "input_file": input_file.name if hasattr(input_file, "name") else "input.csv", "output_dir": temp_dir, # 新增输出目录参数 "output_file": output_file, # 输出文件名 "pooling_method": pooling_method, "problem_type": problem_type, "num_labels": num_labels, "eval_method": eval_method, "batch_size": batch_size } if eval_method == "ses-adapter": args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None if eval_structure_seq: if "foldseek_seq" in eval_structure_seq: args_dict["use_foldseek"] = True if "ss8_seq" in eval_structure_seq: args_dict["use_ss8"] = True # 生成预览命令 preview_text = preview_predict_command(args_dict, is_batch=True) return gr.update(value=preview_text, visible=True) with gr.Tab("Prediction"): with gr.Row(): with gr.Column(): gr.Markdown("## Protein Function Prediction") gr.Markdown("### Model Configuration") with gr.Group(): with gr.Row(): model_path = gr.Textbox( label="Model Path", value="ckpt/demo/demo_provided.pt", placeholder="Path to the trained model" ) plm_model = gr.Dropdown( choices=list(plm_models.keys()), label="Protein Language Model" ) with gr.Row(): eval_method = gr.Dropdown( choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"], label="Evaluation Method", value="freeze" ) pooling_method = gr.Dropdown( choices=["mean", "attention1d", "light_attention"], label="Pooling Method", value="mean" ) # Settings for different training methods with gr.Row(visible=False) as structure_seq_row: structure_seq = gr.Dropdown( choices=["foldseek_seq", "ss8_seq"], label="Structure Sequences", multiselect=True, value=["foldseek_seq", "ss8_seq"], info="Select the structure sequences to use for prediction" ) with gr.Row(): problem_type = gr.Dropdown( choices=["single_label_classification", "multi_label_classification", "regression"], label="Problem Type", value="single_label_classification" ) num_labels = gr.Number( value=2, label="Number of Labels", precision=0, minimum=1 ) with gr.Tabs(): with gr.Tab("Sequence Prediction"): gr.Markdown("### Input Sequences") with gr.Row(): aa_seq = gr.Textbox( label="Amino Acid Sequence", placeholder="Enter protein sequence", lines=3 ) # Put the structure input rows in a row with controllable visibility with gr.Row(visible=False) as structure_input_row: foldseek_seq = gr.Textbox( label="Foldseek Sequence", placeholder="Enter foldseek sequence if available", lines=3 ) ss8_seq = gr.Textbox( label="SS8 Sequence", placeholder="Enter secondary structure sequence if available", lines=3 ) with gr.Row(): preview_single_button = gr.Button("Preview Command") predict_button = gr.Button("Predict", variant="primary") abort_button = gr.Button("Abort", variant="stop") # 添加命令预览区域 command_preview = gr.Code( label="Command Preview", language="shell", interactive=False, visible=False ) predict_output = gr.HTML(label="Prediction Results") predict_button.click( fn=predict_sequence, inputs=[ plm_model, model_path, aa_seq, eval_method, structure_seq, pooling_method, problem_type, num_labels ], outputs=predict_output ) abort_button.click( fn=handle_abort_single, inputs=[], outputs=[predict_output] ) with gr.Tab("Batch Prediction"): gr.Markdown("### Batch Prediction") # Display CSV format information with improved styling gr.HTML("""

CSV File Format Requirements

Please prepare your input CSV file with the following columns:

aa_seq (required)
Amino acid sequence
id (optional)
Unique identifier for each sequence
foldseek_seq (optional)
Foldseek structure sequence
ss8_seq (optional)
Secondary structure sequence
""") with gr.Row(): input_file = gr.UploadButton( label="Upload CSV File", file_types=[".csv"], file_count="single" ) # File preview accordion with gr.Accordion("File Preview", open=False) as file_preview_accordion: # File info area with gr.Row(): file_info = gr.HTML("", elem_classes=["dataset-stats"]) # Table area with gr.Row(): file_preview = gr.Dataframe( headers=["name", "sequence"], value=[["No file uploaded", "-"]], wrap=True, interactive=False, row_count=5, elem_classes=["preview-table"] ) # Add file preview function def update_file_preview(file): if file is None: return gr.update(value="
No file uploaded
"), gr.update(value=[["No file uploaded", "-"]], headers=["name", "sequence"]), gr.update(open=False) try: df = pd.read_csv(file.name) info_html = f"""
File Total Sequences Columns
{file.name.split('/')[-1]} {len(df)} {', '.join(df.columns.tolist())}
""" return gr.update(value=info_html), gr.update(value=df.head(5).values.tolist(), headers=df.columns.tolist()), gr.update(open=True) except Exception as e: error_html = f"""

Error reading file

{str(e)}

""" return gr.update(value=error_html), gr.update(value=[["Error", str(e)]], headers=["Error", "Message"]), gr.update(open=True) # Use upload event instead of click event input_file.upload( fn=update_file_preview, inputs=[input_file], outputs=[file_info, file_preview, file_preview_accordion] ) with gr.Row(): with gr.Column(scale=1): batch_size = gr.Slider( minimum=1, maximum=32, value=8, step=1, label="Batch Size", info="Number of sequences to process at once" ) with gr.Row(): preview_batch_button = gr.Button("Preview Command") batch_predict_button = gr.Button("Start Batch Prediction", variant="primary") batch_abort_button = gr.Button("Abort", variant="stop") # 添加命令预览区域 batch_command_preview = gr.Code( label="Command Preview", language="shell", interactive=False, visible=False ) batch_predict_output = gr.HTML(label="Prediction Progress") result_file = gr.DownloadButton(label="Download Predictions", visible=False) # 在UI部分添加命令预览的可见性控制 def toggle_preview(button_text): """切换命令预览的可见性""" if "Preview" in button_text: return gr.update(visible=True) return gr.update(visible=False) # 连接预览按钮 preview_single_button.click( fn=toggle_preview, inputs=[preview_single_button], outputs=[command_preview] ).then( fn=handle_preview, inputs=[ plm_model, model_path, eval_method, aa_seq, foldseek_seq, ss8_seq, structure_seq, pooling_method, problem_type, num_labels ], outputs=[command_preview] ) # 连接预览按钮 preview_batch_button.click( fn=toggle_preview, inputs=[preview_batch_button], outputs=[batch_command_preview] ).then( fn=handle_batch_preview, inputs=[ plm_model, model_path, eval_method, input_file, structure_seq, pooling_method, problem_type, num_labels, batch_size ], outputs=[batch_command_preview] ) batch_predict_button.click( fn=predict_batch, inputs=[ plm_model, model_path, eval_method, input_file, structure_seq, pooling_method, problem_type, num_labels, batch_size ], outputs=[batch_predict_output, result_file] ) batch_abort_button.click( fn=handle_abort_batch, inputs=[], outputs=[batch_predict_output, result_file] ) # Add this code after all UI components are defined def update_eval_method(method): return { structure_seq_row: gr.update(visible=method == "ses-adapter"), structure_input_row: gr.update(visible=method == "ses-adapter") } eval_method.change( fn=update_eval_method, inputs=[eval_method], outputs=[structure_seq_row, structure_input_row] ) # Add a new function to control the visibility of the structure sequence input boxes def update_structure_inputs(structure_seq_choices): return { foldseek_seq: gr.update(visible="foldseek_seq" in structure_seq_choices), ss8_seq: gr.update(visible="ss8_seq" in structure_seq_choices) } # Add event handling to the UI definition section structure_seq.change( fn=update_structure_inputs, inputs=[structure_seq], outputs=[foldseek_seq, ss8_seq] ) return { "predict_sequence": predict_sequence, "predict_batch": predict_batch, "handle_abort": handle_abort }