ramimu's picture
Upload 586 files
1c72248 verified
from typing import OrderedDict, Optional
from PIL import Image
from toolkit.config_modules import LoggingConfig
# Base logger class
# This class does nothing, it's just a placeholder
class EmptyLogger:
def __init__(self, *args, **kwargs) -> None:
pass
# start logging the training
def start(self):
pass
# collect the log to send
def log(self, *args, **kwargs):
pass
# send the log
def commit(self, step: Optional[int] = None):
pass
# log image
def log_image(self, *args, **kwargs):
pass
# finish logging
def finish(self):
pass
# Wandb logger class
# This class logs the data to wandb
class WandbLogger(EmptyLogger):
def __init__(self, project: str, run_name: str | None, config: OrderedDict) -> None:
self.project = project
self.run_name = run_name
self.config = config
def start(self):
try:
import wandb
except ImportError:
raise ImportError("Failed to import wandb. Please install wandb by running `pip install wandb`")
# send the whole config to wandb
run = wandb.init(project=self.project, name=self.run_name, config=self.config)
self.run = run
self._log = wandb.log # log function
self._image = wandb.Image # image object
def log(self, *args, **kwargs):
# when commit is False, wandb increments the step,
# but we don't want that to happen, so we set commit=False
self._log(*args, **kwargs, commit=False)
def commit(self, step: Optional[int] = None):
# after overall one step is done, we commit the log
# by log empty object with commit=True
self._log({}, step=step, commit=True)
def log_image(
self,
image: Image,
id, # sample index
caption: str | None = None, # positive prompt
*args,
**kwargs,
):
# create a wandb image object and log it
image = self._image(image, caption=caption, *args, **kwargs)
self._log({f"sample_{id}": image}, commit=False)
def finish(self):
self.run.finish()
# create logger based on the logging config
def create_logger(logging_config: LoggingConfig, all_config: OrderedDict):
if logging_config.use_wandb:
project_name = logging_config.project_name
run_name = logging_config.run_name
return WandbLogger(project=project_name, run_name=run_name, config=all_config)
else:
return EmptyLogger()