File size: 9,189 Bytes
c9dac35
 
 
a446ad0
c9dac35
fb6c5e9
524c601
c9dac35
ab00f6b
5d95766
c9dac35
dc6215b
 
 
ebbd273
 
dc6215b
 
 
c9dac35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc6215b
ebbd273
c9dac35
 
 
 
 
 
 
 
 
ebbd273
c9dac35
dc6215b
ebbd273
 
 
 
dc6215b
ebbd273
c9dac35
 
ab00f6b
dc6215b
 
 
 
ab00f6b
dd4f3b1
 
c9dac35
 
 
 
 
 
dd4f3b1
 
ab00f6b
dd4f3b1
 
a446ad0
 
ab00f6b
dd4f3b1
a446ad0
ab00f6b
 
 
ebbd273
 
 
 
dd4f3b1
 
c9dac35
ab00f6b
dd4f3b1
 
 
ab00f6b
dd4f3b1
 
 
 
 
dc6215b
dd4f3b1
 
ab00f6b
dd4f3b1
 
fb6c5e9
dd4f3b1
 
ab00f6b
dd4f3b1
 
 
ab00f6b
dd4f3b1
 
ab00f6b
dd4f3b1
ebbd273
dd4f3b1
 
ab00f6b
dd4f3b1
 
ab00f6b
 
dd4f3b1
 
ab00f6b
 
 
 
dd4f3b1
 
ab00f6b
 
dd4f3b1
ab00f6b
dd4f3b1
 
 
fb6c5e9
dd4f3b1
 
 
 
ab00f6b
dd4f3b1
 
 
 
ab00f6b
dd4f3b1
 
ab00f6b
 
ebbd273
 
ab00f6b
dd4f3b1
 
ab00f6b
 
dd4f3b1
ab00f6b
dd4f3b1
c9dac35
 
ab00f6b
dd4f3b1
 
 
ebbd273
dd4f3b1
ab00f6b
dd4f3b1
 
ab00f6b
dd4f3b1
 
 
ebbd273
dd4f3b1
 
 
ab00f6b
dd4f3b1
 
ab00f6b
dd4f3b1
 
 
ab00f6b
dd4f3b1
ab00f6b
dd4f3b1
ab00f6b
dd4f3b1
 
 
ab00f6b
a446ad0
 
 
 
ebbd273
a446ad0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import base64
import io
import os
from together import Together

MODEL_ID = "meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8"
TRIGGER_WORD = "tr1gg3r"

def get_system_prompt():
    return f"""Automated Image Captioning (for LoRA Training)

Role: You are an expert AI captioning system generating precise, structured descriptions for character images optimized for LoRA model training in Stable Diffusion and Flux.1-dev.

IMPORTANT: You MUST follow these rules EXACTLY:
1. EVERY caption MUST start with the word "{TRIGGER_WORD}" (exactly like this, no variations)
2. You MUST use the exact format: {TRIGGER_WORD} [Style], [Notable Visual Features], [Clothing], [Pose], [Expression/Mood], [Background/Setting], [Lighting], [Camera Angle]
4. DO NOT use bullet points, lists, or any other formatting
5. DO NOT include any text before or after the caption
6. If you don't follow this format exactly, the caption will be rejected

General Guidelines:
1. Prioritize Consistency – Maintain uniform descriptions across all images in a dataset. Avoid introducing variation in features that should remain constant (e.g., fixed traits like eye color, hair color, or markings that are inherently part of the concept and handled during model training).
2. Concise and Structured – Only describe visible and significant visual attributes. Use a standardized format for clarity and efficiency.
3. Omit Subjective Language – Do not evaluative or emotional descriptors like "beautiful" or "scary."
4. Focus on Key Visual Cues – Clearly describe clothing, accessories, pose, facial expression, lighting, and camera angle. Mention distinctive features only if variable (e.g., visible scar, glasses).
5. Adapt to Visual Style – Clearly state the artistic style: "anime-style," "photorealistic," "3D-rendered," etc.
6. Standard Format – Ensure captions follow a consistent structure.
7. Remain Objective – Do not reference known characters, franchises, or people, even if recognizable. Describe only what is visually present.

Avoid Describing These Unless Variable Across Dataset or Uncertain from Concept:
- Eye color
- Hair color
- Skin tone
- Tattoos or markings if core to the concept
- Known accessories that always appear (unless outfit-specific)

Caption Format (MUST FOLLOW EXACTLY):  
{TRIGGER_WORD} [Style], [Notable Visual Features], [Clothing], [Pose], [Expression/Mood], [Background/Setting], [Lighting], [Camera Angle]

Captioning Principles:
- Emphasize visual variation and context-specific details (outfit, pose, lighting, expression, camera angle).
- Exclude constant traits unless clearly deviating.
- Include visible clothing and accessories.
- Clearly define pose and facial expression (neutral, smiling, aggressive).
- Specify lighting conditions (soft lighting, harsh shadows, glowing backlight).
- Explicitly state camera angle (e.g., front view, right side profile, low-angle, high-angle, overhead).
- Avoid mentioning real or fictional identities.
- Always prefix with the trigger word "{TRIGGER_WORD}."

Examples (MUST FOLLOW THIS EXACT FORMAT):
{TRIGGER_WORD} photorealistic, combat gear, tactical vest and gloves, standing in profile, neutral, empty room, overcast lighting, side profile
{TRIGGER_WORD} 3D-rendered, digital patterns, hooded cloak, seated cross-legged, calm, meditation chamber, low ambient lighting, front view
{TRIGGER_WORD} anime-style, school uniform with blue necktie, standing with arms behind back, gentle smile, classroom, soft daylight, three-quarter view
{TRIGGER_WORD} photorealistic, long trench coat and combat boots, walking, determined, rain-soaked street, dramatic shadows, low-angle view

REMEMBER: Your response must be a single line starting with "{TRIGGER_WORD}" and following the exact format above. No additional text, formatting, or explanations are allowed.
"""


