File size: 11,369 Bytes
3d274c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b7d28
 
 
 
 
 
 
3d274c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70b7d28
 
3d274c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
import argparse
import os
import asyncio
import fal_client
import base64
import io
from PIL import Image
import requests
import shutil
from together import Together

# Create a permanent directory for outputs
OUTPUT_DIR = "output"
os.makedirs(OUTPUT_DIR, exist_ok=True)

def get_next_dir_number():
    """Get the next available directory number for output."""
    existing_dirs = [d for d in os.listdir(OUTPUT_DIR) 
                    if os.path.isdir(os.path.join(OUTPUT_DIR, d)) and d.isdigit()]
    if not existing_dirs:
        return 1
    return max(map(int, existing_dirs)) + 1

def save_results(input_path, generated_image_path, video_path, user_prompt, optimized_prompt, output_dir=None):
    """
    Save all generation results in a numbered directory within OUTPUT_DIR.
    
    Args:
        input_path: Path to the input reference image
        generated_image_path: Path to the generated image
        video_path: Path to the generated video
        user_prompt: The original text prompt used for generation
        optimized_prompt: The optimized prompt used for generation
        output_dir: Optional custom output directory
        
    Returns:
        Tuple of (result_dir, saved_video_path)
    """
    # If no custom output directory, create a numbered one
    if output_dir is None:
        dir_num = get_next_dir_number()
        result_dir = os.path.join(OUTPUT_DIR, str(dir_num))
    else:
        result_dir = output_dir
    
    os.makedirs(result_dir, exist_ok=True)
    
    # Copy input image
    input_image_path = os.path.join(result_dir, "input_image.png")
    shutil.copy2(input_path, input_image_path)
    
    # Copy generated image
    output_image_path = os.path.join(result_dir, "generated_image.png")
    shutil.copy2(generated_image_path, output_image_path)
    
    # Copy the video file
    saved_video_path = os.path.join(result_dir, "generated_video.mp4")
    shutil.copy2(video_path, saved_video_path)
    
    # Store the user prompt in a text file
    with open(os.path.join(result_dir, "input_prompt.txt"), "w") as f:
        f.write(user_prompt)
    
    # Store the optimized prompt in a text file
    with open(os.path.join(result_dir, "opt_prompt.txt"), "w") as f:
        f.write(optimized_prompt)
    
    print(f"All results saved to directory: {result_dir}")
    return result_dir, saved_video_path

async def generate_image(ref_image, prompt):
    print(f"Generating image")
    
    handler = await fal_client.submit_async(
        "fal-ai/flux-pulid",
        arguments={
            "prompt": prompt,
            "reference_image_url": ref_image
        },
    )

    # Wait for completion silently
    async for _ in handler.iter_events():
        pass

    result = await handler.get()
    return result

async def generate_video(image_path, prompt):
    print(f"Generating video from image...'")
    
    # Read the image file and convert to base64
    with open(image_path, 'rb') as image_file:
        image_data = image_file.read()
        base64_image = base64.b64encode(image_data).decode('utf-8')
        image_data_url = f"data:image/png;base64,{base64_image}"
    
    handler = await fal_client.submit_async(
        "fal-ai/wan-i2v",
        arguments={
            "prompt": prompt,
            "image_url": image_data_url,
            "resolution": "480p",
            "guide_scale": 6.5,
            "shift": 4.5,
            "enable_prompt_expansion": True,
            "acceleration": "regular",
            "aspect_ratio": "auto"
        },
    )

    # Wait for completion silently
    async for _ in handler.iter_events():
        pass

    # Get the request ID from the handler
    request_id = handler.request_id
    
    # Fetch the result using the request ID
    result = fal_client.result("fal-ai/wan-i2v", request_id)
    return result

async def optimize_prompt(ref_image_path, user_prompt):
    print(f"Optimizing prompt...")
    
    # Initialize Together AI client
    client = Together()
    
    # Read and encode the image
    with open(ref_image_path, 'rb') as image_file:
        image_data = base64.b64encode(image_file.read()).decode('utf-8')
    
    # First get a detailed caption of the image
    messages = [
        {"role": "system", "content": "You are an expert at describing images in detail, focusing on clothing, accessories, poses, and visual attributes."},
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}},
                {"type": "text", "text": "Describe this image in detail, focusing on the clothing, accessories, pose, and any distinctive visual features."}
            ]
        }
    ]
    
    # Get image description from Llama 4
    response = client.chat.completions.create(
        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        messages=messages,
        max_tokens=500
    )
    
    image_description = response.choices[0].message.content
    
    # Now combine the user prompt with the image description
    prompt_messages = [
        {"role": "system", "content": "You are an expert at combining user prompts with detailed image descriptions to create optimal prompts for image generation. Focus on maintaining visual consistency while incorporating the user's desired changes. IMPORTANT: Return ONLY the optimized prompt without any explanations or additional text."},
        {"role": "user", "content": f"""Here is a detailed description of the reference image:
{image_description}

And here is what the user wants to do with it:
{user_prompt}

Create an optimal prompt that maintains the visual details (especially clothing and accessories) while incorporating the user's desired changes. The prompt should be direct and descriptive. Return ONLY the prompt without any explanations."""}
    ]
    
    # Get optimized prompt
    response = client.chat.completions.create(
        model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
        messages=prompt_messages,
        max_tokens=500
    )
    
    optimized_prompt = response.choices[0].message.content.strip()
    print(f"Original prompt: {user_prompt}")
    print(f"Optimized prompt: {optimized_prompt}")
    
    return optimized_prompt

