jbilcke-hf HF Staff commited on
Commit
892fa67
·
1 Parent(s): d78dede

working on fixes for session recovery

Browse files
Files changed (2) hide show
  1. vms/services/trainer.py +179 -93
  2. vms/ui/video_trainer_ui.py +56 -11
vms/services/trainer.py CHANGED
@@ -361,8 +361,14 @@ class TrainingService:
361
  if model_type not in MODEL_TYPES.values():
362
  raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
363
 
364
-
365
- logger.info(f"Initializing training with model_type={model_type}")
 
 
 
 
 
 
366
 
367
  try:
368
  # Get absolute paths
@@ -395,7 +401,7 @@ class TrainingService:
395
  return error_msg, "No training data available"
396
 
397
 
398
- # Get preset configuration
399
  preset = TRAINING_PRESETS[preset_name]
400
  training_buckets = preset["training_buckets"]
401
 
@@ -524,13 +530,12 @@ class TrainingService:
524
  return success_msg, self.get_logs()
525
 
526
  except Exception as e:
527
- error_msg = f"Error starting training: {str(e)}"
528
  self.append_log(error_msg)
529
  logger.exception("Training startup failed")
530
- traceback.print_exc() # Added for better error debugging
531
- return "Error starting training", error_msg
532
-
533
-
534
  def stop_training(self) -> Tuple[str, str]:
535
  """Stop training process"""
536
  if not self.pid_file.exists():
@@ -631,123 +636,204 @@ class TrainingService:
631
  status = self.get_status()
632
  ui_updates = {}
633
 
634
- # If status indicates training but process isn't running, try to recover
635
- if status.get('status') == 'training' and not self.is_training_running():
636
- logger.info("Detected interrupted training session, attempting to recover...")
 
 
 
 
 
 
 
637
 
638
  # Get the latest checkpoint
639
  last_session = self.load_session()
 
640
  if not last_session:
641
- logger.warning("No session data found for recovery")
642
- # Set buttons for no active training
643
- ui_updates = {
644
- "start_btn": {"interactive": True, "variant": "primary"},
645
- "stop_btn": {"interactive": False, "variant": "secondary"},
646
- "pause_resume_btn": {"interactive": False, "variant": "secondary"}
647
- }
648
- return {"status": "error", "message": "No session data found", "ui_updates": ui_updates}
649
-
650
- # Find the latest checkpoint
651
- checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
652
- if not checkpoints:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
  logger.warning("No checkpoints found for recovery")
654
  # Set buttons for no active training
655
  ui_updates = {
656
- "start_btn": {"interactive": True, "variant": "primary"},
657
- "stop_btn": {"interactive": False, "variant": "secondary"},
658
- "pause_resume_btn": {"interactive": False, "variant": "secondary"}
659
  }
660
  return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
661
-
662
- latest_checkpoint = max(checkpoints, key=os.path.getmtime)
663
- checkpoint_step = int(latest_checkpoint.name.split("-")[1])
664
-
665
- logger.info(f"Found checkpoint at step {checkpoint_step}, attempting to resume")
666
 
667
  # Extract parameters from the saved session (not current UI state)
668
  # This ensures we use the original training parameters
669
  params = last_session.get('params', {})
670
- initial_ui_state = last_session.get('initial_ui_state', {})
671
 
672
  # Add UI updates to restore the training parameters in the UI
673
  # This shows the user what values are being used for the resumed training
674
  ui_updates.update({
675
- "model_type": gr.update(value=params.get('model_type', list(MODEL_TYPES.keys())[0])),
676
- "lora_rank": gr.update(value=params.get('lora_rank', "128")),
677
- "lora_alpha": gr.update(value=params.get('lora_alpha', "128")),
678
- "num_epochs": gr.update(value=params.get('num_epochs', 70)),
679
- "batch_size": gr.update(value=params.get('batch_size', 1)),
680
- "learning_rate": gr.update(value=params.get('learning_rate', 3e-5)),
681
- "save_iterations": gr.update(value=params.get('save_iterations', 500)),
682
- "training_preset": gr.update(value=params.get('preset_name', list(TRAINING_PRESETS.keys())[0]))
683
  })
