DDColor / infer_hf.py
Sulio's picture
Upload folder using huggingface_hub
00e6746 verified
import os
import cv2
import torch
import argparse
from tqdm import tqdm
from huggingface_hub import PyTorchModelHubMixin
from ddcolor_model import DDColor
from infer import ImageColorizationPipeline
class DDColorHF(DDColor, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__(**config)
class ImageColorizationPipelineHF(ImageColorizationPipeline):
def __init__(self, model, input_size):
self.input_size = input_size
if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.device = torch.device("cpu")
self.model = model.to(self.device)
self.model.eval()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="ddcolor_modelscope")
parser.add_argument(
"--input",
type=str,
default="figure/",
help="input test image folder or video path",
)
parser.add_argument(
"--output", type=str, default="results", help="output folder or video path"
)
parser.add_argument(
"--input_size", type=int, default=512, help="input size for model"
)
args = parser.parse_args()
if not os.path.exists(args.model_name):
model_name = f"piddnad/{args.model_name}"
else:
model_name = args.model_name
ddcolor_model = DDColorHF.from_pretrained(model_name)
print(f"Output path: {args.output}")
os.makedirs(args.output, exist_ok=True)
img_list = os.listdir(args.input)
assert len(img_list) > 0
colorizer = ImageColorizationPipelineHF(
model=ddcolor_model, input_size=args.input_size
)
for name in tqdm(img_list):
img = cv2.imread(os.path.join(args.input, name))
image_out = colorizer.process(img)
cv2.imwrite(os.path.join(args.output, name), image_out)
if __name__ == "__main__":
main()