Spaces:
Paused
Paused
import gradio as gr | |
from PIL import Image | |
import torch | |
import os | |
import json | |
import zipfile | |
from datetime import datetime | |
from diffusers import StableDiffusionXLImg2ImgPipeline | |
from utils.planner import ( | |
extract_scene_plan, | |
generate_prompt_variations_from_scene, | |
generate_negative_prompt_from_scene | |
) | |
# ---------------------------- | |
# π» Device Configuration | |
# ---------------------------- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if device == "cuda" else torch.float32 | |
# ---------------------------- | |
# π§ Load Stable Diffusion XL Img2Img Pipeline | |
# ---------------------------- | |
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
torch_dtype=dtype, | |
use_safetensors=True, | |
variant="fp16" if device == "cuda" else None, | |
) | |
pipe.to(device) | |
if device == "cuda": | |
pipe.enable_model_cpu_offload() | |
pipe.enable_attention_slicing() | |
# ---------------------------- | |
# π¨ Core Generation Function | |
# ---------------------------- | |
def process_image(prompt, image, num_variations): | |
try: | |
if image is None: | |
raise ValueError("π« Please upload an image.") | |
print("π§ Prompt received:", prompt) | |
scene_plan = extract_scene_plan(prompt, image) | |
enriched_prompts = generate_prompt_variations_from_scene(scene_plan, prompt, num_variations) | |
negative_prompt = generate_negative_prompt_from_scene(scene_plan) | |
image = image.resize((1024, 1024)).convert("RGB") | |
outputs = [] | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
out_dir = f"outputs/session_{timestamp}" | |
os.makedirs(out_dir, exist_ok=True) | |
for i, enriched_prompt in enumerate(enriched_prompts): | |
print(f"β¨ Generating Image {i + 1}...") | |
result = pipe( | |
prompt=enriched_prompt, | |
negative_prompt=negative_prompt, | |
image=image, | |
strength=0.7, | |
guidance_scale=7.5, | |
num_inference_steps=30, | |
) | |
output_img = result.images[0] | |
output_img.save(f"{out_dir}/generated_{i+1}.png") | |
outputs.append(output_img) | |
# Save log | |
log_data = { | |
"timestamp": timestamp, | |
"prompt": prompt, | |
"scene_plan": scene_plan, | |
"enriched_prompts": enriched_prompts, | |
"negative_prompt": negative_prompt, | |
"device": device, | |
"num_variations": num_variations | |
} | |
os.makedirs("logs", exist_ok=True) | |
with open("logs/generation_logs.jsonl", "a") as log_file: | |
log_file.write(json.dumps(log_data) + "\n") | |
# Create ZIP of outputs | |
# Handle single or multiple image download | |
if num_variations == 1: | |
single_img_path = f"{out_dir}/generated_1.png" | |
return outputs, "β Generated one image. Ready for download.", single_img_path | |
else: | |
zip_path = f"{out_dir}/all_images.zip" | |
with zipfile.ZipFile(zip_path, "w") as zipf: | |
for i in range(len(outputs)): | |
img_path = f"{out_dir}/generated_{i+1}.png" | |
zipf.write(img_path, os.path.basename(img_path)) | |
return outputs, f"β Generated {num_variations} images. Download below.", zip_path | |
except Exception as e: | |
print("β Generation failed:", e) | |
return [Image.new("RGB", (512, 512), color="red")], f"β Error: {str(e)}", None | |
# ---------------------------- | |
# π§ͺ Gradio Interface | |
# ---------------------------- | |
with gr.Blocks(title="NewCrux Image-to-Image Generator") as demo: | |
gr.Markdown("### πΌοΈ NewCrux: Product Lifestyle Visual Generator (SDXL + Prompt AI)\nUpload a product image and describe the visual you want. The system will generate realistic marketing images using AI.") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Prompt", placeholder="e.g., A person running on the beach wearing the product") | |
input_image = gr.Image(type="pil", label="Product Image") | |
num_outputs = gr.Slider(1, 5, value=3, step=1, label="Number of Variations") | |
generate_btn = gr.Button("π Generate Image(s)") | |
output_gallery = gr.Gallery(label="Generated Images", show_label=True, columns=[2], height="auto") | |
output_msg = gr.Textbox(label="Generation Status", interactive=False) | |
download_zip = gr.File(label="β¬οΈ Download All Images (.zip)", interactive=False) | |
generate_btn.click( | |
fn=process_image, | |
inputs=[prompt, input_image, num_outputs], | |
outputs=[output_gallery, output_msg, download_zip] | |
) | |
# ---------------------------- | |
# π Launch App | |
# ---------------------------- | |
if __name__ == "__main__": | |
demo.launch() | |