πΌοΈ Fine-Tuned U-Net (Flickr8k) β Stable Diffusion
This model contains a fine-tuned U-Net from the CompVis/stable-diffusion-v1-4
Stable Diffusion pipeline, trained using natural English captions from the Flickr8k dataset. It enhances generation quality for everyday, human-centered scenarios like actions, objects, and environmental scenes.
β Only the U-Net was fine-tuned. The VAE, tokenizer, and text encoder remain from the original base model.
π Training Details
- Base model:
CompVis/stable-diffusion-v1-4
- Fine-tuned on: Flickr8k Kaggle Dataset
- Components fine-tuned:
unet
only - Frozen: text encoder, VAE, and tokenizer
- Epochs: 10
- Learning rate: 1e-6
- Batch size: 1 (with gradient accumulation = 16 β effective batch size β 16)
- Image resolution: 256Γ256
- Training size: 1000 image-caption pairs
- Mixed precision: FP16
- Gradient Accumulation Steps: 16
- Trained on: Kaggle GPU (Tesla T4, 16GB VRAM)
- Seed: 42
- Checkpointing: every 200 steps
π§ Usage
This U-Net can be loaded into a standard Stable Diffusion pipeline to enhance image generation on descriptive prompts:
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTextModel
import torch
import matplotlib.pyplot as plt
# Load base components
print("Loading VAE and text encoder from base SD...")
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
# Load fine-tuned UNet from Hugging Face
print("Loading fine-tuned UNet from Hugging Face (srishticrai/unet-flickr8k)...")
fine_tuned_unet = UNet2DConditionModel.from_pretrained(
"srishticrai/unet-flickr8k",
torch_dtype=torch.float16
).to(device)
# Rebuild the pipeline
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
unet=fine_tuned_unet,
vae=vae,
text_encoder=text_encoder,
torch_dtype=torch.float16
).to(device)
pipe.set_progress_bar_config(disable=False)
pipe.enable_attention_slicing()
# Ask for prompt
prompt = input("Enter a prompt to generate an image: ")
# Generate image
image = pipe(
prompt,
guidance_scale=10.0,
num_inference_steps=50
)
image.show()
- Downloads last month
- 9
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
π
Ask for provider support
Model tree for srishticrai/unet-flickr8k
Base model
CompVis/stable-diffusion-v1-4