684
 
685
- # Attempt to resume training using the ORIGINAL parameters
686
- try:
687
- # Extract required parameters from the session
688
- model_type = params.get('model_type')
689
- lora_rank = params.get('lora_rank')
690
- lora_alpha = params.get('lora_alpha')
691
- num_epochs = params.get('num_epochs')
692
- batch_size = params.get('batch_size')
693
- learning_rate = params.get('learning_rate')
694
- save_iterations = params.get('save_iterations')
695
- repo_id = params.get('repo_id')
696
- preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
697
-
698
- # Attempt to resume training
699
- result = self.start_training(
700
- model_type=model_type,
701
- lora_rank=lora_rank,
702
- lora_alpha=lora_alpha,
703
- num_epochs=num_epochs,
704
- batch_size=batch_size,
705
- learning_rate=learning_rate,
706
- save_iterations=save_iterations,
707
- repo_id=repo_id,
708
- preset_name=preset_name,
709
- resume_from_checkpoint=str(latest_checkpoint)
710
- )
711
-
712
- # Set buttons for active training
713
- ui_updates.update({
714
- "start_btn": {"interactive": False, "variant": "secondary"},
715
- "stop_btn": {"interactive": True, "variant": "stop"},
716
- "pause_resume_btn": {"interactive": True, "variant": "secondary"}
717
- })
718
-
719
- return {
720
- "status": "recovered",
721
- "message": f"Training resumed from checkpoint {checkpoint_step}",
722
- "result": result,
723
- "ui_updates": ui_updates
724
- }
725
- except Exception as e:
726
- logger.error(f"Failed to resume training: {str(e)}")
727
- # Set buttons for no active training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
  ui_updates.update({
729
- "start_btn": {"interactive": True, "variant": "primary"},
730
- "stop_btn": {"interactive": False, "variant": "secondary"},
731
- "pause_resume_btn": {"interactive": False, "variant": "secondary"}
732
  })
733
- return {"status": "error", "message": f"Failed to resume: {str(e)}", "ui_updates": ui_updates}
 
734
  elif self.is_training_running():
735
  # Process is still running, set buttons accordingly
736
  ui_updates = {
737
- "start_btn": {"interactive": False, "variant": "secondary"},
738
- "stop_btn": {"interactive": True, "variant": "stop"},
739
- "pause_resume_btn": {"interactive": True, "variant": "secondary"}
740
  }
741
  return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
742
  else:
743
  # No training process, set buttons to default state
 
744
  ui_updates = {
745
- "start_btn": {"interactive": True, "variant": "primary"},
746
- "stop_btn": {"interactive": False, "variant": "secondary"},
747
- "pause_resume_btn": {"interactive": False, "variant": "secondary"}
748
  }
749
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
 
 
 
 
 
 
 
 
 
750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
751
  def clear_training_data(self) -> str:
752
  """Clear all training data"""
753
  if self.is_training_running():
 
361
  if model_type not in MODEL_TYPES.values():
362
  raise ValueError(f"Invalid model_type: {model_type}. Must be one of {list(MODEL_TYPES.values())}")
363
 
364
+ # Check if we're resuming or starting new
365
+ is_resuming = resume_from_checkpoint is not None
366
+ log_prefix = "Resuming" if is_resuming else "Initializing"
367
+ logger.info(f"{log_prefix} training with model_type={model_type}")
368
+ self.append_log(f"{log_prefix} training with model_type={model_type}")
369
+
370
+ if is_resuming:
371
+ self.append_log(f"Resuming from checkpoint: {resume_from_checkpoint}")
372
 
373
  try:
374
  # Get absolute paths
 
401
  return error_msg, "No training data available"
402
 
403
 
404
+ # Get preset configuration
405
  preset = TRAINING_PRESETS[preset_name]
406
  training_buckets = preset["training_buckets"]
407
 
 
530
  return success_msg, self.get_logs()
531
 
532
  except Exception as e:
533
+ error_msg = f"Error {'resuming' if is_resuming else 'starting'} training: {str(e)}"
534
  self.append_log(error_msg)
