File size: 2,500 Bytes
8373fd9
 
 
69f7712
8373fd9
 
ab00f6b
8373fd9
 
ab00f6b
8373fd9
 
 
 
 
 
 
 
 
 
 
 
ab00f6b
8373fd9
ab00f6b
 
8373fd9
 
ab00f6b
8373fd9
 
ab00f6b
8373fd9
ab00f6b
8373fd9
 
 
 
 
 
 
 
ab00f6b
8373fd9
fb6c5e9
8373fd9
 
ab00f6b
8373fd9
 
 
ab00f6b
8373fd9
 
 
ab00f6b
 
8373fd9
 
 
 
ab00f6b
8373fd9
 
 
 
ab00f6b
8373fd9
 
 
ab00f6b
8373fd9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import argparse
from pathlib import Path
from caption import get_system_prompt, get_together_client, extract_captions, MODEL_ID

def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
    """Optimize a user prompt to follow the same format as training captions.
    
    Args:
        user_prompt (str): The simple user prompt to optimize
        captions_dir (str, optional): Directory containing caption .txt files
        captions_list (list, optional): List of captions to use instead of loading from files
    """
    all_captions = []
    if captions_list:
        all_captions = captions_list
    elif captions_dir:
        # Collect all captions from text files in the directory
        captions_path = Path(captions_dir)
        for file_path in captions_path.glob("*.txt"):
            captions = extract_captions(file_path)
            all_captions.extend(captions)

    if not all_captions:
        raise ValueError("Please provide either caption files or a list of captions!")

    # Concatenate all captions with newlines
    captions_text = "\n".join(all_captions)

    client = get_together_client()
    messages = [
        {"role": "system", "content": get_system_prompt()},
        {
            "role": "user",
            "content": (
                f"These are all of the captions used to train the LoRA:\n\n"
                f"{captions_text}\n\n"
                f"Now optimize this prompt to follow the caption format used in training: "
                f"{user_prompt}"
            )
        }
    ]

    response = client.chat.completions.create(
        model=MODEL_ID,
        messages=messages
    )

    optimized_prompt = response.choices[0].message.content.strip()
    return optimized_prompt


def main():
    parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
    parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
    parser.add_argument('--captions', type=str, required=True,help='Directory containing caption .txt files')

    args = parser.parse_args()
    if not os.path.isdir(args.captions):
        print(f"Error: Captions directory '{args.captions}' does not exist.")
        return

    try:
        optimized_prompt = optimize_prompt(args.prompt, args.captions)
        print("\nOptimized Prompt:")
        print(optimized_prompt)

    except Exception as e:
        print(f"Error optimizing prompt: {e}")


if __name__ == "__main__":
    main()