#!/usr/bin/env python3 import json import sys import os import io import argparse import uuid import base64 import logging import time import copy import cv2 import insightface import numpy as np from typing import List, Union from PIL import Image from restoration import * from flask import Flask, request, jsonify, make_response from waitress import serve LOG_LEVEL = logging.DEBUG TMP_PATH = '/tmp/inswapper' script_dir = os.path.dirname(os.path.abspath(__file__)) log_path = '' # Mac does not have permission to /var/log for example if sys.platform == 'linux': log_path = '/var/log/' logging.basicConfig( filename=f'{log_path}inswapper.log', format='%(asctime)s : %(levelname)s : %(message)s', level=LOG_LEVEL ) logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) def process_request(request_obj): try: logging.debug('Swapping face') face_swap_timer = Timer() result_image = face_swap(request_obj['source_image'], request_obj['target_image']) face_swap_time = face_swap_timer.get_elapsed_time() logging.info(f'Time taken to swap face: {face_swap_time} seconds') response = { 'status': 'ok', 'image': result_image } except Exception as e: logging.error(e) response = { 'status': 'error', 'msg': 'Face swap failed', 'detail': str(e) } return response class Timer: def __init__(self): self.start = time.time() def restart(self): self.start = time.time() def get_elapsed_time(self): end = time.time() return round(end - self.start, 1) def get_args(): parser = argparse.ArgumentParser( description='Inswapper REST API' ) parser.add_argument( '-p', '--port', help='Port to listen on', type=int, default=80 ) parser.add_argument( '-H', '--host', help='Host to bind to', default='0.0.0.0' ) return parser.parse_args() def determine_file_extension(image_data): try: if image_data.startswith('/9j/'): image_extension = '.jpg' elif image_data.startswith('iVBORw0Kg'): image_extension = '.png' else: # Default to png if we can't figure out the extension image_extension = '.png' except Exception as e: image_extension = '.png' return image_extension def write_base64_to_disk(file_b64: str, file_path: str): with open(file_path, 'wb') as file: file.write(base64.b64decode(file_b64)) def get_face_swap_model(model_path: str): model = insightface.model_zoo.get_model(model_path) return model def get_face_analyser(model_path: str, det_size=(320, 320)): face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints") face_analyser.prepare(ctx_id=0, det_size=det_size) return face_analyser def get_one_face(face_analyser, frame:np.ndarray): face = face_analyser.get(frame) try: return min(face, key=lambda x: x.bbox[0]) except ValueError: return None def get_many_faces(face_analyser, frame:np.ndarray): """ get faces from left to right by order """ try: face = face_analyser.get(frame) return sorted(face, key=lambda x: x.bbox[0]) except IndexError: return None def swap_face(face_swapper, source_faces, target_faces, source_index, target_index, temp_frame): """ paste source_face on target image """ source_face = source_faces[source_index] target_face = target_faces[target_index] return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) def process(source_img: Union[Image.Image, List], target_img: Image.Image, source_indexes: str, target_indexes: str, model: str): # load face_analyser face_analyser = get_face_analyser(model) # load face_swapper model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) face_swapper = get_face_swap_model(model_path) # read target image target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) # detect faces that will be replaced in target_img target_faces = get_many_faces(face_analyser, target_img) num_target_faces = len(target_faces) num_source_images = len(source_img) if target_faces is not None: temp_frame = copy.deepcopy(target_img) if isinstance(source_img, list) and num_source_images == num_target_faces: logging.debug('Replacing the faces in the target image from left to right by order') for i in range(num_target_faces): source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) source_index = i target_index = i if source_faces is None: raise Exception('No source faces found!') temp_frame = swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) elif num_source_images == 1: # detect source faces that will be replaced into the target image source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR)) num_source_faces = len(source_faces) logging.debug(f'Source faces: {num_source_faces}') logging.debug(f'Target faces: {num_target_faces}') if source_faces is None: raise Exception('No source faces found!') if target_indexes == "-1": if num_source_faces == 1: logging.debug('Replacing all faces in target image with the same face from the source image') num_iterations = num_target_faces elif num_source_faces < num_target_faces: logging.debug('There are less faces in the source image than the target image, replacing as many as we can') num_iterations = num_source_faces elif num_target_faces < num_source_faces: logging.debug('There are less faces in the target image than the source image, replacing as many as we can') num_iterations = num_target_faces else: logging.debug('Replacing all faces in the target image with the faces from the source image') num_iterations = num_target_faces for i in range(num_iterations): source_index = 0 if num_source_faces == 1 else i target_index = i temp_frame = swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) elif source_indexes == '-1' and target_indexes == '-1': logging.debug('Replacing specific face(s) in the target image with the face from the source image') target_indexes = target_indexes.split(',') source_index = 0 for target_index in target_indexes: target_index = int(target_index) temp_frame = swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) else: logging.debug('Replacing specific face(s) in the target image with specific face(s) from the source image') if source_indexes == "-1": source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces))) if target_indexes == "-1": target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces))) source_indexes = source_indexes.split(',') target_indexes = target_indexes.split(',') num_source_faces_to_swap = len(source_indexes) num_target_faces_to_swap = len(target_indexes) if num_source_faces_to_swap > num_source_faces: raise Exception('Number of source indexes is greater than the number of faces in the source image') if num_target_faces_to_swap > num_target_faces: raise Exception('Number of target indexes is greater than the number of faces in the target image') if num_source_faces_to_swap > num_target_faces_to_swap: num_iterations = num_source_faces_to_swap else: num_iterations = num_target_faces_to_swap if num_source_faces_to_swap == num_target_faces_to_swap: for index in range(num_iterations): source_index = int(source_indexes[index]) target_index = int(target_indexes[index]) if source_index > num_source_faces-1: raise ValueError(f'Source index {source_index} is higher than the number of faces in the source image') if target_index > num_target_faces-1: raise ValueError(f'Target index {target_index} is higher than the number of faces in the target image') temp_frame = swap_face( face_swapper, source_faces, target_faces, source_index, target_index, temp_frame ) else: logging.error('Unsupported face configuration') raise Exception('Unsupported face configuration') result = temp_frame else: logging.error('No target faces found') raise Exception('No target faces found!') result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) return result_image def face_swap(src_img_path, target_img_path, source_indexes, target_indexes, background_enhance, face_restore, face_upsample, upscale, codeformer_fidelity, output_format): source_img_paths = src_img_path.split(';') source_img = [Image.open(img_path) for img_path in source_img_paths] target_img = Image.open(target_img_path) # download from https://huggingface.co/ashleykleynhans/inswapper/tree/main model = os.path.join(script_dir, 'checkpoints/inswapper_128.onnx') logging.debug(f'Face swap model: {model}') try: logging.debug('Performing face swap') result_image = process( source_img, target_img, source_indexes, target_indexes, model ) logging.debug('Face swap complete') except Exception as e: raise # make sure the ckpts downloaded successfully check_ckpts() if face_restore: # https://huggingface.co/spaces/sczhou/CodeFormer logging.debug('Setting upsampler to RealESRGAN_x2plus') upsampler = set_realesrgan() if torch.cuda.is_available(): torch_device = 'cuda' else: torch_device = 'cpu' logging.debug(f'Torch device: {torch_device.upper()}') device = torch.device(torch_device) codeformer_net = ARCH_REGISTRY.get('CodeFormer')( dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256'], ).to(device) ckpt_path = os.path.join(script_dir, 'CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth') logging.debug(f'Loading CodeFormer model: {ckpt_path}') checkpoint = torch.load(ckpt_path)['params_ema'] codeformer_net.load_state_dict(checkpoint) codeformer_net.eval() result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) logging.debug('Performing face restoration using CodeFormer') try: result_image = face_restoration( result_image, background_enhance, face_upsample, upscale, codeformer_fidelity, upsampler, codeformer_net, device ) except Exception as e: raise logging.debug('CodeFormer face restoration completed successfully') result_image = Image.fromarray(result_image) output_buffer = io.BytesIO() result_image.save(output_buffer, format=output_format) image_data = output_buffer.getvalue() return base64.b64encode(image_data).decode('utf-8') app = Flask(__name__) @app.errorhandler(400) def not_found(error): return make_response(jsonify( { 'status': 'error', 'msg': f'Bad Request', 'detail': str(error) } ), 400) @app.errorhandler(404) def not_found(error): return make_response(jsonify( { 'status': 'error', 'msg': f'{request.url} not found', 'detail': str(error) } ), 404) @app.errorhandler(500) def internal_server_error(error): return make_response(jsonify( { 'status': 'error', 'msg': 'Internal Server Error', 'detail': str(error) } ), 500) @app.route('/', methods=['GET']) def ping(): return make_response(jsonify( { 'status': 'ok' } ), 200) @app.route('/faceswap', methods=['POST']) def face_swap_api(): total_timer = Timer() logging.debug('Received face swap API request') payload = request.get_json() if not os.path.exists(TMP_PATH): logging.debug(f'Creating temporary directory: {TMP_PATH}') os.makedirs(TMP_PATH) unique_id = uuid.uuid4() source_image_data = payload['source_image'] target_image_data = payload['target_image'] # Decode the source image data source_image = base64.b64decode(source_image_data) source_file_extension = determine_file_extension(source_image_data) source_image_path = f'{TMP_PATH}/source_{unique_id}{source_file_extension}' # Save the source image to disk with open(source_image_path, 'wb') as source_file: source_file.write(source_image) # Decode the target image data target_image = base64.b64decode(target_image_data) target_file_extension = determine_file_extension(target_image_data) target_image_path = f'{TMP_PATH}/target_{unique_id}{target_file_extension}' # Save the target image to disk with open(target_image_path, 'wb') as target_file: target_file.write(target_image) # Set defaults if they are not specified in the payload if 'source_indexes' not in payload: payload['source_indexes'] = '-1' if 'target_indexes' not in payload: payload['target_indexes'] = '-1' if 'background_enhance' not in payload: payload['background_enhance'] = True if 'face_restore' not in payload: payload['face_restore'] = True if 'face_upsample' not in payload: payload['face_upsample'] = True if 'upscale' not in payload: payload['upscale'] = 1 if 'codeformer_fidelity' not in payload: payload['codeformer_fidelity'] = 0.5 if 'output_format' not in payload: payload['output_format'] = 'JPEG' try: logging.debug(f'Source indexes: {payload["source_indexes"]}') logging.debug(f'Target indexes: {payload["target_indexes"]}') logging.debug(f'Background enhance: {payload["background_enhance"]}') logging.debug(f'Face Restoration: {payload["face_restore"]}') logging.debug(f'Face Upsampling: {payload["face_upsample"]}') logging.debug(f'Upscale: {payload["upscale"]}') logging.debug(f'Codeformer Fidelity: {payload["codeformer_fidelity"]}') logging.debug(f'Output Format: {payload["output_format"]}') result_image = face_swap( source_image_path, target_image_path, payload['source_indexes'], payload['target_indexes'], payload['background_enhance'], payload['face_restore'], payload['face_upsample'], payload['upscale'], payload['codeformer_fidelity'], payload['output_format'] ) status_code = 200 response = { 'status': 'ok', 'image': result_image } except Exception as e: logging.error(e) response = { 'status': 'error', 'msg': 'Face swap failed', 'detail': str(e) } status_code = 500 # Clean up temporary images os.remove(source_image_path) os.remove(target_image_path) total_time = total_timer.get_elapsed_time() logging.info(f'Total time taken for face swap API call {total_time} seconds') return make_response(jsonify(response), status_code) if __name__ == '__main__': args = get_args() serve( app, host=args.host, port=args.port )