ramimu's picture
Upload 586 files
1c72248 verified
raw
history blame
9 kB
from collections import OrderedDict
import os
import sqlite3
import asyncio
import concurrent.futures
from extensions_built_in.sd_trainer.SDTrainer import SDTrainer
from typing import Literal, Optional
AITK_Status = Literal["running", "stopped", "error", "completed"]
class UITrainer(SDTrainer):
def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
super(UITrainer, self).__init__(process_id, job, config, **kwargs)
self.sqlite_db_path = self.config.get("sqlite_db_path", "./aitk_db.db")
if not os.path.exists(self.sqlite_db_path):
raise Exception(
f"SQLite database not found at {self.sqlite_db_path}")
print(f"Using SQLite database at {self.sqlite_db_path}")
self.job_id = os.environ.get("AITK_JOB_ID", None)
self.job_id = self.job_id.strip() if self.job_id is not None else None
print(f"Job ID: \"{self.job_id}\"")
if self.job_id is None:
raise Exception("AITK_JOB_ID not set")
self.is_stopping = False
# Create a thread pool for database operations
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
# Track all async tasks
self._async_tasks = []
# Initialize the status
self._run_async_operation(self._update_status("running", "Starting"))
def _run_async_operation(self, coro):
"""Helper method to run an async coroutine and track the task."""
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No event loop exists, create a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Create a task and track it
if loop.is_running():
task = asyncio.run_coroutine_threadsafe(coro, loop)
self._async_tasks.append(asyncio.wrap_future(task))
else:
task = loop.create_task(coro)
self._async_tasks.append(task)
loop.run_until_complete(task)
async def _execute_db_operation(self, operation_func):
"""Execute a database operation in a separate thread to avoid blocking."""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(self.thread_pool, operation_func)
def _db_connect(self):
"""Create a new connection for each operation to avoid locking."""
conn = sqlite3.connect(self.sqlite_db_path, timeout=10.0)
conn.isolation_level = None # Enable autocommit mode
return conn
def should_stop(self):
def _check_stop():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT stop FROM Job WHERE id = ?", (self.job_id,))
stop = cursor.fetchone()
return False if stop is None else stop[0] == 1
return _check_stop()
def maybe_stop(self):
if self.should_stop():
self._run_async_operation(
self._update_status("stopped", "Job stopped"))
self.is_stopping = True
raise Exception("Job stopped")
async def _update_key(self, key, value):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
# Convert the value to string if it's not already
if isinstance(value, str):
value_to_insert = value
else:
value_to_insert = str(value)
# Use parameterized query for both the column name and value
update_query = f"UPDATE Job SET {key} = ? WHERE id = ?"
cursor.execute(
update_query, (value_to_insert, self.job_id))
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_step(self):
"""Non-blocking update of the step count."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key("step", self.step_num))
def update_db_key(self, key, value):
"""Non-blocking update a key in the database."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_key(key, value))
async def _update_status(self, status: AITK_Status, info: Optional[str] = None):
if not self.accelerator.is_main_process:
return
def _do_update():
with self._db_connect() as conn:
cursor = conn.cursor()
cursor.execute("BEGIN IMMEDIATE")
try:
if info is not None:
cursor.execute(
"UPDATE Job SET status = ?, info = ? WHERE id = ?",
(status, info, self.job_id)
)
else:
cursor.execute(
"UPDATE Job SET status = ? WHERE id = ?",
(status, self.job_id)
)
finally:
cursor.execute("COMMIT")
await self._execute_db_operation(_do_update)
def update_status(self, status: AITK_Status, info: Optional[str] = None):
"""Non-blocking update of status."""
if self.accelerator.is_main_process:
self._run_async_operation(self._update_status(status, info))
async def wait_for_all_async(self):
"""Wait for all tracked async operations to complete."""
if not self._async_tasks:
return
try:
await asyncio.gather(*self._async_tasks)
except Exception as e:
pass
finally:
# Clear the task list after completion
self._async_tasks.clear()
def on_error(self, e: Exception):
super(UITrainer, self).on_error(e)
if self.accelerator.is_main_process and not self.is_stopping:
self.update_status("error", str(e))
self.update_db_key("step", self.last_save_step)
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def handle_timing_print_hook(self, timing_dict):
if "train_loop" not in timing_dict:
print("train_loop not found in timing_dict", timing_dict)
return
seconds_per_iter = timing_dict["train_loop"]
# determine iter/sec or sec/iter
if seconds_per_iter < 1:
iters_per_sec = 1 / seconds_per_iter
self.update_db_key("speed_string", f"{iters_per_sec:.2f} iter/sec")
else:
self.update_db_key(
"speed_string", f"{seconds_per_iter:.2f} sec/iter")
def done_hook(self):
super(UITrainer, self).done_hook()
self.update_status("completed", "Training completed")
# Wait for all async operations to finish before shutting down
asyncio.run(self.wait_for_all_async())
self.thread_pool.shutdown(wait=True)
def end_step_hook(self):
super(UITrainer, self).end_step_hook()
self.update_step()
self.maybe_stop()
def hook_before_model_load(self):
super().hook_before_model_load()
self.maybe_stop()
self.update_status("running", "Loading model")
def before_dataset_load(self):
super().before_dataset_load()
self.maybe_stop()
self.update_status("running", "Loading dataset")
def hook_before_train_loop(self):
super().hook_before_train_loop()
self.maybe_stop()
self.update_step()
self.update_status("running", "Training")
self.timer.add_after_print_hook(self.handle_timing_print_hook)
def status_update_hook_func(self, string):
self.update_status("running", string)
def hook_after_sd_init_before_load(self):
super().hook_after_sd_init_before_load()
self.maybe_stop()
self.sd.add_status_update_hook(self.status_update_hook_func)
def sample_step_hook(self, img_num, total_imgs):
super().sample_step_hook(img_num, total_imgs)
self.maybe_stop()
self.update_status(
"running", f"Generating images - {img_num + 1}/{total_imgs}")
def sample(self, step=None, is_first=False):
self.maybe_stop()
total_imgs = len(self.sample_config.prompts)
self.update_status("running", f"Generating images - 0/{total_imgs}")
super().sample(step, is_first)
self.maybe_stop()
self.update_status("running", "Training")
def save(self, step=None):
self.maybe_stop()
self.update_status("running", "Saving model")
super().save(step)
self.maybe_stop()
self.update_status("running", "Training")