LoRACaptioner / prompt.py
Rishi Desai
first patch at fixing manual entry
69f7712
import os
import argparse
from pathlib import Path
from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID
def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
"""Optimize a user prompt to follow the same format as training captions.
Args:
user_prompt (str): The simple user prompt to optimize
captions_dir (str, optional): Directory containing caption .txt files
captions_list (list, optional): List of captions to use instead of loading from files
"""
all_captions = []
if captions_list:
all_captions = captions_list
elif captions_dir:
# Collect all captions from text files in the directory
captions_path = Path(captions_dir)
for file_path in captions_path.glob("*.txt"):
captions = extract_captions(file_path)
all_captions.extend(captions)
if not all_captions:
raise ValueError("Please provide either caption files or a list of captions!")
# Concatenate all captions with newlines
captions_text = "\n".join(all_captions)
client = get_together_client()
messages = [
{"role": "system", "content": get_system_prompt()},
{
"role": "user",
"content": (
f"These are all of the captions used to train the LoRA:\n\n"
f"{captions_text}\n\n"
f"Now optimize this prompt to follow the caption format used in training: "
f"{user_prompt}"
)
}
]
response = client.chat.completions.create(
model=MODEL_ID,
messages=messages
)
optimized_prompt = response.choices[0].message.content.strip()
return optimized_prompt
def main():
parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files')
args = parser.parse_args()
if not os.path.isdir(args.captions):
print(f"Error: Captions directory '{args.captions}' does not exist.")
return
try:
optimized_prompt = optimize_prompt(args.prompt, args.captions)
print("\nOptimized Prompt:")
print(optimized_prompt)
except Exception as e:
print(f"Error optimizing prompt: {e}")
if __name__ == "__main__":
main()