535
  logger.exception("Training startup failed")
536
+ traceback.print_exc()
537
+ return f"Error {'resuming' if is_resuming else 'starting'} training", error_msg
538
+
 
539
  def stop_training(self) -> Tuple[str, str]:
540
  """Stop training process"""
541
  if not self.pid_file.exists():
 
636
  status = self.get_status()
637
  ui_updates = {}
638
 
639
+ # Check for any checkpoints, even if status doesn't indicate training
640
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
641
+ has_checkpoints = len(checkpoints) > 0
642
+
643
+ # If status indicates training but process isn't running, or if we have checkpoints
644
+ # and no active training process, try to recover
645
+ if (status.get('status') in ['training', 'paused'] and not self.is_training_running()) or \
646
+ (has_checkpoints and not self.is_training_running()):
647
+
648
+ logger.info("Detected interrupted training session or existing checkpoints, attempting to recover...")
649
 
650
  # Get the latest checkpoint
651
  last_session = self.load_session()
652
+
653
  if not last_session:
654
+ logger.warning("No session data found for recovery, but will check for checkpoints")
655
+ # Try to create a default session based on UI state if we have checkpoints
656
+ if has_checkpoints:
657
+ ui_state = self.load_ui_state()
658
+ # Create a default session using UI state values
659
+ last_session = {
660
+ "params": {
661
+ "model_type": MODEL_TYPES.get(ui_state.get("model_type", list(MODEL_TYPES.keys())[0])),
662
+ "lora_rank": ui_state.get("lora_rank", "128"),
663
+ "lora_alpha": ui_state.get("lora_alpha", "128"),
664
+ "num_epochs": ui_state.get("num_epochs", 70),
665
+ "batch_size": ui_state.get("batch_size", 1),
666
+ "learning_rate": ui_state.get("learning_rate", 3e-5),
667
+ "save_iterations": ui_state.get("save_iterations", 500),
668
+ "preset_name": ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0]),
669
+ "repo_id": "" # Default empty repo ID
670
+ }
671
+ }
672
+ logger.info("Created default session from UI state for recovery")
673
+ else:
674
+ # Set buttons for no active training
675
+ ui_updates = {
676
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
677
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
678
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
679
+ }
680
+ return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
681
+
682
+ # Find the latest checkpoint if we have checkpoints
683
+ latest_checkpoint = None
684
+ checkpoint_step = 0
685
+
686
+ if has_checkpoints:
687
+ latest_checkpoint = max(checkpoints, key=os.path.getmtime)
688
+ checkpoint_step = int(latest_checkpoint.name.split("-")[1])
689
+ logger.info(f"Found checkpoint at step {checkpoint_step}")
690
+ else:
691
  logger.warning("No checkpoints found for recovery")
692
  # Set buttons for no active training
693
  ui_updates = {
694
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Start Training"},
695
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
696
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
697
  }
698
  return {"status": "error", "message": "No checkpoints found", "ui_updates": ui_updates}
 
 
 
 
 
699
 
700
  # Extract parameters from the saved session (not current UI state)
701
  # This ensures we use the original training parameters
702
  params = last_session.get('params', {})
 
703
 
704
  # Add UI updates to restore the training parameters in the UI
705
  # This shows the user what values are being used for the resumed training
706
  ui_updates.update({
707
+ "model_type": params.get('model_type', list(MODEL_TYPES.keys())[0]),
708
+ "lora_rank": params.get('lora_rank', "128"),
709
+ "lora_alpha": params.get('lora_alpha', "128"),
710
+ "num_epochs": params.get('num_epochs', 70),
711
+ "batch_size": params.get('batch_size', 1),
712
+ "learning_rate": params.get('learning_rate', 3e-5),
713
+ "save_iterations": params.get('save_iterations', 500),
714
+ "training_preset": params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
715
  })
716
 
