Spaces:
Running
Running
File size: 2,500 Bytes
8373fd9 69f7712 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 fb6c5e9 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 ab00f6b 8373fd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import argparse
from pathlib import Path
from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID
def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
"""Optimize a user prompt to follow the same format as training captions.
Args:
user_prompt (str): The simple user prompt to optimize
captions_dir (str, optional): Directory containing caption .txt files
captions_list (list, optional): List of captions to use instead of loading from files
"""
all_captions = []
if captions_list:
all_captions = captions_list
elif captions_dir:
# Collect all captions from text files in the directory
captions_path = Path(captions_dir)
for file_path in captions_path.glob("*.txt"):
captions = extract_captions(file_path)
all_captions.extend(captions)
if not all_captions:
raise ValueError("Please provide either caption files or a list of captions!")
# Concatenate all captions with newlines
captions_text = "\n".join(all_captions)
client = get_together_client()
messages = [
{"role": "system", "content": get_system_prompt()},
{
"role": "user",
"content": (
f"These are all of the captions used to train the LoRA:\n\n"
f"{captions_text}\n\n"
f"Now optimize this prompt to follow the caption format used in training: "
f"{user_prompt}"
)
}
]
response = client.chat.completions.create(
model=MODEL_ID,
messages=messages
)
optimized_prompt = response.choices[0].message.content.strip()
return optimized_prompt
def main():
parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files')
args = parser.parse_args()
if not os.path.isdir(args.captions):
print(f"Error: Captions directory '{args.captions}' does not exist.")
return
try:
optimized_prompt = optimize_prompt(args.prompt, args.captions)
print("\nOptimized Prompt:")
print(optimized_prompt)
except Exception as e:
print(f"Error optimizing prompt: {e}")
if __name__ == "__main__":
main()
|