File size: 6,098 Bytes
273708c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from transformers import BertTokenizer, BertModel
import argparse
import numpy as np
import os
import time  # Import the time module

# Import the model architecture from train.py
from train import CVAE, TextEncoder, LATENT_DIM, HIDDEN_DIM

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def clean_image(image, threshold=0.75):
    """
    Clean up the image by setting pixels with opacity <= threshold to 0% opacity
    and pixels above the threshold to 100% visibility.
    """
    np_image = np.array(image)
    alpha_channel = np_image[:, :, 3]
    alpha_channel[alpha_channel <= int(threshold * 255)] = 0
    alpha_channel[alpha_channel > int(threshold * 255)] = 255  # Set to 100% visibility
    return Image.fromarray(np_image)

def generate_image(model, text_prompt, device, input_image=None, img_control=0.5):
    # Encode text prompt using BERT tokenizer
    encoded_input = tokenizer(text_prompt, padding=True, truncation=True, return_tensors="pt")
    input_ids = encoded_input['input_ids'].to(device)
    attention_mask = encoded_input['attention_mask'].to(device)
    
    # Generate text encoding
    with torch.no_grad():
        text_encoding = model.text_encoder(input_ids, attention_mask)
    
    # Sample from the latent space
    z = torch.randn(1, LATENT_DIM).to(device)
    
    # Generate image
    with torch.no_grad():
        generated_image = model.decode(z, text_encoding)
    
    if input_image is not None:
        input_image = input_image.convert("RGBA").resize((16, 16), resample=Image.NEAREST)
        input_image = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
        generated_image = img_control * input_image + (1 - img_control) * generated_image
    
    # Convert the generated tensor to a PIL Image
    generated_image = generated_image.squeeze(0).cpu()
    generated_image = (generated_image + 1) / 2  # Rescale from [-1, 1] to [0, 1]
    generated_image = generated_image.clamp(0, 1)
    generated_image = transforms.ToPILImage()(generated_image)
    
    return generated_image

def main():
    parser = argparse.ArgumentParser(description="Generate an image from a text prompt using the trained CVAE model(s).")
    parser.add_argument("--prompt", type=str, help="Text prompt for image generation")
    parser.add_argument("--prompt_file", type=str, help="File containing prompts, one per line")
    parser.add_argument("--output", type=str, default="generated_images", help="Output directory or file for generated images")
    parser.add_argument("--model_paths", type=str, nargs='*', help="Paths to the trained model(s)")
    parser.add_argument("--model_path", type=str, help="Path to a single trained model")
    parser.add_argument("--clean", action="store_true", help="Clean up the image by removing low opacity pixels")
    parser.add_argument("--size", type=int, default=16, help="Size of the generated image")
    parser.add_argument("--input_image", type=str, help="Path to the input image for img2img generation")
    parser.add_argument("--img_control", type=float, default=0.5, help="Control how much the input image influences the output (0 to 1)")
    args = parser.parse_args()

    if not args.prompt and not args.prompt_file:
        parser.error("Either --prompt or --prompt_file must be provided")
    
    if args.model_paths and args.model_path:
        parser.error("Specify either --model_paths or --model_path, not both")
    
    model_paths = args.model_paths if args.model_paths else [args.model_path]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Check if --output is a file or directory
    is_folder_output = os.path.isdir(args.output)

    if is_folder_output:
        # Ensure output directory exists if it's not a file
        os.makedirs(args.output, exist_ok=True)

    # Load input image if provided
    input_image = None
    if args.input_image:
        input_image = Image.open(args.input_image).convert("RGBA")

    # Process single prompt or batch of prompts
    if args.prompt:
        prompts = [args.prompt]
    else:
        with open(args.prompt_file, 'r') as f:
            prompts = [line.strip() for line in f if line.strip()]

    for model_path in model_paths:
        # Initialize model
        text_encoder = TextEncoder(hidden_size=HIDDEN_DIM, output_size=HIDDEN_DIM)
        model = CVAE(text_encoder).to(device)
        
        # Load the trained model
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        model_name = os.path.splitext(os.path.basename(model_path))[0]

        for i, prompt in enumerate(prompts):
            start_time = time.time()  # Start timing the generation
            
            # Generate image from prompt
            generated_image = generate_image(model, prompt, device, input_image, args.img_control)
            
            # End timing the generation
            end_time = time.time()
            generation_time = end_time - start_time  # Calculate the generation time

            # Clean up the image if the flag is set
            if args.clean:
                generated_image = clean_image(generated_image)
            
            # Resize the generated image
            generated_image = generated_image.resize((args.size, args.size), resample=Image.NEAREST)
            
            if not is_folder_output:
                # Save the generated image to the specified file
                output_file = args.output
            else:
                # Save the generated image to the output directory
                output_file = os.path.join(args.output, f"{model_name}_{prompt}_{i:03d}.png")
            
            generated_image.save(output_file)
            print(f"Generated image for prompt '{prompt}' using model '{model_name}' saved as {output_file}")
            print(f"Generation time: {generation_time:.10f} seconds")  # Print the generation time

if __name__ == "__main__":
    main()