Spaces:
Paused
Paused
import os | |
import re | |
from typing import List, Optional | |
import openai | |
from uno.utils.prompt_router import classify_prompt_intent | |
from uno.utils.image_describer import describe_uploaded_images | |
# === OpenAI Client === | |
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
# === Constants === | |
PROMPT_PREFIX = "It's very important that" | |
PROMPT_SUFFIX = ( | |
"and take all the time you needed as it is very important to achieve the best possible result" | |
) | |
TEMPLATE_FORMAT = "prefix [art medium] [main object or objective] [attribute] [expression] [key light] [detailing] suffix" | |
DEFAULT_NUM_PROMPTS = 5 | |
MIN_WORD_COUNT = 8 | |
MAX_PROMPT_LENGTH = 120 | |
REQUIRED_ELEMENTS = [ | |
"art medium", | |
"main object", | |
"attribute", | |
"expression", | |
"key light", | |
"detailing" | |
] | |
INTENT_SYSTEM_INSTRUCTIONS = { | |
"product_ad": "Focus: visually advertise a product using professional commercial photography.", | |
"service_promotion": "Focus: promote a service by illustrating its usage, setting, or emotional impact.", | |
"public_awareness": "Focus: support a cause, campaign, or message with narrative visual storytelling.", | |
"brand_storytelling": "Focus: express a brand's tone or identity using a visual lifestyle story.", | |
"creative_social_post": "Focus: generate stylistic or creative content suitable for social media that maintains core subject clarity.", | |
"fallback": "Fallback: default to showcasing the main product in a visually compelling commercial format." | |
} | |
def build_system_message(intent: str, num_prompts: int, style_hint: Optional[str] = "") -> str: | |
style_clause = ( | |
f"The visual tone, lighting, and environment must match this style: {style_hint}.\n" | |
f"Only override this if the user explicitly requests a different visual style." | |
if style_hint else "" | |
) | |
if intent in INTENT_SYSTEM_INSTRUCTIONS: | |
instruction = INTENT_SYSTEM_INSTRUCTIONS[intent] | |
else: | |
print(f"π§ [DEBUG] Unrecognized intent '{intent}', using dynamic fallback...") | |
instruction = ( | |
f"Focus: Generate prompts suitable for a '{intent}' scenario using descriptive, high-quality visual storytelling. " | |
"The core subject must remain clear and central." | |
) | |
return f""" | |
You are a prompt enhancement assistant for Flux Pro. | |
Your task is to transform a short user input into {num_prompts} full-sentence, professional image generation prompts. | |
Each prompt must follow this structure: | |
{TEMPLATE_FORMAT} | |
Prefix: '{PROMPT_PREFIX}' | |
Suffix: '{PROMPT_SUFFIX}' | |
Each prompt must: | |
- Be under {MAX_PROMPT_LENGTH} words | |
- Include: {", ".join(REQUIRED_ELEMENTS)} | |
- Be a single descriptive sentence | |
- Never use lists, examples, or bullet formatting | |
- Avoid specific color names unless inferred from uploaded images | |
- Do not wrap prompts in quotes or number them | |
All image elements must follow natural physical proportions. Ensure objects intended for interaction appear in realistic size and position relative to the subject and scene. Avoid exaggerated scaling unless the user explicitly asks for surrealism or stylization. | |
{style_clause} | |
{instruction} | |
Do not explain. Only return the prompts. | |
Generate exactly {num_prompts} unique prompts, one per line. | |
""".strip() | |
def enhance_prompt_with_chatgpt( | |
user_prompt: str, | |
num_prompts: int = DEFAULT_NUM_PROMPTS, | |
reference_images: Optional[List] = None | |
) -> List[str]: | |
intent = classify_prompt_intent(user_prompt) | |
blip_data = describe_uploaded_images(reference_images) if reference_images else {} | |
full_caption = blip_data.get("full_caption", "") | |
style_hint = blip_data.get("style_description", "") | |
print(f"\nπ₯ [DEBUG] User prompt: {user_prompt}") | |
if full_caption: | |
print(f"πΌοΈ [DEBUG] BLIP Caption: {full_caption}") | |
if style_hint: | |
print(f"π¨ [DEBUG] Style Description from Image: {style_hint}") | |
user_msg = f"Original prompt: {user_prompt}" | |
if style_hint: | |
user_msg += f"\nVisual reference style: {style_hint}" | |
try: | |
response = client.chat.completions.create( | |
model="gpt-4", | |
messages=[ | |
{"role": "system", "content": build_system_message(intent, num_prompts, style_hint)}, | |
{"role": "user", "content": user_msg} | |
], | |
temperature=0.7, | |
max_tokens=1800 | |
) | |
raw_output = response.choices[0].message.content.strip() | |
print("\nπ [DEBUG] Raw GPT Output:") | |
print(raw_output) | |
candidate_prompts = [p.strip() for p in raw_output.split("\n") if p.strip()] | |
if ( | |
len(candidate_prompts) == 1 and | |
candidate_prompts[0].lower().startswith(PROMPT_PREFIX.lower()) and | |
len(candidate_prompts[0].split()) > MIN_WORD_COUNT | |
): | |
enhanced_prompts = [candidate_prompts[0]] | |
else: | |
enhanced_prompts = [ | |
p for p in candidate_prompts | |
if len(p.split()) > MIN_WORD_COUNT and p.lower().startswith(PROMPT_PREFIX.lower()) | |
] | |
if len(enhanced_prompts) < num_prompts: | |
print(f"β οΈ Only {len(enhanced_prompts)} prompts returned. Padding with user prompt...") | |
enhanced_prompts += [user_prompt] * (num_prompts - len(enhanced_prompts)) | |
elif len(enhanced_prompts) > num_prompts: | |
enhanced_prompts = enhanced_prompts[:num_prompts] | |
print("\nπ§ [DEBUG] ChatGPT Enhanced Prompts:") | |
for idx, p in enumerate(enhanced_prompts): | |
print(f"[{idx+1}] {p}") | |
print("--------------------------------------------------\n") | |
return enhanced_prompts | |
except Exception as e: | |
print(f"β [ERROR] Failed to enhance prompt: {e}") | |
return [user_prompt] * num_prompts | |