sculpt / app.py
ds1david's picture
Trying fix variants
2718083
raw
history blame
2.71 kB
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from PIL import Image, ImageEnhance, ImageOps
device = "cpu" # or "cuda" if you have a GPU
torch_dtype = torch.float32
print("Loading SDXL Base model...")
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
print("Loading bas-relief LoRA weights with PEFT...")
pipe.load_lora_weights(
"KappaNeuro/bas-relief", # The HF repo with BAS-RELIEF.safetensors
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft" # This is crucial
)
print("Loading DPT Depth Model...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
d_min, d_max = depth_arr.min(), depth_arr.max()
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
depth_stretched = (depth_stretched * 255).astype(np.uint8)
depth_pil = Image.fromarray(depth_stretched)
depth_pil = ImageOps.autocontrast(depth_pil)
enhancer = ImageEnhance.Sharpness(depth_pil)
depth_pil = enhancer.enhance(2.0)
return depth_pil
def generate_bas_relief_and_depth(imagem):
# Use the token "BAS-RELIEF" so the LoRA triggers
full_prompt = f"BAS-RELIEF {prompt}"
print("Generating image with LoRA style...")
result = pipe(
prompt=full_prompt,
image=imagem,
num_inference_steps=15, # reduce if too slow
guidance_scale=7.5,
height=512, # reduce if you still get timeouts
width=512
)
image = result.images[0]
print("Running DPT Depth Estimation...")
inputs = feature_extractor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False
).squeeze()
depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
return image, depth_map_pil
# Interface Gradio
interface = gr.Interface(
fn=generate_bas_relief_and_depth,
inputs=gr.Image(type="pil"),
outputs=[gr.Image(label="Resultado"), gr.Image(label="Profundidade")],
title="Conversor para Baixo-relevo",
description="Transforme imagens em baixo-relevo com mapa de profundidade"
)
if __name__ == "__main__":
interface.launch()