import gradio as gr
import json
import os
import subprocess
import sys
import signal
import threading
import queue
import time
import pandas as pd
import tempfile
import csv
from pathlib import Path
import traceback
import re
from web.utils.command import preview_predict_command
import select
def create_predict_tab(constant):
plm_models = constant["plm_models"]
is_predicting = False
current_process = None
output_queue = queue.Queue()
stop_thread = False
process_aborted = False # Flag indicating if the process was manually terminated
def process_output(process, queue):
"""Process output from subprocess and put it in queue"""
nonlocal stop_thread
while True:
if stop_thread:
break
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
queue.put(output.strip())
process.stdout.close()
def generate_status_html(status_info):
"""Generate HTML for single sequence prediction status"""
stage = status_info.get("current_step", "Preparing")
status = status_info.get("status", "running")
# Determine status color and icon
if status == "running":
status_color = "#4285f4" # Blue
icon = "⏳"
animation = """
@keyframes pulse {
0% { transform: scale(1); }
50% { transform: scale(1.05); }
100% { transform: scale(1); }
}
"""
animation_style = "animation: pulse 1.5s infinite ease-in-out;"
elif status == "completed":
status_color = "#2ecc71" # Green
icon = "✅"
animation = ""
animation_style = ""
else: # failed
status_color = "#e74c3c" # Red
icon = "❌"
animation = ""
animation_style = ""
# Create a clean, centered notification
return f"""
{icon}
{stage}
{status.capitalize()}
"""
def predict_sequence(plm_model, model_path, aa_seq, eval_method, eval_structure_seq, pooling_method, problem_type, num_labels):
"""Predict for a single protein sequence"""
nonlocal is_predicting, current_process, stop_thread, process_aborted
# Check if we're already predicting
if is_predicting:
return gr.HTML("""
A prediction is already running. Please wait or abort it.
""")
# If the process was aborted but not reset properly, ensure we're in a clean state
if process_aborted:
process_aborted = False
# Set the prediction flag
is_predicting = True
stop_thread = False # Ensure this is reset
# Create a status info object, similar to batch prediction
status_info = {
"status": "running",
"current_step": "Starting prediction"
}
# Show initial status
yield generate_status_html(status_info)
try:
# Validate inputs
if not model_path:
is_predicting = False
return gr.HTML("""
❌
Please provide a model path
""")
if not os.path.exists(os.path.dirname(model_path)):
is_predicting = False
return gr.HTML("""
❌
Invalid model path - directory does not exist
""")
if not aa_seq:
is_predicting = False
return gr.HTML("""
❌
Amino acid sequence is required
""")
# Update status
status_info["current_step"] = "Preparing model and parameters"
yield generate_status_html(status_info)
# Prepare command
args_dict = {
"model_path": model_path,
"plm_model": plm_models[plm_model],
"aa_seq": aa_seq,
"pooling_method": pooling_method,
"problem_type": problem_type,
"num_labels": num_labels,
"eval_method": eval_method
}
if eval_method == "ses-adapter":
# Handle structure sequence selection from multi-select dropdown
args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None
# Set flags based on selected structure sequences
if eval_structure_seq:
if "foldseek_seq" in eval_structure_seq:
args_dict["use_foldseek"] = True
if "ss8_seq" in eval_structure_seq:
args_dict["use_ss8"] = True
else:
args_dict["structure_seq"] = None
args_dict["use_foldseek"] = False
args_dict["use_ss8"] = False
# Build command line
final_cmd = [sys.executable, "src/predict.py"]
for k, v in args_dict.items():
if v is True:
final_cmd.append(f"--{k}")
elif v is not False and v is not None:
final_cmd.append(f"--{k}")
final_cmd.append(str(v))
# Update status
status_info["current_step"] = "Starting prediction process"
yield generate_status_html(status_info)
# Start prediction process
try:
current_process = subprocess.Popen(
final_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None
)
except Exception as e:
is_predicting = False
return gr.HTML(f"""
❌
Error starting prediction process: {str(e)}
""")
output_thread = threading.Thread(target=process_output, args=(current_process, output_queue))
output_thread.daemon = True
output_thread.start()
# Collect output
result_output = ""
prediction_data = None
json_str = ""
in_json_block = False
json_lines = []
# Update status
status_info["current_step"] = "Processing sequence"
yield generate_status_html(status_info)
while current_process.poll() is None:
# Check if the process was aborted
if process_aborted or stop_thread:
break
try:
while not output_queue.empty():
line = output_queue.get_nowait()
result_output += line + "\n"
# Update status with more meaningful messages
if "Loading model" in line:
status_info["current_step"] = "Loading model and tokenizer"
elif "Processing sequence" in line:
status_info["current_step"] = "Processing protein sequence"
elif "Tokenizing" in line:
status_info["current_step"] = "Tokenizing sequence"
elif "Forward pass" in line:
status_info["current_step"] = "Running model inference"
elif "Making prediction" in line:
status_info["current_step"] = "Calculating final prediction"
elif "Prediction Results" in line:
status_info["current_step"] = "Finalizing results"
# Update status display
yield generate_status_html(status_info)
# Detect start of JSON results block
if "---------- Prediction Results ----------" in line:
in_json_block = True
json_lines = []
continue
# If in JSON block, collect JSON lines
if in_json_block and line.strip():
json_lines.append(line.strip())
# Try to parse the complete JSON when we have multiple lines
if line.strip() == "}": # Potential end of JSON object
try:
complete_json = " ".join(json_lines)
# Clean up the JSON string by removing line breaks and extra spaces
complete_json = re.sub(r'\s+', ' ', complete_json).strip()
prediction_data = json.loads(complete_json)
print(f"Successfully parsed complete JSON: {prediction_data}")
except json.JSONDecodeError as e:
print(f"Failed to parse complete JSON: {e}")
time.sleep(0.1)
except Exception as e:
yield gr.HTML(f"""
⚠️
Warning reading output: {str(e)}
""")
# Check if the process was aborted
if process_aborted:
# Show aborted message
abort_html = """
Prediction was aborted by user
"""
yield gr.HTML(abort_html)
is_predicting = False
return
# Process has completed
if current_process and current_process.returncode == 0:
# Update status
status_info["status"] = "completed"
status_info["current_step"] = "Prediction completed successfully"
yield generate_status_html(status_info)
# If no prediction data found, try to parse from complete output
if not prediction_data:
try:
# Find the JSON block in the output
results_marker = "---------- Prediction Results ----------"
if results_marker in result_output:
json_part = result_output.split(results_marker)[1].strip()
# Try to extract the JSON object
json_match = re.search(r'(\{.*?\})', json_part.replace('\n', ' '), re.DOTALL)
if json_match:
try:
json_str = json_match.group(1)
# Clean up the JSON string
json_str = re.sub(r'\s+', ' ', json_str).strip()
prediction_data = json.loads(json_str)
print(f"Parsed prediction data from regex: {prediction_data}")
except json.JSONDecodeError as e:
print(f"JSON parse error from regex: {e}")
except Exception as e:
print(f"Error parsing JSON from complete output: {e}")
if prediction_data:
# Create styled HTML table based on problem type
if problem_type == "regression":
html_result = f"""
Regression Prediction Results
Output | Value |
Predicted Value | {prediction_data['prediction']:.4f} |
"""
elif problem_type == "single_label_classification":
# Create probability table
prob_rows = ""
if isinstance(prediction_data.get('probabilities'), list):
prob_rows = "".join([
f"Class {i} | {prob:.4f} |
"
for i, prob in enumerate(prediction_data['probabilities'])
])
elif isinstance(prediction_data.get('probabilities'), dict):
prob_rows = "".join([
f"Class {label} | {prob:.4f} |
"
for label, prob in prediction_data['probabilities'].items()
])
else:
# Handle case where probabilities is not a list or dict
prob_value = prediction_data.get('probabilities', 0)
prob_rows = f"Class 0 | {prob_value:.4f} |
"
html_result = f"""
Single-Label Classification Results
Output | Value |
Predicted Class | {prediction_data['predicted_class']} |
Class Probabilities
Class | Probability |
{prob_rows}
"""
else: # multi_label_classification
# Create prediction table
pred_rows = ""
if 'predictions' in prediction_data and 'probabilities' in prediction_data:
# Handle different formats of predictions and probabilities
if (isinstance(prediction_data['predictions'], list) and
isinstance(prediction_data['probabilities'], list)):
pred_rows = "".join([
f"Label {i} | {pred} | {prob:.4f} |
"
for i, (pred, prob) in enumerate(zip(prediction_data['predictions'], prediction_data['probabilities']))
])
elif (isinstance(prediction_data['predictions'], dict) and
isinstance(prediction_data['probabilities'], dict)):
pred_rows = "".join([
f"Label {label} | {pred} | {prediction_data['probabilities'].get(label, 0):.4f} |
"
for label, pred in prediction_data['predictions'].items()
])
else:
# Handle case where predictions or probabilities is not a list or dict
pred = prediction_data['predictions'] if 'predictions' in prediction_data else "N/A"
prob = prediction_data['probabilities'] if 'probabilities' in prediction_data else 0.0
pred_rows = f"Label 0 | {pred} | {prob:.4f} |
"
else:
# Handle other prediction data formats
for key, value in prediction_data.items():
if 'label' in key.lower() or 'class' in key.lower():
label_name = key
label_value = value
prob_value = prediction_data.get(f"{key}_prob", 0.0)
pred_rows += f"{label_name} | {label_value} | {prob_value:.4f} |
"
html_result = f"""
Multi-Label Classification Results
Label | Prediction | Probability |
{pred_rows}
"""
# Add CSS styling
html_result += """
"""
yield gr.HTML(html_result)
else:
# If no prediction data found, display raw output
yield gr.HTML(f"""
Prediction Completed
No prediction results found in output.
""")
else:
# Update status
status_info["status"] = "failed"
status_info["current_step"] = "Prediction failed"
yield generate_status_html(status_info)
stderr_output = ""
if current_process and hasattr(current_process, 'stderr') and current_process.stderr:
stderr_output = current_process.stderr.read()
yield gr.HTML(f"""
Prediction Failed
Error code: {current_process.returncode if current_process else 'Unknown'}
{stderr_output}\n{result_output}
""")
except Exception as e:
# Update status
status_info["status"] = "failed"
status_info["current_step"] = "Error occurred"
yield generate_status_html(status_info)
yield gr.HTML(f"""
""")
finally:
# Reset state
is_predicting = False
# Properly clean up the process
if current_process and current_process.poll() is None:
try:
# Use process group ID to kill all related processes if possible
if hasattr(os, "killpg") and hasattr(os, "getpgid"):
os.killpg(os.getpgid(current_process.pid), signal.SIGTERM)
else:
# On Windows or if killpg is not available
current_process.terminate()
# Wait briefly for termination
try:
current_process.wait(timeout=1)
except subprocess.TimeoutExpired:
# Force kill if necessary
if hasattr(os, "killpg") and hasattr(os, "getpgid"):
os.killpg(os.getpgid(current_process.pid), signal.SIGKILL)
else:
current_process.kill()
except Exception as e:
# Ignore errors during process cleanup
print(f"Error cleaning up process: {e}")
# Reset process reference
current_process = None
stop_thread = False
def predict_batch(plm_model, model_path, eval_method, input_file, eval_structure_seq, pooling_method, problem_type, num_labels, batch_size):
"""Batch predict multiple protein sequences"""
nonlocal is_predicting, current_process, stop_thread, process_aborted
# Check if we're already predicting (this check is performed first)
if is_predicting:
return gr.HTML("""
A prediction is already running. Please wait or abort it.
"""), gr.update(visible=False)
# If the process was aborted but not reset properly, ensure we're in a clean state
if process_aborted:
process_aborted = False
# Reset all state completely
is_predicting = True
stop_thread = False
# Clear the output queue
while not output_queue.empty():
try:
output_queue.get_nowait()
except queue.Empty:
break
# Initialize progress tracking with completely fresh state
progress_info = {
"total": 0,
"completed": 0,
"current_step": "Initializing",
"status": "running",
"lines": [] # Store lines for error handling
}
# Generate completely empty initial progress display
initial_progress_html = """
Initializing prediction environment...
0%
"""
# Always ensure the download button is hidden when starting a new prediction
yield gr.HTML(initial_progress_html), gr.update(visible=False)
try:
# Check abort state before continuing
if process_aborted:
is_predicting = False
return gr.HTML("""
"""), gr.update(visible=False)
# Validate inputs
if not model_path:
is_predicting = False
yield gr.HTML("""
Error: Model path is required
"""), gr.update(visible=False)
return
if not os.path.exists(os.path.dirname(model_path)):
is_predicting = False
yield gr.HTML("""
Error: Invalid model path - directory does not exist
"""), gr.update(visible=False)
return
if not input_file:
is_predicting = False
yield gr.HTML("""
Error: Input file is required
"""), gr.update(visible=False)
return
# Update progress
progress_info["current_step"] = "Preparing input file"
yield generate_progress_html(progress_info), gr.update(visible=False)
# Create temporary file to save uploaded file
temp_dir = tempfile.mkdtemp()
input_path = os.path.join(temp_dir, "input.csv")
output_dir = temp_dir # Use the same temporary directory as output directory
output_file = "predictions.csv"
output_path = os.path.join(output_dir, output_file)
# Save uploaded file
try:
with open(input_path, "wb") as f:
# Fix file upload error, correctly handle files uploaded through gradio
if hasattr(input_file, "name"):
# If it's a NamedString object, read the file content
with open(input_file.name, "rb") as uploaded:
f.write(uploaded.read())
else:
# If it's a bytes object, write directly
f.write(input_file)
# Verify file was saved correctly
if not os.path.exists(input_path):
is_predicting = False
yield gr.HTML("""
Error: Failed to save input file
"""), gr.update(visible=False)
progress_info["status"] = "failed"
progress_info["current_step"] = "Failed to save input file"
return
# Count sequences in input file
try:
df = pd.read_csv(input_path)
progress_info["total"] = len(df)
progress_info["current_step"] = f"Found {len(df)} sequences to process"
yield generate_progress_html(progress_info), gr.update(visible=False)
except Exception as e:
is_predicting = False
yield gr.HTML(f"""
Error reading CSV file:
{str(e)}
"""), gr.update(visible=False)
progress_info["status"] = "failed"
progress_info["current_step"] = "Error reading CSV file"
return
except Exception as e:
is_predicting = False
yield gr.HTML(f"""
Error saving input file:
{str(e)}
"""), gr.update(visible=False)
progress_info["status"] = "failed"
progress_info["current_step"] = "Failed to save input file"
return
# Update progress
progress_info["current_step"] = "Preparing model and parameters"
yield generate_progress_html(progress_info), gr.update(visible=False)
# Prepare command
args_dict = {
"model_path": model_path,
"plm_model": plm_models[plm_model],
"input_file": input_path,
"output_dir": output_dir, # Update to output directory
"output_file": output_file, # Output filename
"pooling_method": pooling_method,
"problem_type": problem_type,
"num_labels": num_labels,
"eval_method": eval_method,
"batch_size": batch_size
}
if eval_method == "ses-adapter":
args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None
if eval_structure_seq:
if "foldseek_seq" in eval_structure_seq:
args_dict["use_foldseek"] = True
if "ss8_seq" in eval_structure_seq:
args_dict["use_ss8"] = True
else:
args_dict["structure_seq"] = None
# Build command line
final_cmd = [sys.executable, "src/predict_batch.py"]
for k, v in args_dict.items():
if v is True:
final_cmd.append(f"--{k}")
elif v is not False and v is not None:
final_cmd.append(f"--{k}")
final_cmd.append(str(v))
# Update progress
progress_info["current_step"] = "Starting batch prediction process"
yield generate_progress_html(progress_info), gr.update(visible=False)
# Start prediction process
try:
current_process = subprocess.Popen(
final_cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1,
universal_newlines=True,
preexec_fn=os.setsid if hasattr(os, "setsid") else None
)
except Exception as e:
is_predicting = False
yield gr.HTML(f"""
Error starting prediction process:
{str(e)}
"""), gr.update(visible=False)
return
output_thread = threading.Thread(target=process_output, args=(current_process, output_queue))
output_thread.daemon = True
output_thread.start()
# Start monitoring loop
last_update_time = time.time()
result_output = ""
# Modified processing loop with abort check
while True:
# Check if process was aborted or completed
if process_aborted or current_process is None or current_process.poll() is not None:
break
# Check for new output
try:
# Get new lines
new_lines = []
for _ in range(10): # Process up to 10 lines at once
try:
line = output_queue.get_nowait()
new_lines.append(line)
result_output += line + "\n"
progress_info["lines"].append(line)
# Update progress based on output
if "Predicting:" in line:
try:
# Extract progress from tqdm output
match = re.search(r'(\d+)/(\d+)', line)
if match:
current, total = map(int, match.groups())
progress_info["completed"] = current
progress_info["total"] = total
progress_info["current_step"] = f"Processing sequence {current}/{total}"
except:
pass
elif "Loading Model and Tokenizer" in line:
progress_info["current_step"] = "Loading model and tokenizer"
elif "Processing sequences" in line:
progress_info["current_step"] = "Processing sequences"
elif "Saving results" in line:
progress_info["current_step"] = "Saving results"
except queue.Empty:
break
# Check if the process has been aborted before updating UI
if process_aborted:
break
# Check if we need to update the UI
current_time = time.time()
if new_lines or (current_time - last_update_time >= 0.5):
yield generate_progress_html(progress_info), gr.update(visible=False)
last_update_time = current_time
# Small sleep to avoid busy waiting
if not new_lines:
time.sleep(0.1)
except Exception as e:
# Check if the process has been aborted before showing error
if process_aborted:
break
error_html = f"""
Warning reading output:
{str(e)}
"""
yield gr.HTML(error_html), gr.update(visible=False)
# Check if aborted instead of completed
if process_aborted:
is_predicting = False
aborted_html = """
Prediction was manually terminated.
All prediction state has been reset.
"""
yield gr.HTML(aborted_html), gr.update(visible=False)
return
# Process has completed
if os.path.exists(output_path):
if current_process and current_process.returncode == 0:
progress_info["status"] = "completed"
# Generate final success HTML
success_html = f"""
Prediction completed successfully!
Results saved to: {output_path}
Total sequences processed: {progress_info.get('total', 0)}
"""
# Read prediction results
try:
df = pd.read_csv(output_path)
# Create summary statistics based on problem type
summary_html = ""
if problem_type == "regression":
summary_html = f"""
{df['prediction'].mean():.4f}
Mean
{df['prediction'].min():.4f}
Min
{df['prediction'].max():.4f}
Max
"""
elif problem_type == "single_label_classification":
if 'predicted_class' in df.columns:
class_counts = df['predicted_class'].value_counts()
class_stats = "".join([
f"""
{count}
Class {class_label}
"""
for class_label, count in class_counts.items()
])
summary_html = f"""
"""
elif problem_type == "multi_label_classification":
label_cols = [col for col in df.columns if col.startswith('label_') and not col.endswith('_prob')]
if label_cols:
label_stats = "".join([
f"""
"""
for col in label_cols
])
summary_html = f"""
"""
# Create table preview with style consistent with dataset preview
html_table = f"""
Batch Prediction Results Preview
{summary_html}
{' '.join([f'{col} | ' for col in df.columns])}
{generate_table_rows(df)}
You can download the complete prediction results using the button below.
"""
# Add CSS styles
final_html = success_html + f"""
{html_table}
"""
# Return results preview and download link
yield gr.HTML(final_html), gr.update(value=output_path, visible=True)
except Exception as e:
# If reading results file fails, show error but still provide download link
error_html = f"""
{success_html}
Unable to load preview results: {str(e)}
You can still download the complete prediction results file.
"""
yield gr.HTML(error_html), gr.update(value=output_path, visible=True)
else:
# Process failed
error_html = f"""
Prediction failed to complete
Process return code: {current_process.returncode if current_process else 'Unknown'}
{result_output}
"""
yield gr.HTML(error_html), gr.update(visible=False)
else:
progress_info["status"] = "failed"
error_html = f"""
Prediction completed, but output file not found at {output_path}
{result_output}
"""
yield gr.HTML(error_html), gr.update(visible=False)
except Exception as e:
# Capture the full error with traceback
error_traceback = traceback.format_exc()
# Display error with traceback in UI
error_html = f"""
Error during batch prediction: {str(e)}
{error_traceback}
"""
yield gr.HTML(error_html), gr.update(visible=False)
finally:
# Always reset prediction state
is_predicting = False
if current_process:
current_process = None
process_aborted = False # Reset abort flag
def generate_progress_html(progress_info):
"""Generate HTML progress bar similar to eval_tab"""
current = progress_info.get("completed", 0)
total = max(progress_info.get("total", 1), 1) # Avoid division by zero
percentage = min(100, int((current / total) * 100))
stage = progress_info.get("current_step", "Preparing")
# 确保进度在0-100之间
percentage = max(0, min(100, percentage))
# 准备详细信息
details = []
if total > 0:
details.append(f"Total sequences: {total}")
if current > 0 and total > 0:
details.append(f"Current progress: {current}/{total}")
details_text = ", ".join(details)
# 创建更现代化的进度条 - 完全匹配eval_tab的样式
return f"""
Prediction Status:
{stage}
{percentage:.1f}%
{f'
Total sequences: {total}
' if total > 0 else ''}
{f'
Progress: {current}/{total}
' if current > 0 and total > 0 else ''}
{f'
Status: {progress_info.get("status", "").capitalize()}
' if "status" in progress_info else ''}
"""
def generate_table_rows(df, max_rows=100):
"""Generate HTML table rows with special handling for sequence data, maintaining consistent style with eval_tab"""
rows = []
for i, row in df.iterrows():
if i >= max_rows:
break
cells = []
for col in df.columns:
value = row[col]
# Special handling for sequence type columns
if col in ['aa_seq', 'foldseek_seq', 'ss8_seq'] and isinstance(value, str) and len(value) > 30:
# Add title attribute to show full sequence on hover
cell = f'{value[:30]}... | '
# Format numeric values to 4 decimal places
elif isinstance(value, (int, float)) and not isinstance(value, bool):
formatted_value = f"{value:.4f}" if isinstance(value, float) else value
cell = f'{formatted_value} | '
else:
cell = f'{value} | '
cells.append(cell)
# Add alternating row background color
bg_color = "#f9f9f9" if i % 2 == 1 else "white"
rows.append(f'{" ".join(cells)}
')
if len(df) > max_rows:
cols_count = len(df.columns)
rows.append(f'Showing {max_rows} of {len(df)} rows |
')
return '\n'.join(rows)
def handle_abort():
"""Handle abortion of the prediction process for both single and batch prediction"""
nonlocal is_predicting, current_process, stop_thread, process_aborted
if not is_predicting or current_process is None:
empty_html = """
No prediction process is currently running.
"""
# Return full HTML value (not gr.HTML component)
return empty_html
try:
# Set the abort flag before terminating the process
process_aborted = True
stop_thread = True
# Kill the process group
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGTERM)
else:
current_process.terminate()
# Wait for process to terminate (with timeout)
try:
current_process.wait(timeout=5)
except subprocess.TimeoutExpired:
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGKILL)
else:
current_process.kill()
# Reset state
is_predicting = False
current_process = None
# Clear output queue
while not output_queue.empty():
try:
output_queue.get_nowait()
except queue.Empty:
break
success_html = """
Prediction successfully terminated!
All prediction state has been reset.
"""
# Return full HTML value (not gr.HTML component)
return success_html
except Exception as e:
# Reset states even on error
is_predicting = False
current_process = None
process_aborted = False
# Clear queue
while not output_queue.empty():
try:
output_queue.get_nowait()
except queue.Empty:
break
error_html = f"""
Failed to terminate prediction: {str(e)}
Prediction state has been reset.
"""
# Return full HTML value (not gr.HTML component)
return error_html
# Create handler functions for each tab
def handle_abort_single():
"""Handle abort for single sequence prediction tab"""
# Flag the process for abortion first
nonlocal stop_thread, process_aborted, is_predicting, current_process
# Only proceed if there's an active prediction
if not is_predicting or current_process is None:
return gr.HTML("""
No prediction process is currently running.
""")
# Set the abort flags
process_aborted = True
stop_thread = True
# Terminate the process
try:
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGTERM)
else:
current_process.terminate()
# Wait briefly for termination
try:
current_process.wait(timeout=1)
except subprocess.TimeoutExpired:
# Force kill if necessary
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGKILL)
else:
current_process.kill()
except Exception as e:
pass # Catch any termination errors
# Reset state
is_predicting = False
current_process = None
# Return the success message
return gr.HTML("""
Prediction successfully terminated!
All prediction state has been reset.
""")
def handle_abort_batch():
"""Handle abort for batch prediction tab"""
# Flag the process for abortion first
nonlocal stop_thread, process_aborted, is_predicting, current_process
# Only proceed if there's an active prediction
if not is_predicting or current_process is None:
return gr.HTML("""
No prediction process is currently running.
"""), gr.update(visible=False)
# Set the abort flags
process_aborted = True
stop_thread = True
# Terminate the process
try:
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGTERM)
else:
current_process.terminate()
# Wait briefly for termination
try:
current_process.wait(timeout=1)
except subprocess.TimeoutExpired:
# Force kill if necessary
if hasattr(os, "killpg"):
os.killpg(os.getpgid(current_process.pid), signal.SIGKILL)
else:
current_process.kill()
except Exception as e:
pass # Catch any termination errors
# Reset state
is_predicting = False
current_process = None
# Clear output queue
while not output_queue.empty():
try:
output_queue.get_nowait()
except queue.Empty:
break
# Return the success message and hide the download button
return gr.HTML("""
Prediction successfully terminated!
All prediction state has been reset.
"""), gr.update(visible=False)
def handle_preview(plm_model, model_path, eval_method, aa_seq, foldseek_seq, ss8_seq, eval_structure_seq, pooling_method, problem_type, num_labels):
"""处理单序列预测命令预览"""
# 构建参数字典
args_dict = {
"model_path": model_path,
"plm_model": plm_models[plm_model],
"aa_seq": aa_seq,
"foldseek_seq": foldseek_seq if foldseek_seq else "",
"ss8_seq": ss8_seq if ss8_seq else "",
"pooling_method": pooling_method,
"problem_type": problem_type,
"num_labels": num_labels,
"eval_method": eval_method
}
if eval_method == "ses-adapter":
args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None
if eval_structure_seq:
if "foldseek_seq" in eval_structure_seq:
args_dict["use_foldseek"] = True
if "ss8_seq" in eval_structure_seq:
args_dict["use_ss8"] = True
# 生成预览命令
preview_text = preview_predict_command(args_dict, is_batch=False)
return gr.update(value=preview_text, visible=True)
def handle_batch_preview(plm_model, model_path, eval_method, input_file, eval_structure_seq, pooling_method, problem_type, num_labels, batch_size):
"""处理批量预测命令预览"""
if not input_file:
return gr.update(value="Please upload a file first", visible=True)
# 创建临时目录作为输出目录
temp_dir = "temp_predictions"
output_file = "predictions.csv"
args_dict = {
"model_path": model_path,
"plm_model": plm_models[plm_model],
"input_file": input_file.name if hasattr(input_file, "name") else "input.csv",
"output_dir": temp_dir, # 新增输出目录参数
"output_file": output_file, # 输出文件名
"pooling_method": pooling_method,
"problem_type": problem_type,
"num_labels": num_labels,
"eval_method": eval_method,
"batch_size": batch_size
}
if eval_method == "ses-adapter":
args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None
if eval_structure_seq:
if "foldseek_seq" in eval_structure_seq:
args_dict["use_foldseek"] = True
if "ss8_seq" in eval_structure_seq:
args_dict["use_ss8"] = True
# 生成预览命令
preview_text = preview_predict_command(args_dict, is_batch=True)
return gr.update(value=preview_text, visible=True)
with gr.Tab("Prediction"):
with gr.Row():
with gr.Column():
gr.Markdown("## Protein Function Prediction")
gr.Markdown("### Model Configuration")
with gr.Group():
with gr.Row():
model_path = gr.Textbox(
label="Model Path",
value="ckpt/demo/demo_provided.pt",
placeholder="Path to the trained model"
)
plm_model = gr.Dropdown(
choices=list(plm_models.keys()),
label="Protein Language Model"
)
with gr.Row():
eval_method = gr.Dropdown(
choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"],
label="Evaluation Method",
value="freeze"
)
pooling_method = gr.Dropdown(
choices=["mean", "attention1d", "light_attention"],
label="Pooling Method",
value="mean"
)
# Settings for different training methods
with gr.Row(visible=False) as structure_seq_row:
structure_seq = gr.Dropdown(
choices=["foldseek_seq", "ss8_seq"],
label="Structure Sequences",
multiselect=True,
value=["foldseek_seq", "ss8_seq"],
info="Select the structure sequences to use for prediction"
)
with gr.Row():
problem_type = gr.Dropdown(
choices=["single_label_classification", "multi_label_classification", "regression"],
label="Problem Type",
value="single_label_classification"
)
num_labels = gr.Number(
value=2,
label="Number of Labels",
precision=0,
minimum=1
)
with gr.Tabs():
with gr.Tab("Sequence Prediction"):
gr.Markdown("### Input Sequences")
with gr.Row():
aa_seq = gr.Textbox(
label="Amino Acid Sequence",
placeholder="Enter protein sequence",
lines=3
)
# Put the structure input rows in a row with controllable visibility
with gr.Row(visible=False) as structure_input_row:
foldseek_seq = gr.Textbox(
label="Foldseek Sequence",
placeholder="Enter foldseek sequence if available",
lines=3
)
ss8_seq = gr.Textbox(
label="SS8 Sequence",
placeholder="Enter secondary structure sequence if available",
lines=3
)
with gr.Row():
preview_single_button = gr.Button("Preview Command")
predict_button = gr.Button("Predict", variant="primary")
abort_button = gr.Button("Abort", variant="stop")
# 添加命令预览区域
command_preview = gr.Code(
label="Command Preview",
language="shell",
interactive=False,
visible=False
)
predict_output = gr.HTML(label="Prediction Results")
predict_button.click(
fn=predict_sequence,
inputs=[
plm_model,
model_path,
aa_seq,
eval_method,
structure_seq,
pooling_method,
problem_type,
num_labels
],
outputs=predict_output
)
abort_button.click(
fn=handle_abort_single,
inputs=[],
outputs=[predict_output]
)
with gr.Tab("Batch Prediction"):
gr.Markdown("### Batch Prediction")
# Display CSV format information with improved styling
gr.HTML("""
""")
with gr.Row():
input_file = gr.UploadButton(
label="Upload CSV File",
file_types=[".csv"],
file_count="single"
)
# File preview accordion
with gr.Accordion("File Preview", open=False) as file_preview_accordion:
# File info area
with gr.Row():
file_info = gr.HTML("", elem_classes=["dataset-stats"])
# Table area
with gr.Row():
file_preview = gr.Dataframe(
headers=["name", "sequence"],
value=[["No file uploaded", "-"]],
wrap=True,
interactive=False,
row_count=5,
elem_classes=["preview-table"]
)
# Add file preview function
def update_file_preview(file):
if file is None:
return gr.update(value="No file uploaded
"), gr.update(value=[["No file uploaded", "-"]], headers=["name", "sequence"]), gr.update(open=False)
try:
df = pd.read_csv(file.name)
info_html = f"""
File |
Total Sequences |
Columns |
{file.name.split('/')[-1]} |
{len(df)} |
{', '.join(df.columns.tolist())} |
"""
return gr.update(value=info_html), gr.update(value=df.head(5).values.tolist(), headers=df.columns.tolist()), gr.update(open=True)
except Exception as e:
error_html = f"""
Error reading file
{str(e)}
"""
return gr.update(value=error_html), gr.update(value=[["Error", str(e)]], headers=["Error", "Message"]), gr.update(open=True)
# Use upload event instead of click event
input_file.upload(
fn=update_file_preview,
inputs=[input_file],
outputs=[file_info, file_preview, file_preview_accordion]
)
with gr.Row():
with gr.Column(scale=1):
batch_size = gr.Slider(
minimum=1,
maximum=32,
value=8,
step=1,
label="Batch Size",
info="Number of sequences to process at once"
)
with gr.Row():
preview_batch_button = gr.Button("Preview Command")
batch_predict_button = gr.Button("Start Batch Prediction", variant="primary")
batch_abort_button = gr.Button("Abort", variant="stop")
# 添加命令预览区域
batch_command_preview = gr.Code(
label="Command Preview",
language="shell",
interactive=False,
visible=False
)
batch_predict_output = gr.HTML(label="Prediction Progress")
result_file = gr.DownloadButton(label="Download Predictions", visible=False)
# 在UI部分添加命令预览的可见性控制
def toggle_preview(button_text):
"""切换命令预览的可见性"""
if "Preview" in button_text:
return gr.update(visible=True)
return gr.update(visible=False)
# 连接预览按钮
preview_single_button.click(
fn=toggle_preview,
inputs=[preview_single_button],
outputs=[command_preview]
).then(
fn=handle_preview,
inputs=[
plm_model,
model_path,
eval_method,
aa_seq,
foldseek_seq,
ss8_seq,
structure_seq,
pooling_method,
problem_type,
num_labels
],
outputs=[command_preview]
)
# 连接预览按钮
preview_batch_button.click(
fn=toggle_preview,
inputs=[preview_batch_button],
outputs=[batch_command_preview]
).then(
fn=handle_batch_preview,
inputs=[
plm_model,
model_path,
eval_method,
input_file,
structure_seq,
pooling_method,
problem_type,
num_labels,
batch_size
],
outputs=[batch_command_preview]
)
batch_predict_button.click(
fn=predict_batch,
inputs=[
plm_model,
model_path,
eval_method,
input_file,
structure_seq,
pooling_method,
problem_type,
num_labels,
batch_size
],
outputs=[batch_predict_output, result_file]
)
batch_abort_button.click(
fn=handle_abort_batch,
inputs=[],
outputs=[batch_predict_output, result_file]
)
# Add this code after all UI components are defined
def update_eval_method(method):
return {
structure_seq_row: gr.update(visible=method == "ses-adapter"),
structure_input_row: gr.update(visible=method == "ses-adapter")
}
eval_method.change(
fn=update_eval_method,
inputs=[eval_method],
outputs=[structure_seq_row, structure_input_row]
)
# Add a new function to control the visibility of the structure sequence input boxes
def update_structure_inputs(structure_seq_choices):
return {
foldseek_seq: gr.update(visible="foldseek_seq" in structure_seq_choices),
ss8_seq: gr.update(visible="ss8_seq" in structure_seq_choices)
}
# Add event handling to the UI definition section
structure_seq.change(
fn=update_structure_inputs,
inputs=[structure_seq],
outputs=[foldseek_seq, ss8_seq]
)
return {
"predict_sequence": predict_sequence,
"predict_batch": predict_batch,
"handle_abort": handle_abort
}