Spaces:
Running
Running
import os | |
import argparse | |
from pathlib import Path | |
from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID | |
def optimize_prompt(user_prompt, captions_dir=None, captions_list=None): | |
"""Optimize a user prompt to follow the same format as training captions. | |
Args: | |
user_prompt (str): The simple user prompt to optimize | |
captions_dir (str, optional): Directory containing caption .txt files | |
captions_list (list, optional): List of captions to use instead of loading from files | |
""" | |
all_captions = [] | |
if captions_list: | |
all_captions = captions_list | |
elif captions_dir: | |
# Collect all captions from text files in the directory | |
captions_path = Path(captions_dir) | |
for file_path in captions_path.glob("*.txt"): | |
captions = extract_captions(file_path) | |
all_captions.extend(captions) | |
if not all_captions: | |
raise ValueError("Please provide either caption files or a list of captions!") | |
# Concatenate all captions with newlines | |
captions_text = "\n".join(all_captions) | |
client = get_together_client() | |
messages = [ | |
{"role": "system", "content": get_system_prompt()}, | |
{ | |
"role": "user", | |
"content": ( | |
f"These are all of the captions used to train the LoRA:\n\n" | |
f"{captions_text}\n\n" | |
f"Now optimize this prompt to follow the caption format used in training: " | |
f"{user_prompt}" | |
) | |
} | |
] | |
response = client.chat.completions.create( | |
model=MODEL_ID, | |
messages=messages | |
) | |
optimized_prompt = response.choices[0].message.content.strip() | |
return optimized_prompt | |
def main(): | |
parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.') | |
parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize') | |
parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files') | |
args = parser.parse_args() | |
if not os.path.isdir(args.captions): | |
print(f"Error: Captions directory '{args.captions}' does not exist.") | |
return | |
try: | |
optimized_prompt = optimize_prompt(args.prompt, args.captions) | |
print("\nOptimized Prompt:") | |
print(optimized_prompt) | |
except Exception as e: | |
print(f"Error optimizing prompt: {e}") | |
if __name__ == "__main__": | |
main() | |