class CaptioningError(Exception):
    """Exception raised for errors in the captioning process."""
    pass


def images_to_base64(images):
    """Convert a list of PIL images to base64 encoded strings."""
    image_strings = []
    for image in images:
        buffered = io.BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        image_strings.append(img_str)
    return image_strings


def get_together_client():
    """Initialize and return the Together API client."""
    api_key = os.environ.get("TOGETHER_API_KEY")
    if not api_key:
        raise ValueError("TOGETHER_API_KEY not set!")
    return Together(api_key=api_key)


def extract_caption(line):
    """Extract caption from a line of text."""
    if TRIGGER_WORD in line:
        # If caption doesn't start with trigger_word but contains it, extract just that part
        if not line.startswith(TRIGGER_WORD):
            return line[line.index(TRIGGER_WORD):]
        return line
    return ""


def caption_single_image(client, img_str):
    """Process and caption a single image."""
    messages = [
        {"role": "system", "content": get_system_prompt()},
        {
            "role": "user",
            "content": [
                {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}},
                {"type": "text", "text": "Caption this image."}
            ]
        }
    ]

    # Request caption for the image using Llama 4 Maverick
    response = client.chat.completions.create(
        model=MODEL_ID,
        messages=messages
    )

    full_response = response.choices[0].message.content.strip()
    caption = ""
    for line in full_response.splitlines():
        caption = extract_caption(line)
        if caption:
            break

    if not caption:
        error_msg = f"Failed to extract a valid caption (containing '{TRIGGER_WORD}') from the response"
        error_msg += f"\n\nActual response:\n{full_response}"
        raise CaptioningError(error_msg)

    return caption


def caption_image_batch(client, image_strings, category):
    """Process and caption multiple images in a single batch request."""
    # Create a content array with all images
    content = [{"type": "text",
                "text": f"Here is the batch of images for {category}. "
                        f"Caption each image on a separate line."}]

    for i, img_str in enumerate(image_strings):
        content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_str}"}})
        content.append({"type": "text", "text": f"Image {i + 1}"})

    messages = [
        {"role": "system", "content": get_system_prompt()},
        {"role": "user", "content": content}
    ]
    response = client.chat.completions.create(
        model=MODEL_ID,
        messages=messages
    )
    return process_batch_response(response, image_strings)


def process_batch_response(response, image_strings):
    """Process the API response from a batch request and extract captions."""
    full_response = response.choices[0].message.content.strip()
    lines = full_response.splitlines()

    # Extract captions from the response
    image_count = len(image_strings)
    captions = [""] * image_count

    # Extract lines that start with or contain trigger_word
    caption_lines = [line for line in lines if TRIGGER_WORD in line]

    # Assign captions to images
    for i in range(image_count):
        if i < len(caption_lines):
            caption = extract_caption(caption_lines[i])
            captions[i] = caption

    validate_batch_captions(captions, image_count, full_response)
    return captions


def validate_batch_captions(captions, image_count, full_response):
    """Validate captions extracted from a batch response."""
    # Check if all captions are empty or don't contain the trigger word
    valid_captions = [c for c in captions if c and TRIGGER_WORD in c]
    if not valid_captions:
        error_msg = "Failed to parse any valid captions from batch response."
        error_msg += f"\n\nActual response:\n{full_response}"
        raise CaptioningError(error_msg)

    # Check if some captions are missing
    if len(valid_captions) < image_count:
        missing_count = image_count - len(valid_captions)
        invalid_captions = [(i, c) for i, c in enumerate(captions) if not c or TRIGGER_WORD not in c]
        error_msg = f"Failed to parse captions for {missing_count} of {image_count} images in batch mode"
        error_msg += "\n\nMalformed captions:"
        for idx, caption in invalid_captions:
            error_msg += f"\nImage {idx + 1}: '{caption}'"
        raise CaptioningError(error_msg)


def caption_images(images, category=None, batch_mode=False):
    """Caption a list of images, either individually or in batch mode."""
    image_strings = images_to_base64(images)

    client = get_together_client()

    if batch_mode and category:
        return caption_image_batch(client, image_strings, category)
    else:
        return [caption_single_image(client, img_str) for img_str in image_strings]


def extract_captions(file_path):
    captions = []
    with open(file_path, 'r') as file:
        for line in file:
            if line.startswith(TRIGGER_WORD):
                captions.append(line.strip())
    return captions