API / app.py
ar08's picture
Update app.py
1f6f4e8 verified
raw
history blame
17.9 kB
#!/usr/bin/env python3
import json
import sys
import os
import io
import argparse
import os
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
import uuid
import base64
import logging
import time
import copy
import cv2
import insightface
import numpy as np
from typing import List, Union
from PIL import Image
from restoration import *
from flask import Flask, request, jsonify, make_response
from waitress import serve
LOG_LEVEL = logging.DEBUG
TMP_PATH = '/tmp/inswapper'
script_dir = os.path.dirname(os.path.abspath(__file__))
log_path = ''
# Mac does not have permission to /var/log for example
if sys.platform == 'linux':
log_path = '/var/log/'
logging.basicConfig(
filename=f'{log_path}inswapper.log',
format='%(asctime)s : %(levelname)s : %(message)s',
level=LOG_LEVEL
)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
def process_request(request_obj):
try:
logging.debug('Swapping face')
face_swap_timer = Timer()
result_image = face_swap(request_obj['source_image'], request_obj['target_image'])
face_swap_time = face_swap_timer.get_elapsed_time()
logging.info(f'Time taken to swap face: {face_swap_time} seconds')
response = {
'status': 'ok',
'image': result_image
}
except Exception as e:
logging.error(e)
response = {
'status': 'error',
'msg': 'Face swap failed',
'detail': str(e)
}
return response
class Timer:
def __init__(self):
self.start = time.time()
def restart(self):
self.start = time.time()
def get_elapsed_time(self):
end = time.time()
return round(end - self.start, 1)
def get_args():
parser = argparse.ArgumentParser(
description='Inswapper REST API'
)
parser.add_argument(
'-p', '--port',
help='Port to listen on',
type=int,
default=80
)
parser.add_argument(
'-H', '--host',
help='Host to bind to',
default='0.0.0.0'
)
return parser.parse_args()
def determine_file_extension(image_data):
try:
if image_data.startswith('/9j/'):
image_extension = '.jpg'
elif image_data.startswith('iVBORw0Kg'):
image_extension = '.png'
else:
# Default to png if we can't figure out the extension
image_extension = '.png'
except Exception as e:
image_extension = '.png'
return image_extension
def write_base64_to_disk(file_b64: str, file_path: str):
with open(file_path, 'wb') as file:
file.write(base64.b64decode(file_b64))
def get_face_swap_model(model_path: str):
model = insightface.model_zoo.get_model(model_path)
return model
def get_face_analyser(model_path: str,
det_size=(320, 320)):
face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints")
face_analyser.prepare(ctx_id=0, det_size=det_size)
return face_analyser
def get_one_face(face_analyser,
frame:np.ndarray):
face = face_analyser.get(frame)
try:
return min(face, key=lambda x: x.bbox[0])
except ValueError:
return None
def get_many_faces(face_analyser,
frame:np.ndarray):
"""
get faces from left to right by order
"""
try:
face = face_analyser.get(frame)
return sorted(face, key=lambda x: x.bbox[0])
except IndexError:
return None
def swap_face(face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame):
"""
paste source_face on target image
"""
source_face = source_faces[source_index]
target_face = target_faces[target_index]
return face_swapper.get(temp_frame, target_face, source_face, paste_back=True)
def process(source_img: Union[Image.Image, List],
target_img: Image.Image,
source_indexes: str,
target_indexes: str,
model: str):
# load face_analyser
face_analyser = get_face_analyser(model)
# load face_swapper
model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model)
face_swapper = get_face_swap_model(model_path)
# read target image
target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)
# detect faces that will be replaced in target_img
target_faces = get_many_faces(face_analyser, target_img)
num_target_faces = len(target_faces)
num_source_images = len(source_img)
if target_faces is not None:
temp_frame = copy.deepcopy(target_img)
if isinstance(source_img, list) and num_source_images == num_target_faces:
logging.debug('Replacing the faces in the target image from left to right by order')
for i in range(num_target_faces):
source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR))
source_index = i
target_index = i
if source_faces is None:
raise Exception('No source faces found!')
temp_frame = swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
elif num_source_images == 1:
# detect source faces that will be replaced into the target image
source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR))
num_source_faces = len(source_faces)
logging.debug(f'Source faces: {num_source_faces}')
logging.debug(f'Target faces: {num_target_faces}')
if source_faces is None:
raise Exception('No source faces found!')
if target_indexes == "-1":
if num_source_faces == 1:
logging.debug('Replacing all faces in target image with the same face from the source image')
num_iterations = num_target_faces
elif num_source_faces < num_target_faces:
logging.debug('There are less faces in the source image than the target image, replacing as many as we can')
num_iterations = num_source_faces
elif num_target_faces < num_source_faces:
logging.debug('There are less faces in the target image than the source image, replacing as many as we can')
num_iterations = num_target_faces
else:
logging.debug('Replacing all faces in the target image with the faces from the source image')
num_iterations = num_target_faces
for i in range(num_iterations):
source_index = 0 if num_source_faces == 1 else i
target_index = i
temp_frame = swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
elif source_indexes == '-1' and target_indexes == '-1':
logging.debug('Replacing specific face(s) in the target image with the face from the source image')
target_indexes = target_indexes.split(',')
source_index = 0
for target_index in target_indexes:
target_index = int(target_index)
temp_frame = swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
else:
logging.debug('Replacing specific face(s) in the target image with specific face(s) from the source image')
if source_indexes == "-1":
source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces)))
if target_indexes == "-1":
target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces)))
source_indexes = source_indexes.split(',')
target_indexes = target_indexes.split(',')
num_source_faces_to_swap = len(source_indexes)
num_target_faces_to_swap = len(target_indexes)
if num_source_faces_to_swap > num_source_faces:
raise Exception('Number of source indexes is greater than the number of faces in the source image')
if num_target_faces_to_swap > num_target_faces:
raise Exception('Number of target indexes is greater than the number of faces in the target image')
if num_source_faces_to_swap > num_target_faces_to_swap:
num_iterations = num_source_faces_to_swap
else:
num_iterations = num_target_faces_to_swap
if num_source_faces_to_swap == num_target_faces_to_swap:
for index in range(num_iterations):
source_index = int(source_indexes[index])
target_index = int(target_indexes[index])
if source_index > num_source_faces-1:
raise ValueError(f'Source index {source_index} is higher than the number of faces in the source image')
if target_index > num_target_faces-1:
raise ValueError(f'Target index {target_index} is higher than the number of faces in the target image')
temp_frame = swap_face(
face_swapper,
source_faces,
target_faces,
source_index,
target_index,
temp_frame
)
else:
logging.error('Unsupported face configuration')
raise Exception('Unsupported face configuration')
result = temp_frame
else:
logging.error('No target faces found')
raise Exception('No target faces found!')
result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
return result_image
def face_swap(src_img_path,
target_img_path,
source_indexes,
target_indexes,
background_enhance,
face_restore,
face_upsample,
upscale,
codeformer_fidelity,
output_format):
source_img_paths = src_img_path.split(';')
source_img = [Image.open(img_path) for img_path in source_img_paths]
target_img = Image.open(target_img_path)
# download from https://huggingface.co/ashleykleynhans/inswapper/tree/main
model = os.path.join(script_dir, 'checkpoints/inswapper_128.onnx')
logging.debug(f'Face swap model: {model}')
try:
logging.debug('Performing face swap')
result_image = process(
source_img,
target_img,
source_indexes,
target_indexes,
model
)
logging.debug('Face swap complete')
except Exception as e:
raise
# make sure the ckpts downloaded successfully
check_ckpts()
if face_restore:
# https://huggingface.co/spaces/sczhou/CodeFormer
logging.debug('Setting upsampler to RealESRGAN_x2plus')
upsampler = set_realesrgan()
if torch.cuda.is_available():
torch_device = 'cuda'
else:
torch_device = 'cpu'
logging.debug(f'Torch device: {torch_device.upper()}')
device = torch.device(torch_device)
codeformer_net = ARCH_REGISTRY.get('CodeFormer')(
dim_embd=512,
codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=['32', '64', '128', '256'],
).to(device)
ckpt_path = os.path.join(script_dir, 'CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth')
logging.debug(f'Loading CodeFormer model: {ckpt_path}')
checkpoint = torch.load(ckpt_path)['params_ema']
codeformer_net.load_state_dict(checkpoint)
codeformer_net.eval()
result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR)
logging.debug('Performing face restoration using CodeFormer')
try:
result_image = face_restoration(
result_image,
background_enhance,
face_upsample,
upscale,
codeformer_fidelity,
upsampler,
codeformer_net,
device
)
except Exception as e:
raise
logging.debug('CodeFormer face restoration completed successfully')
result_image = Image.fromarray(result_image)
output_buffer = io.BytesIO()
result_image.save(output_buffer, format=output_format)
image_data = output_buffer.getvalue()
return base64.b64encode(image_data).decode('utf-8')
app = Flask(__name__)
@app.errorhandler(400)
def not_found(error):
return make_response(jsonify(
{
'status': 'error',
'msg': f'Bad Request',
'detail': str(error)
}
), 400)
@app.errorhandler(404)
def not_found(error):
return make_response(jsonify(
{
'status': 'error',
'msg': f'{request.url} not found',
'detail': str(error)
}
), 404)
@app.errorhandler(500)
def internal_server_error(error):
return make_response(jsonify(
{
'status': 'error',
'msg': 'Internal Server Error',
'detail': str(error)
}
), 500)
@app.route('/', methods=['GET'])
def ping():
return make_response(jsonify(
{
'status': 'ok'
}
), 200)
@app.route('/faceswap', methods=['POST'])
def face_swap_api():
total_timer = Timer()
logging.debug('Received face swap API request')
payload = request.get_json()
if not os.path.exists(TMP_PATH):
logging.debug(f'Creating temporary directory: {TMP_PATH}')
os.makedirs(TMP_PATH)
unique_id = uuid.uuid4()
source_image_data = payload['source_image']
target_image_data = payload['target_image']
# Decode the source image data
source_image = base64.b64decode(source_image_data)
source_file_extension = determine_file_extension(source_image_data)
source_image_path = f'{TMP_PATH}/source_{unique_id}{source_file_extension}'
# Save the source image to disk
with open(source_image_path, 'wb') as source_file:
source_file.write(source_image)
# Decode the target image data
target_image = base64.b64decode(target_image_data)
target_file_extension = determine_file_extension(target_image_data)
target_image_path = f'{TMP_PATH}/target_{unique_id}{target_file_extension}'
# Save the target image to disk
with open(target_image_path, 'wb') as target_file:
target_file.write(target_image)
# Set defaults if they are not specified in the payload
if 'source_indexes' not in payload:
payload['source_indexes'] = '-1'
if 'target_indexes' not in payload:
payload['target_indexes'] = '-1'
if 'background_enhance' not in payload:
payload['background_enhance'] = True
if 'face_restore' not in payload:
payload['face_restore'] = True
if 'face_upsample' not in payload:
payload['face_upsample'] = True
if 'upscale' not in payload:
payload['upscale'] = 1
if 'codeformer_fidelity' not in payload:
payload['codeformer_fidelity'] = 0.5
if 'output_format' not in payload:
payload['output_format'] = 'JPEG'
try:
logging.debug(f'Source indexes: {payload["source_indexes"]}')
logging.debug(f'Target indexes: {payload["target_indexes"]}')
logging.debug(f'Background enhance: {payload["background_enhance"]}')
logging.debug(f'Face Restoration: {payload["face_restore"]}')
logging.debug(f'Face Upsampling: {payload["face_upsample"]}')
logging.debug(f'Upscale: {payload["upscale"]}')
logging.debug(f'Codeformer Fidelity: {payload["codeformer_fidelity"]}')
logging.debug(f'Output Format: {payload["output_format"]}')
result_image = face_swap(
source_image_path,
target_image_path,
payload['source_indexes'],
payload['target_indexes'],
payload['background_enhance'],
payload['face_restore'],
payload['face_upsample'],
payload['upscale'],
payload['codeformer_fidelity'],
payload['output_format']
)
status_code = 200
response = {
'status': 'ok',
'image': result_image
}
except Exception as e:
logging.error(e)
response = {
'status': 'error',
'msg': 'Face swap failed',
'detail': str(e)
}
status_code = 500
# Clean up temporary images
os.remove(source_image_path)
os.remove(target_image_path)
total_time = total_timer.get_elapsed_time()
logging.info(f'Total time taken for face swap API call {total_time} seconds')
return make_response(jsonify(response), status_code)
if __name__ == '__main__':
args = get_args()
serve(
app,
host=args.host,
port=args.port
)