import gradio as gr from PIL import Image import torch import os import json import zipfile from datetime import datetime from diffusers import StableDiffusionXLImg2ImgPipeline from utils.planner import ( extract_scene_plan, generate_prompt_variations_from_scene, generate_negative_prompt_from_scene ) # ---------------------------- # ๐Ÿ’ป Device Configuration # ---------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # ---------------------------- # ๐Ÿง  Load Stable Diffusion XL Img2Img Pipeline # ---------------------------- pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=dtype, use_safetensors=True, variant="fp16" if device == "cuda" else None, ) pipe.to(device) if device == "cuda": pipe.enable_model_cpu_offload() pipe.enable_attention_slicing() # ---------------------------- # ๐ŸŽจ Core Generation Function # ---------------------------- def process_image(prompt, image, num_variations): try: if image is None: raise ValueError("๐Ÿšซ Please upload an image.") print("๐Ÿง  Prompt received:", prompt) scene_plan = extract_scene_plan(prompt, image) enriched_prompts = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations) negative_prompt = generate_negative_prompt_from_scene(scene_plan) image = image.resize((1024, 1024)).convert("RGB") outputs = [] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") out_dir = f"outputs/session_{timestamp}" os.makedirs(out_dir, exist_ok=True) for i, enriched_prompt in enumerate(enriched_prompts): print(f"โœจ Generating Image {i + 1}...") result = pipe( prompt=enriched_prompt, negative_prompt=negative_prompt, image=image, strength=0.7, guidance_scale=7.5, num_inference_steps=30, ) output_img = result.images[0] output_img.save(f"{out_dir}/generated_{i+1}.png") outputs.append(output_img) # Save log log_data = { "timestamp": timestamp, "prompt": prompt, "scene_plan": scene_plan, "enriched_prompts": enriched_prompts, "negative_prompt": negative_prompt, "device": device, "num_variations": num_variations } os.makedirs("logs", exist_ok=True) with open("logs/generation_logs.jsonl", "a") as log_file: log_file.write(json.dumps(log_data) + "\n") # Create ZIP of outputs # Handle single or multiple image download if num_variations == 1: single_img_path = f"{out_dir}/generated_1.png" return outputs, "โœ… Generated one image. Ready for download.", single_img_path else: zip_path = f"{out_dir}/all_images.zip" with zipfile.ZipFile(zip_path, "w") as zipf: for i in range(len(outputs)): img_path = f"{out_dir}/generated_{i+1}.png" zipf.write(img_path, os.path.basename(img_path)) return outputs, f"โœ… Generated {num_variations} images. Download below.", zip_path except Exception as e: print("โŒ Generation failed:", e) return [Image.new("RGB", (512, 512), color="red")], f"โŒ Error: {str(e)}", None # ---------------------------- # ๐Ÿงช Gradio Interface # ---------------------------- with gr.Blocks(title="NewCrux Image-to-Image Generator") as demo: gr.Markdown("### ๐Ÿ–ผ๏ธ NewCrux: Product Lifestyle Visual Generator (SDXL + Prompt AI)\nUpload a product image and describe the visual you want. The system will generate realistic marketing images using AI.") with gr.Row(): prompt = gr.Textbox(label="Prompt", placeholder="e.g., A person running on the beach wearing the product") input_image = gr.Image(type="pil", label="Product Image") num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Variations") generate_btn = gr.Button("๐Ÿš€ Generate Image(s)") output_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=[2], height="auto") output_msg = gr.Textbox(label="Generation Status", interactive=False) download_zip = gr.File(label="โฌ‡๏ธ Download All Images (.zip)", interactive=False) generate_btn.click( fn=process_image, inputs=[prompt, input_image, num_outputs], outputs=[output_gallery, output_msg, download_zip] ) # ---------------------------- # ๐Ÿš€ Launch App # ---------------------------- if __name__ == "__main__": demo.launch()