import logging import os import re from collections.abc import Sequence from datetime import datetime, timezone from pathlib import Path import torch import yaml def batchify(seq: Sequence, batch_size: int): for i in range(0, len(seq), batch_size): yield seq[i : i + batch_size] def get_device(): if torch.backends.mps.is_available(): return "mps" # mac GPU elif torch.cuda.is_available(): return "cuda" else: return "cpu" def init_logger(): logging.basicConfig( level=logging.INFO, # Set the logging level format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Define the log format ) def get_timestamp(): return datetime.now(timezone.utc).strftime("%Y_%m_%d-%H_%M_%S") TIMESTAMP_PATTERN = re.compile(r"^\d{4}_\d{2}_\d{2}-\d{2}_\d{2}_\d{2}$") def get_last_timestamp(path: Path): if os.path.exists(path): timestamps = [f for f in os.listdir(path) if TIMESTAMP_PATTERN.match(f)] if len(timestamps) > 0: return sorted(timestamps)[-1] return None