πŸ–ΌοΈ Fine-Tuned U-Net (Flickr8k) β€” Stable Diffusion

This model contains a fine-tuned U-Net from the CompVis/stable-diffusion-v1-4 Stable Diffusion pipeline, trained using natural English captions from the Flickr8k dataset. It enhances generation quality for everyday, human-centered scenarios like actions, objects, and environmental scenes.

βœ… Only the U-Net was fine-tuned. The VAE, tokenizer, and text encoder remain from the original base model.


πŸ“Š Training Details

  • Base model: CompVis/stable-diffusion-v1-4
  • Fine-tuned on: Flickr8k Kaggle Dataset
  • Components fine-tuned: unet only
  • Frozen: text encoder, VAE, and tokenizer
  • Epochs: 10
  • Learning rate: 1e-6
  • Batch size: 1 (with gradient accumulation = 16 β†’ effective batch size β‰ˆ 16)
  • Image resolution: 256Γ—256
  • Training size: 1000 image-caption pairs
  • Mixed precision: FP16
  • Gradient Accumulation Steps: 16
  • Trained on: Kaggle GPU (Tesla T4, 16GB VRAM)
  • Seed: 42
  • Checkpointing: every 200 steps

🧠 Usage

This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:

from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel
import torch
import matplotlib.pyplot as plt

# Load base components
print("Loading VAE and text encoder from base SD...")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)

# Load fine-tuned UNet from Hugging Face
print("Loading fine-tuned UNet from Hugging Face (srishticrai/unet-flickr8k)...")
fine_tuned_unet = UNet2DConditionModel.from_pretrained(
    "srishticrai/unet-flickr8k", 
    torch_dtype=torch.float16
).to(device)

# Rebuild the pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    unet=fine_tuned_unet,
    vae=vae,
    text_encoder=text_encoder,
    torch_dtype=torch.float16
).to(device)

pipe.set_progress_bar_config(disable=False)
pipe.enable_attention_slicing()

# Ask for prompt
prompt = input("Enter a prompt to generate an image: ")

# Generate image
image = pipe(
    prompt,
    guidance_scale=10.0,
    num_inference_steps=50
)

image.show()
Downloads last month
9
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for srishticrai/unet-flickr8k

Finetuned
(1127)
this model