from typing import OrderedDict, Optional from PIL import Image from toolkit.config_modules import LoggingConfig # Base logger class # This class does nothing, it's just a placeholder class EmptyLogger: def __init__(self, *args, **kwargs) -> None: pass # start logging the training def start(self): pass # collect the log to send def log(self, *args, **kwargs): pass # send the log def commit(self, step: Optional[int] = None): pass # log image def log_image(self, *args, **kwargs): pass # finish logging def finish(self): pass # Wandb logger class # This class logs the data to wandb class WandbLogger(EmptyLogger): def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None: self.project = project self.run_name = run_name self.config = config def start(self): try: import wandb except ImportError: raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`") # send the whole config to wandb run = wandb.init(project=self.project, name=self.run_name, config=self.config) self.run = run self._log = wandb.log # log function self._image = wandb.Image # image object def log(self, *args, **kwargs): # when commit is False, wandb increments the step, # but we don't want that to happen, so we set commit=False self._log(*args, **kwargs, commit=False) def commit(self, step: Optional[int] = None): # after overall one step is done, we commit the log # by log empty object with commit=True self._log({}, step=step, commit=True) def log_image( self, image: Image, id, # sample index caption: str | None = None, # positive prompt *args, **kwargs, ): # create a wandb image object and log it image = self._image(image, caption=caption, *args, **kwargs) self._log({f"sample_{id}": image}, commit=False) def finish(self): self.run.finish() # create logger based on the logging config def create_logger(logging_config: LoggingConfig, all_config: OrderedDict): if logging_config.use_wandb: project_name = logging_config.project_name run_name = logging_config.run_name return WandbLogger(project=project_name, run_name=run_name, config=all_config) else: return EmptyLogger()