Jen Ben Arye commited on
Commit
36b0fc6
·
1 Parent(s): 44089d9

added history context to prompt

Browse files
Files changed (1) hide show
  1. ml/kto_dataset_processor.py +138 -102
ml/kto_dataset_processor.py CHANGED
@@ -3,57 +3,96 @@ import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
  import json
5
  from ipdb import set_trace as st
6
-
7
-
8
-
9
- def process_feel_dataset():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  """
11
  Processes the feel dataset into a format suitable for KTO training using TRL.
12
 
13
  Args:
14
- data (list): A list of dictionaries containing conversation data.
 
 
15
 
16
  Returns:
17
- dict: A dictionary containing the 'train' and 'test' splits of the dataset in KTO format, as Hugging Face Dataset objects.
18
  """
19
-
20
- # Load feel dataset
21
- # Load the JSON file
22
- file_path = "../data/example_data.json"
23
- with open(file_path, "r") as file:
24
- feel_dataset = json.load(file)
25
-
26
-
27
  kto_data = []
28
 
29
- # Function to transform a single conversation into KTO format
30
- def transform_conversation(entry):
31
- conversation = entry["conversation"]
32
- data_points = []
33
- user_timestamp = None
34
-
35
- for i in range(len(conversation)):
36
- message = conversation[i]
37
- if message["role"] == "user":
38
- user_timestamp = entry["timestamp"]
39
- if (
40
- message["role"] == "assistant" and
41
- message["rating"] in [1, -1] # Only process feedback with positive or negative ratings
42
- ):
43
- user_content = conversation[i - 1]["content"] if i > 0 and conversation[i - 1]["role"] == "user" else ""
44
- data_points.append({
45
- "prompt": user_content.strip(),
46
- "completion": message["content"].strip(),
47
- "label": message["rating"] == 1, # True for positive feedback, False for negative (KTO Trainer format)
48
- "timestamp": user_timestamp,
49
- "session_id": entry["session_id"],
50
- "conversation_id": entry["conversation_id"]
51
- })
52
- return data_points
53
-
54
  # Process all conversations in the dataset
55
  for entry in feel_dataset:
56
- kto_data.extend(transform_conversation(entry))
 
 
 
 
 
57
 
58
  # Convert to DataFrame
59
  kto_df = pd.DataFrame(kto_data)
@@ -71,66 +110,63 @@ def process_feel_dataset():
71
 
72
  return {"train": train_dataset, "test": test_dataset}
73
 
74
-
75
-
76
-
77
- # def process_dataset_ultrafeedback():
78
- # """
79
- # Processes the 'train_prefs' and 'test_prefs' splits of the 'HuggingFaceH4/ultrafeedback_binarized' dataset
80
- # into a unified format for preference modeling.
81
-
82
- # Returns:
83
- # dict: A dictionary containing the unified 'train' and 'test' splits of the dataset in the KTO format.
84
- # Each split is a Hugging Face Dataset object.
85
- # """
86
- # # Load the relevant splits of the dataset
87
- # dataset_name = "HuggingFaceH4/ultrafeedback_binarized"
88
- # train_prefs = load_dataset(dataset_name, split="train_prefs")
89
- # test_prefs = load_dataset(dataset_name, split="test_prefs")
90
-
91
- # # Function to transform a single example into the desired schema
92
- # def transform_data(example):
93
- # data_points = []
94
- # # Chosen completion
95
- # chosen_completion = example["chosen"][1]["content"]
96
- # if chosen_completion.strip(): # Check for non-empty completions
97
- # data_points.append({
98
- # "prompt": example["prompt"],
99
- # "completion": chosen_completion.strip(),
100
- # "label": True
101
- # })
102
- # # Rejected completion
103
- # rejected_completion = example["rejected"][1]["content"]
104
- # if rejected_completion.strip(): # Check for non-empty completions
105
- # data_points.append({
106
- # "prompt": example["prompt"],
107
- # "completion": rejected_completion.strip(),
108
- # "label": False
109
- # })
110
- # return data_points
111
-
112
- # # Process train and test splits
113
- # train_data = []
114
- # test_data = []
115
-
116
- # for example in train_prefs:
117
- # train_data.extend(transform_data(example))
118
-
119
- # for example in test_prefs:
120
- # test_data.extend(transform_data(example))
121
-
122
- # # Convert unified data to DataFrames
123
- # train_df = pd.DataFrame(train_data)
124
- # test_df = pd.DataFrame(test_data)
125
-
126
-
127
- # # Convert to Hugging Face Dataset
128
- # unified_train = Dataset.from_pandas(train_df)
129
- # unified_test = Dataset.from_pandas(test_df)
130
-
131
- # return {"train": unified_train, "test": unified_test}
132
-
133
-
134
  if __name__ == "__main__":
135
- kto_dataset = process_feel_dataset()
136
- st()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from sklearn.model_selection import train_test_split
4
  import json
5
  from ipdb import set_trace as st
