File size: 5,800 Bytes
a4075b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import re
from typing import List, Optional
import openai

from uno.utils.prompt_router import classify_prompt_intent
from uno.utils.image_describer import describe_uploaded_images

# === OpenAI Client ===
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))

# === Constants ===
PROMPT_PREFIX = "It's very important that"
PROMPT_SUFFIX = (
    "and take all the time you needed as it is very important to achieve the best possible result"
)
TEMPLATE_FORMAT = "prefix [art medium] [main object or objective] [attribute] [expression] [key light] [detailing] suffix"
DEFAULT_NUM_PROMPTS = 5
MIN_WORD_COUNT = 8
MAX_PROMPT_LENGTH = 120
REQUIRED_ELEMENTS = [
    "art medium",
    "main object",
    "attribute",
    "expression",
    "key light",
    "detailing"
]

INTENT_SYSTEM_INSTRUCTIONS = {
    "product_ad": "Focus: visually advertise a product using professional commercial photography.",
    "service_promotion": "Focus: promote a service by illustrating its usage, setting, or emotional impact.",
    "public_awareness": "Focus: support a cause, campaign, or message with narrative visual storytelling.",
    "brand_storytelling": "Focus: express a brand's tone or identity using a visual lifestyle story.",
    "creative_social_post": "Focus: generate stylistic or creative content suitable for social media that maintains core subject clarity.",
    "fallback": "Fallback: default to showcasing the main product in a visually compelling commercial format."
}

def build_system_message(intent: str, num_prompts: int, style_hint: Optional[str] = "") -> str:
    style_clause = (
        f"The visual tone, lighting, and environment must match this style: {style_hint}.\n"
        f"Only override this if the user explicitly requests a different visual style."
        if style_hint else ""
    )

    if intent in INTENT_SYSTEM_INSTRUCTIONS:
        instruction = INTENT_SYSTEM_INSTRUCTIONS[intent]
    else:
        print(f"๐Ÿง  [DEBUG] Unrecognized intent '{intent}', using dynamic fallback...")
        instruction = (
            f"Focus: Generate prompts suitable for a '{intent}' scenario using descriptive, high-quality visual storytelling. "
            "The core subject must remain clear and central."
        )

    return f"""
You are a prompt enhancement assistant for Flux Pro.

Your task is to transform a short user input into {num_prompts} full-sentence, professional image generation prompts.

Each prompt must follow this structure:
{TEMPLATE_FORMAT}

Prefix: '{PROMPT_PREFIX}'
Suffix: '{PROMPT_SUFFIX}'

Each prompt must:
- Be under {MAX_PROMPT_LENGTH} words
- Include: {", ".join(REQUIRED_ELEMENTS)}
- Be a single descriptive sentence
- Never use lists, examples, or bullet formatting
- Avoid specific color names unless inferred from uploaded images
- Do not wrap prompts in quotes or number them

All image elements must follow natural physical proportions. Ensure objects intended for interaction appear in realistic size and position relative to the subject and scene. Avoid exaggerated scaling unless the user explicitly asks for surrealism or stylization.

{style_clause}

{instruction}

Do not explain. Only return the prompts.
Generate exactly {num_prompts} unique prompts, one per line.
""".strip()


def enhance_prompt_with_chatgpt(
    user_prompt: str,
    num_prompts: int = DEFAULT_NUM_PROMPTS,
    reference_images: Optional[List] = None
) -> List[str]:
    intent = classify_prompt_intent(user_prompt)

    blip_data = describe_uploaded_images(reference_images) if reference_images else {}
    full_caption = blip_data.get("full_caption", "")
    style_hint = blip_data.get("style_description", "")

    print(f"\n๐Ÿ“ฅ [DEBUG] User prompt: {user_prompt}")
    if full_caption:
        print(f"๐Ÿ–ผ๏ธ [DEBUG] BLIP Caption: {full_caption}")
    if style_hint:
        print(f"๐ŸŽจ [DEBUG] Style Description from Image: {style_hint}")

    user_msg = f"Original prompt: {user_prompt}"
    if style_hint:
        user_msg += f"\nVisual reference style: {style_hint}"

    try:
        response = client.chat.completions.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": build_system_message(intent, num_prompts, style_hint)},
                {"role": "user", "content": user_msg}
            ],
            temperature=0.7,
            max_tokens=1800
        )

        raw_output = response.choices[0].message.content.strip()
        print("\n๐Ÿ“ [DEBUG] Raw GPT Output:")
        print(raw_output)

        candidate_prompts = [p.strip() for p in raw_output.split("\n") if p.strip()]

        if (
            len(candidate_prompts) == 1 and
            candidate_prompts[0].lower().startswith(PROMPT_PREFIX.lower()) and
            len(candidate_prompts[0].split()) > MIN_WORD_COUNT
        ):
            enhanced_prompts = [candidate_prompts[0]]
        else:
            enhanced_prompts = [
                p for p in candidate_prompts
                if len(p.split()) > MIN_WORD_COUNT and p.lower().startswith(PROMPT_PREFIX.lower())
            ]

        if len(enhanced_prompts) < num_prompts:
            print(f"โš ๏ธ Only {len(enhanced_prompts)} prompts returned. Padding with user prompt...")
            enhanced_prompts += [user_prompt] * (num_prompts - len(enhanced_prompts))
        elif len(enhanced_prompts) > num_prompts:
            enhanced_prompts = enhanced_prompts[:num_prompts]

        print("\n๐Ÿ”ง [DEBUG] ChatGPT Enhanced Prompts:")
        for idx, p in enumerate(enhanced_prompts):
            print(f"[{idx+1}] {p}")
        print("--------------------------------------------------\n")

        return enhanced_prompts

    except Exception as e:
        print(f"โŒ [ERROR] Failed to enhance prompt: {e}")
        return [user_prompt] * num_prompts