717
+ # Check if we should auto-recover (immediate restart)
718
+ auto_recover = True # Always auto-recover on startup
719
+
720
+ if auto_recover:
721
+ # Attempt to resume training using the ORIGINAL parameters
722
+ try:
723
+ # Extract required parameters from the session
724
+ model_type = params.get('model_type')
725
+ lora_rank = params.get('lora_rank')
726
+ lora_alpha = params.get('lora_alpha')
727
+ num_epochs = params.get('num_epochs')
728
+ batch_size = params.get('batch_size')
729
+ learning_rate = params.get('learning_rate')
730
+ save_iterations = params.get('save_iterations')
731
+ repo_id = params.get('repo_id', '')
732
+ preset_name = params.get('preset_name', list(TRAINING_PRESETS.keys())[0])
733
+
734
+ # Log the recovery attempt
735
+ self.append_log(f"Auto-recovering training from checkpoint {checkpoint_step}")
736
+ gr.Info(f"Automatically resuming training from checkpoint {checkpoint_step}")
737
+
738
+ # Attempt to resume training
739
+ result = self.start_training(
740
+ model_type=model_type,
741
+ lora_rank=lora_rank,
742
+ lora_alpha=lora_alpha,
743
+ num_epochs=num_epochs,
744
+ batch_size=batch_size,
745
+ learning_rate=learning_rate,
746
+ save_iterations=save_iterations,
747
+ repo_id=repo_id,
748
+ preset_name=preset_name,
749
+ resume_from_checkpoint=str(latest_checkpoint)
750
+ )
751
+
752
+ # Set buttons for active training
753
+ ui_updates.update({
754
+ "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training"},
755
+ "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
756
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
757
+ })
758
+
759
+ return {
760
+ "status": "recovered",
761
+ "message": f"Training resumed from checkpoint {checkpoint_step}",
762
+ "result": result,
763
+ "ui_updates": ui_updates
764
+ }
765
+ except Exception as e:
766
+ logger.error(f"Failed to auto-resume training: {str(e)}")
767
+ # Set buttons for manual recovery
768
+ ui_updates.update({
769
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
770
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
771
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
772
+ })
773
+ return {"status": "error", "message": f"Failed to auto-resume: {str(e)}", "ui_updates": ui_updates}
774
+ else:
775
+ # Set up UI for manual recovery
776
  ui_updates.update({
777
+ "start_btn": {"interactive": True, "variant": "primary", "value": "Continue Training"},
778
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
779
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
780
  })
781
+ return {"status": "ready_to_recover", "message": f"Ready to resume from checkpoint {checkpoint_step}", "ui_updates": ui_updates}
782
+
783
  elif self.is_training_running():
784
  # Process is still running, set buttons accordingly
