Spaces:
Running
Running
Commit
·
892fa67
1
Parent(s):
d78dede
working on fixes for session recovery
Browse files- vms/services/trainer.py +179 -93
- vms/ui/video_trainer_ui.py +56 -11
vms/services/trainer.py
CHANGED
@@ -361,8 +361,14 @@ class TrainingService:
|
|
361 |
if model_type not in MODEL_TYPES.values():
|
362 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
363 |
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
|
367 |
try:
|
368 |
# Get absolute paths
|
@@ -395,7 +401,7 @@ class TrainingService:
|
|
395 |
return error_msg, "No training data available"
|
396 |
|
397 |
|
398 |
-
|
399 |
preset = TRAINING_PRESETS[preset_name]
|
400 |
training_buckets = preset["training_buckets"]
|
401 |
|
@@ -524,13 +530,12 @@ class TrainingService:
|
|
524 |
return success_msg, self.get_logs()
|
525 |
|
526 |
except Exception as e:
|
527 |
-
error_msg = f"Error starting training: {str(e)}"
|
528 |
self.append_log(error_msg)
|
529 |
logger.exception("Training startup failed")
|
530 |
-
traceback.print_exc()
|
531 |
-
return "Error starting training", error_msg
|
532 |
-
|
533 |
-
|
534 |
def stop_training(self) -> Tuple[str, str]:
|
535 |
"""Stop training process"""
|
536 |
if not self.pid_file.exists():
|
@@ -631,123 +636,204 @@ class TrainingService:
|
|
631 |
status = self.get_status()
|
632 |
ui_updates = {}
|
633 |
|
634 |
-
#
|
635 |
-
|
636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
637 |
|
638 |
# Get the latest checkpoint
|
639 |
last_session = self.load_session()
|
|
|
640 |
if not last_session:
|
641 |
-
logger.warning("No session data found for recovery")
|
642 |
-
#
|
643 |
-
|
644 |
-
|
645 |
-
|
646 |
-
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
653 |
logger.warning("No checkpoints found for recovery")
|
654 |
# Set buttons for no active training
|
655 |
ui_updates = {
|
656 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
657 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
658 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
659 |
}
|
660 |
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
661 |
-
|
662 |
-
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
663 |
-
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
664 |
-
|
665 |
-
logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
|
666 |
|
667 |
# Extract parameters from the saved session (not current UI state)
|
668 |
# This ensures we use the original training parameters
|
669 |
params = last_session.get('params', {})
|
670 |
-
initial_ui_state = last_session.get('initial_ui_state', {})
|
671 |
|
672 |
# Add UI updates to restore the training parameters in the UI
|
673 |
# This shows the user what values are being used for the resumed training
|
674 |
ui_updates.update({
|
675 |
-
"model_type":
|
676 |
-
"lora_rank":
|
677 |
-
"lora_alpha":
|
678 |
-
"num_epochs":
|
679 |
-
"batch_size":
|
680 |
-
"learning_rate":
|
681 |
-
"save_iterations":
|
682 |
-
"training_preset":
|
683 |
})
|
684 |
|
685 |
-
#
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
711 |
-
|
712 |
-
|
713 |
-
|
714 |
-
|
715 |
-
|
716 |
-
|
717 |
-
|
718 |
-
|
719 |
-
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
ui_updates.update({
|
729 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
730 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
731 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
732 |
})
|
733 |
-
return {"status": "
|
|
|
734 |
elif self.is_training_running():
|
735 |
# Process is still running, set buttons accordingly
|
736 |
ui_updates = {
|
737 |
-
"start_btn": {"interactive": False, "variant": "secondary"},
|
738 |
-
"stop_btn": {"interactive": True, "variant": "
|
739 |
-
"pause_resume_btn": {"interactive":
|
740 |
}
|
741 |
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
742 |
else:
|
743 |
# No training process, set buttons to default state
|
|
|
744 |
ui_updates = {
|
745 |
-
"start_btn": {"interactive": True, "variant": "primary"},
|
746 |
-
"stop_btn": {"interactive": False, "variant": "secondary"},
|
747 |
-
"pause_resume_btn": {"interactive": False, "variant": "secondary"}
|
748 |
}
|
749 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
751 |
def clear_training_data(self) -> str:
|
752 |
"""Clear all training data"""
|
753 |
if self.is_training_running():
|
|
|
361 |
if model_type not in MODEL_TYPES.values():
|
362 |
raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
|
363 |
|
364 |
+
# Check if we're resuming or starting new
|
365 |
+
is_resuming = resume_from_checkpoint is not None
|
366 |
+
log_prefix = "Resuming" if is_resuming else "Initializing"
|
367 |
+
logger.info(f"{log_prefix} training with model_type={model_type}")
|
368 |
+
self.append_log(f"{log_prefix} training with model_type={model_type}")
|
369 |
+
|
370 |
+
if is_resuming:
|
371 |
+
self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
372 |
|
373 |
try:
|
374 |
# Get absolute paths
|
|
|
401 |
return error_msg, "No training data available"
|
402 |
|
403 |
|
404 |
+
# Get preset configuration
|
405 |
preset = TRAINING_PRESETS[preset_name]
|
406 |
training_buckets = preset["training_buckets"]
|
407 |
|
|
|
530 |
return success_msg, self.get_logs()
|
531 |
|
532 |
except Exception as e:
|
533 |
+
error_msg = f"Error {'resuming' if is_resuming else 'starting'} training: {str(e)}"
|
534 |
self.append_log(error_msg)
|
535 |
logger.exception("Training startup failed")
|
536 |
+
traceback.print_exc()
|
537 |
+
return f"Error {'resuming' if is_resuming else 'starting'} training", error_msg
|
538 |
+
|
|
|
539 |
def stop_training(self) -> Tuple[str, str]:
|
540 |
"""Stop training process"""
|
541 |
if not self.pid_file.exists():
|
|
|
636 |
status = self.get_status()
|
637 |
ui_updates = {}
|
638 |
|
639 |
+
# Check for any checkpoints, even if status doesn't indicate training
|
640 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
641 |
+
has_checkpoints = len(checkpoints) > 0
|
642 |
+
|
643 |
+
# If status indicates training but process isn't running, or if we have checkpoints
|
644 |
+
# and no active training process, try to recover
|
645 |
+
if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
|
646 |
+
(has_checkpoints and not self.is_training_running()):
|
647 |
+
|
648 |
+
logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
|
649 |
|
650 |
# Get the latest checkpoint
|
651 |
last_session = self.load_session()
|
652 |
+
|
653 |
if not last_session:
|
654 |
+
logger.warning("No session data found for recovery, but will check for checkpoints")
|
655 |
+
# Try to create a default session based on UI state if we have checkpoints
|
656 |
+
if has_checkpoints:
|
657 |
+
ui_state = self.load_ui_state()
|
658 |
+
# Create a default session using UI state values
|
659 |
+
last_session = {
|
660 |
+
"params": {
|
661 |
+
"model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
|
662 |
+
"lora_rank": ui_state.get("lora_rank", "128"),
|
663 |
+
"lora_alpha": ui_state.get("lora_alpha", "128"),
|
664 |
+
"num_epochs": ui_state.get("num_epochs", 70),
|
665 |
+
"batch_size": ui_state.get("batch_size", 1),
|
666 |
+
"learning_rate": ui_state.get("learning_rate", 3e-5),
|
667 |
+
"save_iterations": ui_state.get("save_iterations", 500),
|
668 |
+
"preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
|
669 |
+
"repo_id": "" # Default empty repo ID
|
670 |
+
}
|
671 |
+
}
|
672 |
+
logger.info("Created default session from UI state for recovery")
|
673 |
+
else:
|
674 |
+
# Set buttons for no active training
|
675 |
+
ui_updates = {
|
676 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
677 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
678 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
679 |
+
}
|
680 |
+
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
681 |
+
|
682 |
+
# Find the latest checkpoint if we have checkpoints
|
683 |
+
latest_checkpoint = None
|
684 |
+
checkpoint_step = 0
|
685 |
+
|
686 |
+
if has_checkpoints:
|
687 |
+
latest_checkpoint = max(checkpoints, key=os.path.getmtime)
|
688 |
+
checkpoint_step = int(latest_checkpoint.name.split("-")[1])
|
689 |
+
logger.info(f"Found checkpoint at step {checkpoint_step}")
|
690 |
+
else:
|
691 |
logger.warning("No checkpoints found for recovery")
|
692 |
# Set buttons for no active training
|
693 |
ui_updates = {
|
694 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
|
695 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
696 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
697 |
}
|
698 |
return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
|
|
|
|
|
|
|
|
|
|
|
699 |
|
700 |
# Extract parameters from the saved session (not current UI state)
|
701 |
# This ensures we use the original training parameters
|
702 |
params = last_session.get('params', {})
|
|
|
703 |
|
704 |
# Add UI updates to restore the training parameters in the UI
|
705 |
# This shows the user what values are being used for the resumed training
|
706 |
ui_updates.update({
|
707 |
+
"model_type": params.get('model_type', list(MODEL_TYPES.keys())[0]),
|
708 |
+
"lora_rank": params.get('lora_rank', "128"),
|
709 |
+
"lora_alpha": params.get('lora_alpha', "128"),
|
710 |
+
"num_epochs": params.get('num_epochs', 70),
|
711 |
+
"batch_size": params.get('batch_size', 1),
|
712 |
+
"learning_rate": params.get('learning_rate', 3e-5),
|
713 |
+
"save_iterations": params.get('save_iterations', 500),
|
714 |
+
"training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
715 |
})
|
716 |
|
717 |
+
# Check if we should auto-recover (immediate restart)
|
718 |
+
auto_recover = True # Always auto-recover on startup
|
719 |
+
|
720 |
+
if auto_recover:
|
721 |
+
# Attempt to resume training using the ORIGINAL parameters
|
722 |
+
try:
|
723 |
+
# Extract required parameters from the session
|
724 |
+
model_type = params.get('model_type')
|
725 |
+
lora_rank = params.get('lora_rank')
|
726 |
+
lora_alpha = params.get('lora_alpha')
|
727 |
+
num_epochs = params.get('num_epochs')
|
728 |
+
batch_size = params.get('batch_size')
|
729 |
+
learning_rate = params.get('learning_rate')
|
730 |
+
save_iterations = params.get('save_iterations')
|
731 |
+
repo_id = params.get('repo_id', '')
|
732 |
+
preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
|
733 |
+
|
734 |
+
# Log the recovery attempt
|
735 |
+
self.append_log(f"Auto-recovering training from checkpoint {checkpoint_step}")
|
736 |
+
gr.Info(f"Automatically resuming training from checkpoint {checkpoint_step}")
|
737 |
+
|
738 |
+
# Attempt to resume training
|
739 |
+
result = self.start_training(
|
740 |
+
model_type=model_type,
|
741 |
+
lora_rank=lora_rank,
|
742 |
+
lora_alpha=lora_alpha,
|
743 |
+
num_epochs=num_epochs,
|
744 |
+
batch_size=batch_size,
|
745 |
+
learning_rate=learning_rate,
|
746 |
+
save_iterations=save_iterations,
|
747 |
+
repo_id=repo_id,
|
748 |
+
preset_name=preset_name,
|
749 |
+
resume_from_checkpoint=str(latest_checkpoint)
|
750 |
+
)
|
751 |
+
|
752 |
+
# Set buttons for active training
|
753 |
+
ui_updates.update({
|
754 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
|
755 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
756 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
757 |
+
})
|
758 |
+
|
759 |
+
return {
|
760 |
+
"status": "recovered",
|
761 |
+
"message": f"Training resumed from checkpoint {checkpoint_step}",
|
762 |
+
"result": result,
|
763 |
+
"ui_updates": ui_updates
|
764 |
+
}
|
765 |
+
except Exception as e:
|
766 |
+
logger.error(f"Failed to auto-resume training: {str(e)}")
|
767 |
+
# Set buttons for manual recovery
|
768 |
+
ui_updates.update({
|
769 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
770 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
771 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
772 |
+
})
|
773 |
+
return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
|
774 |
+
else:
|
775 |
+
# Set up UI for manual recovery
|
776 |
ui_updates.update({
|
777 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
|
778 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
779 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
780 |
})
|
781 |
+
return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates}
|
782 |
+
|
783 |
elif self.is_training_running():
|
784 |
# Process is still running, set buttons accordingly
|
785 |
ui_updates = {
|
786 |
+
"start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"},
|
787 |
+
"stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
|
788 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
789 |
}
|
790 |
return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
|
791 |
else:
|
792 |
# No training process, set buttons to default state
|
793 |
+
button_text = "Continue Training" if has_checkpoints else "Start Training"
|
794 |
ui_updates = {
|
795 |
+
"start_btn": {"interactive": True, "variant": "primary", "value": button_text},
|
796 |
+
"stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
|
797 |
+
"pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
|
798 |
}
|
799 |
return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
|
800 |
+
|
801 |
+
def delete_all_checkpoints(self) -> str:
|
802 |
+
"""Delete all checkpoints in the output directory.
|
803 |
+
|
804 |
+
Returns:
|
805 |
+
Status message
|
806 |
+
"""
|
807 |
+
if self.is_training_running():
|
808 |
+
return "Cannot delete checkpoints while training is running. Stop training first."
|
809 |
|
810 |
+
try:
|
811 |
+
# Find all checkpoint directories
|
812 |
+
checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
|
813 |
+
|
814 |
+
if not checkpoints:
|
815 |
+
return "No checkpoints found to delete."
|
816 |
+
|
817 |
+
# Delete each checkpoint directory
|
818 |
+
for checkpoint in checkpoints:
|
819 |
+
if checkpoint.is_dir():
|
820 |
+
shutil.rmtree(checkpoint)
|
821 |
+
|
822 |
+
# Also delete session.json which contains previous training info
|
823 |
+
if self.session_file.exists():
|
824 |
+
self.session_file.unlink()
|
825 |
+
|
826 |
+
# Reset status file to idle
|
827 |
+
self.save_status(state='idle', message='No training in progress')
|
828 |
+
|
829 |
+
self.append_log(f"Deleted {len(checkpoints)} checkpoint(s)")
|
830 |
+
return f"Successfully deleted {len(checkpoints)} checkpoint(s)"
|
831 |
+
|
832 |
+
except Exception as e:
|
833 |
+
error_msg = f"Error deleting checkpoints: {str(e)}"
|
834 |
+
self.append_log(error_msg)
|
835 |
+
return error_msg
|
836 |
+
|
837 |
def clear_training_data(self) -> str:
|
838 |
"""Clear all training data"""
|
839 |
if self.is_training_running():
|
vms/ui/video_trainer_ui.py
CHANGED
@@ -36,7 +36,7 @@ class VideoTrainerUI:
|
|
36 |
|
37 |
# Initialize log parser
|
38 |
self.log_parser = TrainingLogParser()
|
39 |
-
|
40 |
# Shared state for tabs
|
41 |
self.state = {
|
42 |
"recovery_result": recovery_result
|
@@ -45,6 +45,9 @@ class VideoTrainerUI:
|
|
45 |
# Initialize tabs dictionary (will be populated in create_ui)
|
46 |
self.tabs = {}
|
47 |
self.tabs_component = None
|
|
|
|
|
|
|
48 |
|
49 |
def create_ui(self):
|
50 |
"""Create the main Gradio UI"""
|
@@ -104,7 +107,7 @@ class VideoTrainerUI:
|
|
104 |
self.tabs["train_tab"].components["log_box"],
|
105 |
self.tabs["train_tab"].components["start_btn"],
|
106 |
self.tabs["train_tab"].components["stop_btn"],
|
107 |
-
self.tabs["train_tab"].components["
|
108 |
]
|
109 |
)
|
110 |
|
@@ -135,14 +138,33 @@ class VideoTrainerUI:
|
|
135 |
video_list = self.tabs["split_tab"].list_unprocessed_videos()
|
136 |
training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
|
137 |
|
138 |
-
# Get button states
|
139 |
button_states = self.get_initial_button_states()
|
140 |
start_btn = button_states[0]
|
141 |
stop_btn = button_states[1]
|
142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
-
#
|
145 |
ui_state = self.load_ui_values()
|
|
|
146 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
147 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
148 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
@@ -158,7 +180,7 @@ class VideoTrainerUI:
|
|
158 |
training_dataset,
|
159 |
start_btn,
|
160 |
stop_btn,
|
161 |
-
pause_resume_btn
|
162 |
training_preset,
|
163 |
model_type_val,
|
164 |
lora_rank_val,
|
@@ -210,16 +232,39 @@ class VideoTrainerUI:
|
|
210 |
# Add this new method to get initial button states:
|
211 |
def get_initial_button_states(self):
|
212 |
"""Get the initial states for training buttons based on recovery status"""
|
213 |
-
recovery_result = self.trainer.recover_interrupted_training()
|
214 |
ui_updates = recovery_result.get("ui_updates", {})
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
# Return button states in the correct order
|
217 |
return (
|
218 |
-
gr.Button(**
|
219 |
-
gr.Button(**
|
220 |
-
gr.Button(**
|
221 |
)
|
222 |
-
|
223 |
def update_titles(self) -> Tuple[Any]:
|
224 |
"""Update all dynamic titles with current counts
|
225 |
|
|
|
36 |
|
37 |
# Initialize log parser
|
38 |
self.log_parser = TrainingLogParser()
|
39 |
+
|
40 |
# Shared state for tabs
|
41 |
self.state = {
|
42 |
"recovery_result": recovery_result
|
|
|
45 |
# Initialize tabs dictionary (will be populated in create_ui)
|
46 |
self.tabs = {}
|
47 |
self.tabs_component = None
|
48 |
+
|
49 |
+
# Log recovery status
|
50 |
+
logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
|
51 |
|
52 |
def create_ui(self):
|
53 |
"""Create the main Gradio UI"""
|
|
|
107 |
self.tabs["train_tab"].components["log_box"],
|
108 |
self.tabs["train_tab"].components["start_btn"],
|
109 |
self.tabs["train_tab"].components["stop_btn"],
|
110 |
+
self.tabs["train_tab"].components["delete_checkpoints_btn"] # Replace pause_resume_btn
|
111 |
]
|
112 |
)
|
113 |
|
|
|
138 |
video_list = self.tabs["split_tab"].list_unprocessed_videos()
|
139 |
training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
|
140 |
|
141 |
+
# Get button states based on recovery status
|
142 |
button_states = self.get_initial_button_states()
|
143 |
start_btn = button_states[0]
|
144 |
stop_btn = button_states[1]
|
145 |
+
delete_checkpoints_btn = button_states[2] # This replaces pause_resume_btn in the response tuple
|
146 |
+
|
147 |
+
# Get UI form values - possibly from the recovery
|
148 |
+
if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]:
|
149 |
+
recovery_ui = self.state["recovery_result"]["ui_updates"]
|
150 |
+
|
151 |
+
# If we recovered training parameters from the original session
|
152 |
+
ui_state = {}
|
153 |
+
for param in ["model_type", "lora_rank", "lora_alpha", "num_epochs",
|
154 |
+
"batch_size", "learning_rate", "save_iterations", "training_preset"]:
|
155 |
+
if param in recovery_ui:
|
156 |
+
ui_state[param] = recovery_ui[param]
|
157 |
+
|
158 |
+
# Merge with existing UI state if needed
|
159 |
+
if ui_state:
|
160 |
+
current_state = self.load_ui_values()
|
161 |
+
current_state.update(ui_state)
|
162 |
+
self.trainer.save_ui_state(current_state)
|
163 |
+
logger.info(f"Updated UI state from recovery: {ui_state}")
|
164 |
|
165 |
+
# Load values (potentially with recovery updates applied)
|
166 |
ui_state = self.load_ui_values()
|
167 |
+
|
168 |
training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
|
169 |
model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
|
170 |
lora_rank_val = ui_state.get("lora_rank", "128")
|
|
|
180 |
training_dataset,
|
181 |
start_btn,
|
182 |
stop_btn,
|
183 |
+
delete_checkpoints_btn, # Replaces pause_resume_btn
|
184 |
training_preset,
|
185 |
model_type_val,
|
186 |
lora_rank_val,
|
|
|
232 |
# Add this new method to get initial button states:
|
233 |
def get_initial_button_states(self):
|
234 |
"""Get the initial states for training buttons based on recovery status"""
|
235 |
+
recovery_result = self.state.get("recovery_result") or self.trainer.recover_interrupted_training()
|
236 |
ui_updates = recovery_result.get("ui_updates", {})
|
237 |
|
238 |
+
# Check for checkpoints to determine start button text
|
239 |
+
has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
|
240 |
+
|
241 |
+
# Default button states if recovery didn't provide any
|
242 |
+
if not ui_updates or not ui_updates.get("start_btn"):
|
243 |
+
is_training = self.trainer.is_training_running()
|
244 |
+
|
245 |
+
if is_training:
|
246 |
+
# Active training detected
|
247 |
+
start_btn_props = {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"}
|
248 |
+
stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
|
249 |
+
delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
|
250 |
+
else:
|
251 |
+
# No active training
|
252 |
+
start_btn_props = {"interactive": True, "variant": "primary", "value": "Continue Training" if has_checkpoints else "Start Training"}
|
253 |
+
stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
|
254 |
+
delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
|
255 |
+
else:
|
256 |
+
# Use button states from recovery
|
257 |
+
start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start Training"})
|
258 |
+
stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
|
259 |
+
delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
|
260 |
+
|
261 |
# Return button states in the correct order
|
262 |
return (
|
263 |
+
gr.Button(**start_btn_props),
|
264 |
+
gr.Button(**stop_btn_props),
|
265 |
+
gr.Button(**delete_btn_props)
|
266 |
)
|
267 |
+
|
268 |
def update_titles(self) -> Tuple[Any]:
|
269 |
"""Update all dynamic titles with current counts
|
270 |
|