import os from install import install if "HF_DEMO" in os.environ: # Global variable to track if install() has been run; only for deploying on HF space INSTALLED = False if not INSTALLED: install() INSTALLED = True import gradio as gr import tempfile import hashlib import io import pickle import sys from test import process_face from PIL import Image INPUT_CACHE_DIR = "./cache" os.makedirs(INPUT_CACHE_DIR, exist_ok=True) def get_image_hash(img): """Generate a hash of the image content.""" img_bytes = io.BytesIO() img.save(img_bytes, format='PNG') return hashlib.md5(img_bytes.getvalue()).hexdigest() def enhance_face_gradio(input_image, ref_image): """ Wrapper function for process_face that works with Gradio. Args: input_image: Input image from Gradio ref_image: Reference face image from Gradio Returns: PIL Image: Enhanced image """ # Generate hashes for both images input_hash = get_image_hash(input_image) ref_hash = get_image_hash(ref_image) combined_hash = f"{input_hash}_{ref_hash}" cache_path = os.path.join(INPUT_CACHE_DIR, f"{combined_hash}.pkl") # Check if result exists in cache if os.path.exists(cache_path): try: with open(cache_path, 'rb') as f: result_img = pickle.load(f) print(f"Returning cached result for images with hash {combined_hash}") return result_img except (pickle.PickleError, IOError) as e: print(f"Error loading from cache: {e}") # Continue to processing if cache load fails # Create temporary files for input, reference, and output with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as input_file, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as ref_file, \ tempfile.NamedTemporaryFile(suffix=".png", delete=False) as output_file: input_path = input_file.name ref_path = ref_file.name output_path = output_file.name # Save uploaded images to temporary files input_image.save(input_path) ref_image.save(ref_path) try: process_face( input_path=input_path, ref_path=ref_path, output_path=output_path ) except Exception as e: # Handle the error, log it, and return an error message print(f"Error processing face: {e}") return "An error occurred while processing the face. Please try again." finally: # Clean up temporary input and reference files os.unlink(input_path) os.unlink(ref_path) # Load the output image result_img = Image.open(output_path) # Cache the result try: with open(cache_path, 'wb') as f: pickle.dump(result_img, f) print(f"Cached result for images with hash {combined_hash}") except (pickle.PickleError, IOError) as e: print(f"Error caching result: {e}") return result_img def create_gradio_interface(): with gr.Blocks(title="Face Enhancement") as demo: gr.Markdown(""" # Face Enhance ### Instructions 1. Upload the target image you want to enhance 2. Upload a high-quality face image 3. Click 'Enhance Face' Processing takes around 30 seconds. """, elem_id="instructions") gr.Markdown("---") with gr.Row(): with gr.Column(): input_image = gr.Image(label="Target Image", type="pil") ref_image = gr.Image(label="Reference Face", type="pil") enhance_button = gr.Button("Enhance Face") with gr.Column(): output_image = gr.Image(label="Enhanced Result") enhance_button.click( fn=enhance_face_gradio, inputs=[input_image, ref_image], outputs=output_image, queue=True # Enable queue for sequential processing ) gr.Markdown(""" ## Examples Click on an example to load the images into the interface. """) example_inps = [ ["examples/dany_gpt_1.png", "examples/dany_face.jpg"], ["examples/dany_gpt_2.png", "examples/dany_face.jpg"], ["examples/tim_gpt_1.png", "examples/tim_face.jpg"], ["examples/tim_gpt_2.png", "examples/tim_face.jpg"], ["examples/elon_gpt.png", "examples/elon_face.png"], ] gr.Examples(examples=example_inps, inputs=[input_image, ref_image], outputs=output_image) gr.Markdown(""" ## Notes Check out the code [here](https://github.com/RishiDesai/FaceEnhance) and see my [blog post](https://rishidesai.github.io/posts/face-enhancement-techniques/) for more information. Due to the constraints of this demo, face cropping and upscaling are not applied to the reference image. """) # Launch the Gradio app with queue demo.queue(max_size=99) try: demo.launch() except OSError as e: print(f"Error starting server: {e}") sys.exit(1) if __name__ == "__main__": create_gradio_interface()