785
  ui_updates = {
786
+ "start_btn": {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"},
787
+ "stop_btn": {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"},
788
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
789
  }
790
  return {"status": "running", "message": "Training process is running", "ui_updates": ui_updates}
791
  else:
792
  # No training process, set buttons to default state
793
+ button_text = "Continue Training" if has_checkpoints else "Start Training"
794
  ui_updates = {
795
+ "start_btn": {"interactive": True, "variant": "primary", "value": button_text},
796
+ "stop_btn": {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"},
797
+ "pause_resume_btn": {"interactive": False, "variant": "secondary", "visible": False}
798
  }
799
  return {"status": "idle", "message": "No training in progress", "ui_updates": ui_updates}
800
+
801
+ def delete_all_checkpoints(self) -> str:
802
+ """Delete all checkpoints in the output directory.
803
+
804
+ Returns:
805
+ Status message
806
+ """
807
+ if self.is_training_running():
808
+ return "Cannot delete checkpoints while training is running. Stop training first."
809
 
810
+ try:
811
+ # Find all checkpoint directories
812
+ checkpoints = list(OUTPUT_PATH.glob("checkpoint-*"))
813
+
814
+ if not checkpoints:
815
+ return "No checkpoints found to delete."
816
+
817
+ # Delete each checkpoint directory
818
+ for checkpoint in checkpoints:
819
+ if checkpoint.is_dir():
820
+ shutil.rmtree(checkpoint)
821
+
822
+ # Also delete session.json which contains previous training info
823
+ if self.session_file.exists():
824
+ self.session_file.unlink()
825
+
826
+ # Reset status file to idle
827
+ self.save_status(state='idle', message='No training in progress')
828
+
829
+ self.append_log(f"Deleted {len(checkpoints)} checkpoint(s)")
830
+ return f"Successfully deleted {len(checkpoints)} checkpoint(s)"
831
+
832
+ except Exception as e:
833
+ error_msg = f"Error deleting checkpoints: {str(e)}"
834
+ self.append_log(error_msg)
835
+ return error_msg
836
+
837
  def clear_training_data(self) -> str:
838
  """Clear all training data"""
839
  if self.is_training_running():
vms/ui/video_trainer_ui.py CHANGED
@@ -36,7 +36,7 @@ class VideoTrainerUI:
36
 
37
  # Initialize log parser
38
  self.log_parser = TrainingLogParser()
39
-
40
  # Shared state for tabs
41
  self.state = {
42
  "recovery_result": recovery_result
@@ -45,6 +45,9 @@ class VideoTrainerUI:
45
  # Initialize tabs dictionary (will be populated in create_ui)
46
  self.tabs = {}
47
  self.tabs_component = None
 
 
 
48
 
49
  def create_ui(self):
50
  """Create the main Gradio UI"""
@@ -104,7 +107,7 @@ class VideoTrainerUI:
104
  self.tabs["train_tab"].components["log_box"],
105
  self.tabs["train_tab"].components["start_btn"],
106
  self.tabs["train_tab"].components["stop_btn"],
107
- self.tabs["train_tab"].components["pause_resume_btn"]
108
  ]
109
  )
110
 
@@ -135,14 +138,33 @@ class VideoTrainerUI:
135
  video_list = self.tabs["split_tab"].list_unprocessed_videos()
136
  training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
137
 
138
- # Get button states
139
  button_states = self.get_initial_button_states()
140
  start_btn = button_states[0]
141
  stop_btn = button_states[1]
142
- pause_resume_btn = button_states[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
- # Get UI form values
145
  ui_state = self.load_ui_values()
 
146
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
147
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
148
  lora_rank_val = ui_state.get("lora_rank", "128")
@@ -158,7 +180,7 @@ class VideoTrainerUI:
158
  training_dataset,
159
  start_btn,
160
  stop_btn,
161
- pause_resume_btn,
162
  training_preset,
163
  model_type_val,
164
  lora_rank_val,
@@ -210,16 +232,39 @@ class VideoTrainerUI:
210
  # Add this new method to get initial button states:
211
  def get_initial_button_states(self):
212
  """Get the initial states for training buttons based on recovery status"""
213
- recovery_result = self.trainer.recover_interrupted_training()
214
  ui_updates = recovery_result.get("ui_updates", {})
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  # Return button states in the correct order
217
  return (
218
- gr.Button(**ui_updates.get("start_btn", {"interactive": True, "variant": "primary"})),
219
- gr.Button(**ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary"})),
220
- gr.Button(**ui_updates.get("pause_resume_btn", {"interactive": False, "variant": "secondary"}))
221
  )
222
-
223
  def update_titles(self) -> Tuple[Any]:
224
  """Update all dynamic titles with current counts
225
 
 
36
 
37
  # Initialize log parser
38
  self.log_parser = TrainingLogParser()
39
+
40
  # Shared state for tabs
41
  self.state = {
42
  "recovery_result": recovery_result
 
45
  # Initialize tabs dictionary (will be populated in create_ui)
46
  self.tabs = {}
47
  self.tabs_component = None
48
+
49
+ # Log recovery status
50
+ logger.info(f"Initialization complete. Recovery status: {self.recovery_status}")
51
 
52
  def create_ui(self):
53
  """Create the main Gradio UI"""
 
107
  self.tabs["train_tab"].components["log_box"],
108
  self.tabs["train_tab"].components["start_btn"],
109
  self.tabs["train_tab"].components["stop_btn"],
110
+ self.tabs["train_tab"].components["delete_checkpoints_btn"] # Replace pause_resume_btn
111
  ]
112
  )
113
 
 
138
  video_list = self.tabs["split_tab"].list_unprocessed_videos()
139
  training_dataset = self.tabs["caption_tab"].list_training_files_to_caption()
140
 
141
+ # Get button states based on recovery status
142
  button_states = self.get_initial_button_states()
143
  start_btn = button_states[0]
144
  stop_btn = button_states[1]
145
+ delete_checkpoints_btn = button_states[2] # This replaces pause_resume_btn in the response tuple
146
+
147
+ # Get UI form values - possibly from the recovery
148
+ if self.recovery_status in ["recovered", "ready_to_recover", "running"] and "ui_updates" in self.state["recovery_result"]:
149
+ recovery_ui = self.state["recovery_result"]["ui_updates"]
150
+
151
+ # If we recovered training parameters from the original session
152
+ ui_state = {}
153
+ for param in ["model_type", "lora_rank", "lora_alpha", "num_epochs",
154
+ "batch_size", "learning_rate", "save_iterations", "training_preset"]:
155
+ if param in recovery_ui:
156
+ ui_state[param] = recovery_ui[param]
157
+
158
+ # Merge with existing UI state if needed
159
+ if ui_state:
160
+ current_state = self.load_ui_values()
161
+ current_state.update(ui_state)
162
+ self.trainer.save_ui_state(current_state)
163
+ logger.info(f"Updated UI state from recovery: {ui_state}")
164
 
165
+ # Load values (potentially with recovery updates applied)
166
  ui_state = self.load_ui_values()
167
+
168
  training_preset = ui_state.get("training_preset", list(TRAINING_PRESETS.keys())[0])
169
  model_type_val = ui_state.get("model_type", list(MODEL_TYPES.keys())[0])
170
  lora_rank_val = ui_state.get("lora_rank", "128")
 
180
  training_dataset,
181
  start_btn,
182
  stop_btn,
183
+ delete_checkpoints_btn, # Replaces pause_resume_btn
184
  training_preset,
185
  model_type_val,
186
  lora_rank_val,
 
232
  # Add this new method to get initial button states:
233
  def get_initial_button_states(self):
234
  """Get the initial states for training buttons based on recovery status"""
235
+ recovery_result = self.state.get("recovery_result") or self.trainer.recover_interrupted_training()
236
  ui_updates = recovery_result.get("ui_updates", {})
237
 
238
+ # Check for checkpoints to determine start button text
239
+ has_checkpoints = len(list(OUTPUT_PATH.glob("checkpoint-*"))) > 0
240
+
241
+ # Default button states if recovery didn't provide any
242
+ if not ui_updates or not ui_updates.get("start_btn"):
243
+ is_training = self.trainer.is_training_running()
244
+
245
+ if is_training:
246
+ # Active training detected
247
+ start_btn_props = {"interactive": False, "variant": "secondary", "value": "Continue Training" if has_checkpoints else "Start Training"}
248
+ stop_btn_props = {"interactive": True, "variant": "primary", "value": "Stop at Last Checkpoint"}
249
+ delete_btn_props = {"interactive": False, "variant": "stop", "value": "Delete All Checkpoints"}
250
+ else:
251
+ # No active training
252
+ start_btn_props = {"interactive": True, "variant": "primary", "value": "Continue Training" if has_checkpoints else "Start Training"}
253
+ stop_btn_props = {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"}
254
+ delete_btn_props = {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"}
255
+ else:
256
+ # Use button states from recovery
257
+ start_btn_props = ui_updates.get("start_btn", {"interactive": True, "variant": "primary", "value": "Start Training"})
258
+ stop_btn_props = ui_updates.get("stop_btn", {"interactive": False, "variant": "secondary", "value": "Stop at Last Checkpoint"})
259
+ delete_btn_props = ui_updates.get("delete_checkpoints_btn", {"interactive": has_checkpoints, "variant": "stop", "value": "Delete All Checkpoints"})
260
+
261
  # Return button states in the correct order
262
  return (
263
+ gr.Button(**start_btn_props),
264
+ gr.Button(**stop_btn_props),
265
+ gr.Button(**delete_btn_props)
266
  )
267
+
268
  def update_titles(self) -> Tuple[Any]:
269
  """Update all dynamic titles with current counts
270