6
+ import tiktoken
7
+ from transformers import AutoTokenizer
8
+
9
+ def count_tokens(text: str, model_name: str) -> int:
10
+ """Count tokens in text using model's tokenizer"""
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+ return len(tokenizer.encode(text))
13
+
14
+ def format_conversation(messages: list, model_name: str) -> str:
15
+ """Format messages using model's chat template"""
16
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
17
+ return tokenizer.apply_chat_template(messages, tokenize=False)
18
+
19
+ def transform_conversation(
20
+ entry: dict,
21
+ model_name: str,
22
+ max_history_turns: int = 10,
23
+ max_history_tokens: int = 4000
24
+ ) -> list:
25
+ """Transform conversation into KTO format with history"""
26
+ data_points = []
27
+ conversation = entry["conversation"]
28
+
29
+ for i, message in enumerate(conversation):
30
+ # Only process assistant messages with ratings
31
+ if message["role"] != "assistant" or message["rating"] not in [1, -1]:
32
+ continue
33
+
34
+ # Get previous messages up to limits
35
+ history = []
36
+ tokens = 0
37
+ turns = 0
38
+
39
+ # Start from i-1 instead of going through all previous messages
40
+ for prev in reversed(conversation[max(0, i-1):i]):
41
+ if turns >= max_history_turns:
42
+ break
43
+
44
+ history.insert(0, prev)
45
+ formatted = format_conversation(history, model_name)
46
+ tokens = count_tokens(formatted, model_name)
47
+
48
+ if tokens > max_history_tokens:
49
+ history.pop(0)
50
+ break
51
+
52
+ turns += 1
53
+
54
+ # Format prompt with just the immediate previous message
55
+ prompt = format_conversation([conversation[i-1]], model_name) if i > 0 else ""
56
+
57
+ data_points.append({
58
+ "prompt": prompt.strip(),
59
+ "completion": message["content"].strip(),
60
+ "label": message["rating"] == 1,
61
+ "timestamp": entry["timestamp"],
62
+ "session_id": entry["session_id"],
63
+ "conversation_id": entry["conversation_id"]
64
+ })
65
+
66
+ return data_points
67
+
68
+ def process_feel_dataset(
69
+ model_name: str = "HuggingFaceH4/zephyr-7b-beta",
70
+ max_history_turns: int = 10,
71
+ max_history_tokens: int = 4000
72
+ ):
73
  """
74
  Processes the feel dataset into a format suitable for KTO training using TRL.
75
 
76
  Args:
77
+ model_name: Name of the model to format for
78
+ max_history_turns: Maximum number of previous turns to include in history
79
+ max_history_tokens: Maximum number of tokens allowed in history
80
 
81
  Returns:
82
+ dict: A dictionary containing the 'train' and 'test' splits of the dataset in KTO format
83
  """
84
+ # Load feel dataset from HuggingFace
85
+ feel_dataset = load_dataset("feel-fl/feel-feedback")["train"]
 
 
 
 
 
 
86
  kto_data = []
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  # Process all conversations in the dataset
89
  for entry in feel_dataset:
90
+ kto_data.extend(transform_conversation(
91
+ entry,
92
+ model_name,
93
+ max_history_turns,
94
+ max_history_tokens
95
+ ))
96
 
97
  # Convert to DataFrame
98
  kto_df = pd.DataFrame(kto_data)
 
110
 
111
  return {"train": train_dataset, "test": test_dataset}
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if __name__ == "__main__":
114
+ # Process the dataset
115
+ datasets = process_feel_dataset()
116
+
117
+ # Print basic statistics
118
+ print("\nDataset Statistics:")
119
+ print(f"Train set size: {len(datasets['train'])}")
120
+ print(f"Test set size: {len(datasets['test'])}")
121
+
122
+ # Print distribution of positive/negative labels
123
+ train_labels = datasets['train']['label']
124
+ test_labels = datasets['test']['label']
125
+
126
+ print("\nLabel Distribution:")
127
+ print("Train set:")
128
+ print(f"Positive feedback: {sum(train_labels)}")
129
+ print(f"Negative feedback: {len(train_labels) - sum(train_labels)}")
130
+ print(f"Positive ratio: {sum(train_labels)/len(train_labels):.2%}")
131
+
132
+ print("\nTest set:")
133
+ print(f"Positive feedback: {sum(test_labels)}")
134
+ print(f"Negative feedback: {len(test_labels) - sum(test_labels)}")
135
+ print(f"Positive ratio: {sum(test_labels)/len(test_labels):.2%}")
136
+
137
+ # Load original FEEL dataset
138
+ feel_dataset = load_dataset("feel-fl/feel-feedback", split="train")
139
+
140
+ # Print one original conversation
141
+ print("\nOriginal conversation from FEEL dataset:")
142
+ print(json.dumps(feel_dataset[0], indent=2))
143
+
144
+ # Print sample entries from processed dataset
145
+ print("\nSample entries from processed KTO dataset:")
146
+ print("\n" + "="*80 + "\nTRAIN SET SAMPLES\n" + "="*80)
147
+
148
+ # for i, example in enumerate(datasets['train'].select(range(min(3, len(datasets['train']))))):
149
+ # print(f"\nEntry #{i+1}:")
150
+ # print("-" * 40)
151
+ # for field, value in example.items():
152
+ # print(f"\n{field}:")
153
+ # if isinstance(value, str):
154
+ # # Print strings with line breaks for better readability
155
+ # print(f"{value}")
156
+ # else:
157
+ # print(f"{value}")
158
+ # print("\n" + "-"*80)
159
+
160
+ # print("\n" + "="*80 + "\nTEST SET SAMPLES\n" + "="*80)
161
+
162
+ # for i, example in enumerate(datasets['test'].select(range(min(3, len(datasets['test']))))):
163
+ # print(f"\nEntry #{i+1}:")
164
+ # print("-" * 40)
165
+ # for field, value in example.items():
166
+ # print(f"\n{field}:")
167
+ # if isinstance(value, str):
168
+ # # Print strings with line breaks for better readability
169
+ # print(f"{value}")
170
+ # else:
171
+ # print(f"{value}")
172
+ # print("\n" + "-"*80)