import os import cv2 import torch import argparse from tqdm import tqdm from huggingface_hub import PyTorchModelHubMixin from ddcolor_model import DDColor from infer import ImageColorizationPipeline class DDColorHF(DDColor, PyTorchModelHubMixin): def __init__(self, config): super().__init__(**config) class ImageColorizationPipelineHF(ImageColorizationPipeline): def __init__(self, model, input_size): self.input_size = input_size if torch.cuda.is_available(): self.device = torch.device("cuda") else: self.device = torch.device("cpu") self.model = model.to(self.device) self.model.eval() def main(): parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, default="ddcolor_modelscope") parser.add_argument( "--input", type=str, default="figure/", help="input test image folder or video path", ) parser.add_argument( "--output", type=str, default="results", help="output folder or video path" ) parser.add_argument( "--input_size", type=int, default=512, help="input size for model" ) args = parser.parse_args() if not os.path.exists(args.model_name): model_name = f"piddnad/{args.model_name}" else: model_name = args.model_name ddcolor_model = DDColorHF.from_pretrained(model_name) print(f"Output path: {args.output}") os.makedirs(args.output, exist_ok=True) img_list = os.listdir(args.input) assert len(img_list) > 0 colorizer = ImageColorizationPipelineHF( model=ddcolor_model, input_size=args.input_size ) for name in tqdm(img_list): img = cv2.imread(os.path.join(args.input, name)) image_out = colorizer.process(img) cv2.imwrite(os.path.join(args.output, name), image_out) if __name__ == "__main__": main()