|
import json |
|
import os |
|
import time |
|
import uuid |
|
import tempfile |
|
from PIL import Image, ImageDraw, ImageFont |
|
import gradio as gr |
|
import base64 |
|
import mimetypes |
|
|
|
from google import genai |
|
from google.genai import types |
|
|
|
def save_binary_file(file_name, data): |
|
with open(file_name, "wb") as f: |
|
f.write(data) |
|
|
|
def generate(text, file_name, api_key, model="gemini-2.0-flash-exp"): |
|
|
|
client = genai.Client(api_key=(api_key.strip() if api_key and api_key.strip() != "" |
|
else os.environ.get("GEMINI_API_KEY"))) |
|
|
|
try: |
|
print("Uploading file to Gemini API...") |
|
files = [ client.files.upload(file=file_name) ] |
|
|
|
contents = [ |
|
types.Content( |
|
role="user", |
|
parts=[ |
|
types.Part.from_uri( |
|
file_uri=files[0].uri, |
|
mime_type=files[0].mime_type, |
|
), |
|
types.Part.from_text(text=text), |
|
], |
|
), |
|
] |
|
generate_content_config = types.GenerateContentConfig( |
|
temperature=0, |
|
top_p=0.92, |
|
max_output_tokens=8192, |
|
response_modalities=["image", "text"], |
|
response_mime_type="text/plain", |
|
|
|
safety_settings=[ |
|
{ |
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT", |
|
"threshold": "BLOCK_MEDIUM_AND_ABOVE" |
|
} |
|
] |
|
) |
|
|
|
text_response = "" |
|
image_path = None |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: |
|
temp_path = tmp.name |
|
print("Sending request to Gemini API...") |
|
|
|
start_time = time.time() |
|
max_wait_time = 60 |
|
|
|
try: |
|
stream = client.models.generate_content_stream( |
|
model=model, |
|
contents=contents, |
|
config=generate_content_config, |
|
) |
|
|
|
for chunk in stream: |
|
|
|
if time.time() - start_time > max_wait_time: |
|
print("Gemini API request timed out after", max_wait_time, "seconds") |
|
break |
|
|
|
if not chunk.candidates or not chunk.candidates[0].content or not chunk.candidates[0].content.parts: |
|
continue |
|
candidate = chunk.candidates[0].content.parts[0] |
|
|
|
if candidate.inline_data: |
|
save_binary_file(temp_path, candidate.inline_data.data) |
|
print(f"Smile enhancement image generated: {temp_path}") |
|
image_path = temp_path |
|
|
|
break |
|
else: |
|
|
|
text_response += chunk.text + "\n" |
|
print("Received text response from Gemini API") |
|
except Exception as e: |
|
print(f"Error during content generation: {str(e)}") |
|
|
|
except Exception as e: |
|
print(f"Error in Gemini API setup: {str(e)}") |
|
return None, f"Error: {str(e)}" |
|
finally: |
|
|
|
try: |
|
if 'files' in locals() and files: |
|
del files |
|
except: |
|
pass |
|
|
|
return image_path, text_response |
|
|
|
def assess_image_quality(original_image, enhanced_image): |
|
""" |
|
Assesses the quality of the enhanced image based on specific criteria. |
|
Returns a tuple of (is_acceptable, feedback_message) |
|
""" |
|
try: |
|
|
|
if enhanced_image is None: |
|
return False, "No enhanced image generated" |
|
|
|
|
|
if enhanced_image.size[0] < 100 or enhanced_image.size[1] < 100: |
|
return False, "Enhanced image appears to be too small or improperly sized" |
|
|
|
|
|
|
|
width_diff = abs(original_image.size[0] - enhanced_image.size[0]) |
|
height_diff = abs(original_image.size[1] - enhanced_image.size[1]) |
|
|
|
|
|
if width_diff > 20 or height_diff > 20: |
|
return False, "Enhanced image dimensions differ significantly from original, suggesting facial proportions may have changed" |
|
|
|
|
|
if enhanced_image.mode != 'RGB': |
|
return False, "Enhanced image does not have the correct color mode" |
|
|
|
|
|
return True, "Image passes quality assessment criteria" |
|
except Exception as e: |
|
print(f"Error in quality assessment: {str(e)}") |
|
|
|
return False, f"Assessment error: {str(e)}" |
|
|
|
def compare_image_results(results_list): |
|
""" |
|
Compares multiple generated images and returns the best one. |
|
If no valid results, returns None. |
|
""" |
|
if not results_list or all(img is None for img in results_list): |
|
return None |
|
|
|
|
|
valid_results = [img for img in results_list if img is not None] |
|
|
|
if not valid_results: |
|
return None |
|
|
|
|
|
if len(valid_results) == 1: |
|
return valid_results[0] |
|
|
|
|
|
|
|
|
|
|
|
print(f"Comparing {len(valid_results)} valid results and selecting best one") |
|
return valid_results[-1] |
|
|
|
def process_smile_enhancement(input_image, max_attempts=3): |
|
try: |
|
if input_image is None: |
|
return None, "", "" |
|
|
|
|
|
gemini_api_key = "AIzaSyCVzRDxkuvtaS1B22F_F-zl0ehhXR0nuU8" |
|
if not gemini_api_key: |
|
print("Error: GEMINI_API_KEY not found in environment variables") |
|
return [input_image], "", "API key not configured" |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: |
|
input_path = tmp.name |
|
input_image.save(input_path) |
|
print(f"Input image saved to {input_path}") |
|
|
|
|
|
current_attempt = 0 |
|
result_images = [] |
|
feedback_history = [] |
|
max_processing_time = 150 |
|
start_processing_time = time.time() |
|
|
|
while current_attempt < max_attempts: |
|
|
|
if time.time() - start_processing_time > max_processing_time: |
|
print(f"Overall processing time exceeded {max_processing_time} seconds") |
|
break |
|
|
|
current_attempt += 1 |
|
print(f"Starting processing attempt {current_attempt}/{max_attempts}...") |
|
|
|
|
|
|
|
prompt = """ |
|
Create a naturally enhanced smile that focuses primarily on the overall facial expression rather than perfect teeth. Make the following personalized improvements: |
|
|
|
- Focus on enhancing the OVERALL SMILE EXPRESSION with natural eye crinkles, cheeks, and subtle facial changes |
|
- Create authentic "Duchenne smile" characteristics with proper eye corner crinkles (crow's feet) appropriate for this person's age |
|
- Enhance the natural lifting of cheeks that occurs in genuine smiles WITHOUT widening the face |
|
- Add the characteristic slight narrowing of the eyes that happens in genuine smiles |
|
- Create subtle dimples ONLY if they already exist in the original image |
|
- Boost the overall joyful expression while maintaining the person's unique facial structure |
|
- Maintain natural-looking nasolabial folds (smile lines) consistent with the smile intensity |
|
- Subtly complement existing teeth - they should remain natural looking with their original character |
|
|
|
IMPORTANT GUIDELINES: |
|
- FOCUS ON THE SMILE AS A COMPLETE FACIAL EXPRESSION - not just teeth |
|
- The most important aspects are eye crinkles, cheek raising, and natural facial expressions |
|
- Teeth should be subtly complemented but NOT the main focus of the enhancement |
|
- PRESERVE THE PERSON'S NATURAL DENTAL CHARACTERISTICS - teeth should look like THEIR teeth |
|
- Keep teeth coloration natural and appropriate for the person - avoid any artificial whitening |
|
- Maintain all natural imperfections in tooth alignment that give character to the smile |
|
- Create a genuine, authentic-looking smile that affects the entire face naturally |
|
- ABSOLUTELY CRITICAL: DO NOT widen the face or change face width/shape at all |
|
- Preserve the person's identity completely (extremely important) |
|
- Preserve exact facial proportions of the original image |
|
- Maintain natural-looking results appropriate for the person's age and face structure |
|
- Keep the original background, lighting, and image quality intact |
|
- Ensure the enhanced smile looks natural, genuine, and believable |
|
- Create a smile that looks like a moment of true happiness for THIS specific person |
|
""" |
|
|
|
|
|
if current_attempt > 1 and feedback_history: |
|
prompt += """ |
|
|
|
IMPORTANT FEEDBACK FROM PREVIOUS ATTEMPT: |
|
""" + " ".join(feedback_history) + """ |
|
Please address these issues in this new attempt. |
|
""" |
|
|
|
|
|
print(f"Processing attempt {current_attempt}/{max_attempts}...") |
|
|
|
|
|
api_call_timeout = time.time() + 45 |
|
|
|
try: |
|
|
|
image_path, text_response = generate(text=prompt, file_name=input_path, api_key=gemini_api_key) |
|
|
|
|
|
if time.time() > api_call_timeout: |
|
print("API call timeout occurred") |
|
feedback_history.append("API call timed out, trying again with simplified request.") |
|
continue |
|
|
|
print(f"API response received: Image path: {image_path is not None}, Text length: {len(text_response)}") |
|
|
|
if image_path: |
|
|
|
try: |
|
current_result = Image.open(image_path) |
|
if current_result.mode == "RGBA": |
|
current_result = current_result.convert("RGB") |
|
|
|
print("Successfully loaded generated image for attempt " + str(current_attempt)) |
|
|
|
|
|
is_acceptable, assessment_feedback = assess_image_quality(input_image, current_result) |
|
print(f"Image quality assessment: {is_acceptable}, {assessment_feedback}") |
|
|
|
if is_acceptable: |
|
|
|
result_images.append(current_result) |
|
print(f"Added acceptable result from attempt {current_attempt} to results list") |
|
|
|
|
|
if current_attempt < max_attempts: |
|
feedback_history.append("Previous attempt successful, trying to further improve...") |
|
continue |
|
else: |
|
|
|
feedback_history.append(assessment_feedback) |
|
|
|
|
|
result_images.append(current_result) |
|
|
|
except Exception as img_error: |
|
print(f"Error processing the generated image: {str(img_error)}") |
|
feedback_history.append(f"Error with image: {str(img_error)}") |
|
else: |
|
|
|
print("No image was generated, only text response") |
|
feedback_history.append("No image was generated in the previous attempt.") |
|
except Exception as gen_error: |
|
print(f"Error during generation attempt {current_attempt}: {str(gen_error)}") |
|
feedback_history.append(f"Error during processing: {str(gen_error)}") |
|
|
|
|
|
print(f"All attempts completed. Comparing {len(result_images)} results") |
|
|
|
if result_images: |
|
|
|
best_result = compare_image_results(result_images) |
|
if best_result: |
|
print("Returning best result from multiple attempts") |
|
success_message = "Enhancement completed after multiple attempts to find the best result" |
|
return [best_result], "", success_message |
|
|
|
|
|
print("Returning original image as fallback - no valid results generated") |
|
return [input_image], "", "No satisfactory enhancements could be generated" |
|
except Exception as e: |
|
|
|
print(f"Overall error in process_smile_enhancement: {str(e)}") |
|
return [input_image], "", "" |
|
|
|
|
|
with gr.Blocks(title="Smile Enhancement", css="footer {visibility: hidden} .gradio-container {min-height: 0 !important}") as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image( |
|
type="pil", |
|
label=None, |
|
image_mode="RGB", |
|
elem_classes="upload-box" |
|
) |
|
|
|
submit_btn = gr.Button("Enhance Smile with Natural Expressions", elem_classes="generate-btn") |
|
|
|
with gr.Column(): |
|
output_gallery = gr.Gallery(label=None) |
|
|
|
|
|
feedback_text = gr.Textbox(label=None, visible=True, elem_classes="status-box") |
|
|
|
|
|
output_text = gr.Textbox(visible=False) |
|
|
|
submit_btn.click( |
|
fn=process_smile_enhancement, |
|
inputs=[image_input], |
|
outputs=[output_gallery, output_text, feedback_text] |
|
) |
|
|
|
|
|
demo.queue(max_size=50).launch( |
|
show_api=False, |
|
share=False, |
|
show_error=True, |
|
server_name="0.0.0.0", |
|
) |