async def process_async(ref, prompt, output):
    print(f"Processing image:")
    
    # If ref is a URL, download it first
    if ref.startswith('http'):
        response = requests.get(ref)
        temp_image_path = os.path.join(output, 'temp_ref_image.png')
        with open(temp_image_path, 'wb') as f:
            f.write(response.content)
        ref_path = temp_image_path
    else:
        # If ref is a data URL, decode it and save
        if ref.startswith('data:image'):
            base64_data = ref.split(',')[1]
            image_bytes = base64.b64decode(base64_data)
            temp_image_path = os.path.join(output, 'temp_ref_image.png')
            with open(temp_image_path, 'wb') as f:
                f.write(image_bytes)
            ref_path = temp_image_path
        else:
            ref_path = ref
    
    # Optimize the prompt using Together AI
    optimized_prompt = await optimize_prompt(ref_path, prompt)
    
    # Generate image using text+image with optimized prompt
    result = await generate_image(ref, optimized_prompt)
    
    # Save the result
    if result and 'images' in result and len(result['images']) > 0:
        # Get the first image
        image_data = result['images'][0]
        
        # Handle base64 encoded images
        if isinstance(image_data, str) and image_data.startswith('data:image'):
            base64_data = image_data.split(',')[1]
            image_bytes = base64.b64decode(base64_data)
            image = Image.open(io.BytesIO(image_bytes))
        # Handle URL responses
        elif isinstance(image_data, dict) and 'url' in image_data:
            response = requests.get(image_data['url'])
            image = Image.open(io.BytesIO(response.content))
        else:
            print(f"Unexpected image format in response: {type(image_data)}")
            return None
        
        # Save the image
        output_filename = os.path.join(output, 'generated_image.png')
        image.save(output_filename)
        print(f"Generated image saved to: {output_filename}")
        
        # Generate video from the saved image using the original prompt
        video_result = await generate_video(output_filename, prompt)
        
        # Save the video if available
        if video_result and isinstance(video_result, dict) and 'video' in video_result:
            video_url = video_result['video']['url']
            video_response = requests.get(video_url)
            if video_response.status_code == 200:
                video_filename = os.path.join(output, 'generated_video.mp4')
                with open(video_filename, 'wb') as f:
                    f.write(video_response.content)
                print(f"Generated video saved to: {video_filename}")
                
                # Save the results to a numbered directory if output is not already a numbered directory
                if output != os.path.join(OUTPUT_DIR, str(get_next_dir_number() - 1)):
                    result_dir, saved_video_path = save_results(
                        ref_path, output_filename, video_filename, prompt, optimized_prompt
                    )
                    return result, output_filename, saved_video_path
                
                return result, output_filename, video_filename
            else:
                print(f"Failed to download video. Status code: {video_response.status_code}")
        else:
            print("Error: No video URL in response")
        
        return result, output_filename, None
    else:
        print("Error: Failed to generate image")
        return None

def process(ref, prompt, output):
    return asyncio.run(process_async(ref, prompt, output))

def main():
    # Set up command line argument parsing
    parser = argparse.ArgumentParser(description='Process an image with a text prompt and generate a video')
    parser.add_argument('--ref', type=str, required=True, help='URL or path to the reference image')
    parser.add_argument('--prompt', type=str, required=True, help='Text prompt')
    parser.add_argument('--output', type=str, default=None, help='Optional custom output directory. If not provided, a numbered directory will be created.')
    
    # Parse arguments
    args = parser.parse_args()
    
    # Determine output directory
    if args.output:
        output_dir = args.output
        os.makedirs(output_dir, exist_ok=True)
        print(f"Using custom output directory: {output_dir}")
    else:
        # Create a temporary processing directory
        temp_dir = os.path.join(OUTPUT_DIR, "temp")
        os.makedirs(temp_dir, exist_ok=True)
        output_dir = temp_dir
    
    # Print the provided arguments
    print(f"Reference image: {args.ref}")
    print(f"Text prompt: {args.prompt}")
    
    # Process the image and generate video
    result, image_path, video_path = process(args.ref, args.prompt, output_dir)
    
    if result and image_path and video_path:
        print("Processing complete")
        return 0
    else:
        print("Processing failed")
        return 1

if __name__ == "__main__":
    exit(main())