File size: 6,098 Bytes
273708c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
import argparse
import numpy as np
import os
import time # Import the time module
# Import the model architecture from train.py
from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def clean_image(image, threshold=0.75):
"""
Clean up the image by setting pixels with opacity <= threshold to 0% opacity
and pixels above the threshold to 100% visibility.
"""
np_image = np.array(image)
alpha_channel = np_image[:, :, 3]
alpha_channel[alpha_channel <= int(threshold * 255)] = 0
alpha_channel[alpha_channel > int(threshold * 255)] = 255 # Set to 100% visibility
return Image.fromarray(np_image)
def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
# Encode text prompt using BERT tokenizer
encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
input_ids = encoded_input['input_ids'].to(device)
attention_mask = encoded_input['attention_mask'].to(device)
# Generate text encoding
with torch.no_grad():
text_encoding = model.text_encoder(input_ids, attention_mask)
# Sample from the latent space
z = torch.randn(1, LATENT_DIM).to(device)
# Generate image
with torch.no_grad():
generated_image = model.decode(z, text_encoding)
if input_image is not None:
input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
generated_image = img_control * input_image + (1 - img_control) * generated_image
# Convert the generated tensor to a PIL Image
generated_image = generated_image.squeeze(0).cpu()
generated_image = (generated_image + 1) / 2 # Rescale from [-1, 1] to [0, 1]
generated_image = generated_image.clamp(0, 1)
generated_image = transforms.ToPILImage()(generated_image)
return generated_image
def main():
parser = argparse.ArgumentParser(description="Generate an image from a text prompt using the trained CVAE model(s).")
parser.add_argument("--prompt", type=str, help="Text prompt for image generation")
parser.add_argument("--prompt_file", type=str, help="File containing prompts, one per line")
parser.add_argument("--output", type=str, default="generated_images", help="Output directory or file for generated images")
parser.add_argument("--model_paths", type=str, nargs='*', help="Paths to the trained model(s)")
parser.add_argument("--model_path", type=str, help="Path to a single trained model")
parser.add_argument("--clean", action="store_true", help="Clean up the image by removing low opacity pixels")
parser.add_argument("--size", type=int, default=16, help="Size of the generated image")
parser.add_argument("--input_image", type=str, help="Path to the input image for img2img generation")
parser.add_argument("--img_control", type=float, default=0.5, help="Control how much the input image influences the output (0 to 1)")
args = parser.parse_args()
if not args.prompt and not args.prompt_file:
parser.error("Either --prompt or --prompt_file must be provided")
if args.model_paths and args.model_path:
parser.error("Specify either --model_paths or --model_path, not both")
model_paths = args.model_paths if args.model_paths else [args.model_path]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Check if --output is a file or directory
is_folder_output = os.path.isdir(args.output)
if is_folder_output:
# Ensure output directory exists if it's not a file
os.makedirs(args.output, exist_ok=True)
# Load input image if provided
input_image = None
if args.input_image:
input_image = Image.open(args.input_image).convert("RGBA")
# Process single prompt or batch of prompts
if args.prompt:
prompts = [args.prompt]
else:
with open(args.prompt_file, 'r') as f:
prompts = [line.strip() for line in f if line.strip()]
for model_path in model_paths:
# Initialize model
text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
model = CVAE(text_encoder).to(device)
# Load the trained model
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
model_name = os.path.splitext(os.path.basename(model_path))[0]
for i, prompt in enumerate(prompts):
start_time = time.time() # Start timing the generation
# Generate image from prompt
generated_image = generate_image(model, prompt, device, input_image, args.img_control)
# End timing the generation
end_time = time.time()
generation_time = end_time - start_time # Calculate the generation time
# Clean up the image if the flag is set
if args.clean:
generated_image = clean_image(generated_image)
# Resize the generated image
generated_image = generated_image.resize((args.size, args.size), resample=Image.NEAREST)
if not is_folder_output:
# Save the generated image to the specified file
output_file = args.output
else:
# Save the generated image to the output directory
output_file = os.path.join(args.output, f"{model_name}_{prompt}_{i:03d}.png")
generated_image.save(output_file)
print(f"Generated image for prompt '{prompt}' using model '{model_name}' saved as {output_file}")
print(f"Generation time: {generation_time:.10f} seconds") # Print the generation time
if __name__ == "__main__":
main() |