Rishi Desai commited on
Commit
8373fd9
·
1 Parent(s): dd4f3b1

adding prompt opt

Browse files
Files changed (1) hide show
  1. prompt.py +82 -0
prompt.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ from pathlib import Path
4
+ from caption import get_prompt, get_together_client, extract_captions
5
+
6
+ def optimize_prompt(user_prompt, captions_dir=None, captions_list=None):
7
+ """
8
+ Optimize a user prompt to follow the same format as training captions.
9
+
10
+ Args:
11
+ user_prompt (str): The simple user prompt to optimize (e.g., "woman riding a bike")
12
+ captions_dir (str, optional): Directory containing caption .txt files
13
+ captions_list (list, optional): List of captions to use instead of loading from files
14
+
15
+ Returns:
16
+ str: The optimized prompt following the training format
17
+ """
18
+ # Get captions either from directory or provided list
19
+ all_captions = []
20
+
21
+ if captions_list:
22
+ all_captions = captions_list
23
+ elif captions_dir:
24
+ # Collect all captions from text files in the directory
25
+ captions_path = Path(captions_dir)
26
+ for file_path in captions_path.glob("*.txt"):
27
+ captions = extract_captions(file_path)
28
+ all_captions.extend(captions)
29
+
30
+ if not all_captions:
31
+ raise ValueError("No captions found. Please provide either caption files or a list of captions.")
32
+
33
+ # Concatenate all captions with newlines
34
+ captions_text = "\n".join(all_captions)
35
+
36
+ client = get_together_client()
37
+
38
+ messages = [
39
+ {"role": "system", "content": get_prompt()},
40
+ {
41
+ "role": "user",
42
+ "content": (
43
+ f"These are all of the captions used to train the LoRA:\n\n"
44
+ f"{captions_text}\n\n"
45
+ f"Now optimize this prompt to follow the caption format used in training: "
46
+ f"{user_prompt}"
47
+ )
48
+ }
49
+ ]
50
+
51
+ response = client.chat.completions.create(
52
+ model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
53
+ messages=messages
54
+ )
55
+
56
+ optimized_prompt = response.choices[0].message.content.strip()
57
+ return optimized_prompt
58
+
59
+ def main():
60
+ parser = argparse.ArgumentParser(description='Optimize prompts based on existing captions.')
61
+ parser.add_argument('--prompt', type=str, required=True, help='User prompt to optimize')
62
+ parser.add_argument('--captions', type=str, help='Directory containing caption .txt files')
63
+
64
+ args = parser.parse_args()
65
+
66
+ if not args.captions:
67
+ print("Error: --captions is required.")
68
+ return
69
+ if not os.path.isdir(args.captions):
70
+ print(f"Error: Captions directory '{args.captions}' does not exist.")
71
+ return
72
+
73
+ try:
74
+ optimized_prompt = optimize_prompt(args.prompt, args.captions)
75
+ print("\nOptimized Prompt:")
76
+ print(optimized_prompt)
77
+
78
+ except Exception as e:
79
+ print(f"Error optimizing prompt: {e}")
80
+
81
+ if __name__ == "__main__":
82
+ main()