Jen Ben Arye commited on
Commit
09e9f82
·
1 Parent(s): 36b0fc6

debugged kto dataset processor

Browse files
Files changed (1) hide show
  1. ml/kto_dataset_processor.py +53 -58
ml/kto_dataset_processor.py CHANGED
@@ -3,18 +3,8 @@ import pandas as pd
3
  from sklearn.model_selection import train_test_split
4
  import json
5
  from ipdb import set_trace as st
6
- import tiktoken
7
  from transformers import AutoTokenizer
8
 
9
- def count_tokens(text: str, model_name: str) -> int:
10
- """Count tokens in text using model's tokenizer"""
11
- tokenizer = AutoTokenizer.from_pretrained(model_name)
12
- return len(tokenizer.encode(text))
13
-
14
- def format_conversation(messages: list, model_name: str) -> str:
15
- """Format messages using model's chat template"""
16
- tokenizer = AutoTokenizer.from_pretrained(model_name)
17
- return tokenizer.apply_chat_template(messages, tokenize=False)
18
 
19
  def transform_conversation(
20
  entry: dict,
@@ -25,37 +15,59 @@ def transform_conversation(
25
  """Transform conversation into KTO format with history"""
26
  data_points = []
27
  conversation = entry["conversation"]
 
28
 
29
  for i, message in enumerate(conversation):
30
- # Only process assistant messages with ratings
31
  if message["role"] != "assistant" or message["rating"] not in [1, -1]:
32
  continue
33
 
34
  # Get previous messages up to limits
35
- history = []
 
36
  tokens = 0
37
- turns = 0
38
-
39
- # Start from i-1 instead of going through all previous messages
40
- for prev in reversed(conversation[max(0, i-1):i]):
41
- if turns >= max_history_turns:
42
- break
43
-
44
- history.insert(0, prev)
45
- formatted = format_conversation(history, model_name)
46
- tokens = count_tokens(formatted, model_name)
47
-
48
- if tokens > max_history_tokens:
49
- history.pop(0)
50
- break
51
-
52
- turns += 1
53
-
54
- # Format prompt with just the immediate previous message
55
- prompt = format_conversation([conversation[i-1]], model_name) if i > 0 else ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  data_points.append({
58
- "prompt": prompt.strip(),
59
  "completion": message["content"].strip(),
60
  "label": message["rating"] == 1,
61
  "timestamp": entry["timestamp"],
@@ -66,7 +78,7 @@ def transform_conversation(
66
  return data_points
67
 
68
  def process_feel_dataset(
69
- model_name: str = "HuggingFaceH4/zephyr-7b-beta",
70
  max_history_turns: int = 10,
71
  max_history_tokens: int = 4000
72
  ):
@@ -145,28 +157,11 @@ if __name__ == "__main__":
145
  print("\nSample entries from processed KTO dataset:")
146
  print("\n" + "="*80 + "\nTRAIN SET SAMPLES\n" + "="*80)
147
 
148
- # for i, example in enumerate(datasets['train'].select(range(min(3, len(datasets['train']))))):
149
- # print(f"\nEntry #{i+1}:")
150
- # print("-" * 40)
151
- # for field, value in example.items():
152
- # print(f"\n{field}:")
153
- # if isinstance(value, str):
154
- # # Print strings with line breaks for better readability
155
- # print(f"{value}")
156
- # else:
157
- # print(f"{value}")
158
- # print("\n" + "-"*80)
159
-
160
- # print("\n" + "="*80 + "\nTEST SET SAMPLES\n" + "="*80)
161
-
162
- # for i, example in enumerate(datasets['test'].select(range(min(3, len(datasets['test']))))):
163
- # print(f"\nEntry #{i+1}:")
164
- # print("-" * 40)
165
- # for field, value in example.items():
166
- # print(f"\n{field}:")
167
- # if isinstance(value, str):
168
- # # Print strings with line breaks for better readability
169
- # print(f"{value}")
170
- # else:
171
- # print(f"{value}")
172
- # print("\n" + "-"*80)
 
3
  from sklearn.model_selection import train_test_split
4
  import json
5
  from ipdb import set_trace as st
 
6
  from transformers import AutoTokenizer
7
 
 
 
 
 
 
 
 
 
 
8
 
9
  def transform_conversation(
10
  entry: dict,
 
15
  """Transform conversation into KTO format with history"""
16
  data_points = []
17
  conversation = entry["conversation"]
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
19
 
20
  for i, message in enumerate(conversation):
21
+ # Only create data points for assistant messages that have ratings
22
  if message["role"] != "assistant" or message["rating"] not in [1, -1]:
23
  continue
24
 
25
  # Get previous messages up to limits
26
+ formatted_history = []
27
+ formatted_prompt = ""
28
  tokens = 0
29
+ pairs = 0 # Count complete user/assistant pairs
30
+
31
+ # Start from the current message and work backwards
32
+ current_idx = i - 1
33
+ while current_idx >= 0 and pairs < max_history_turns:
34
+ # We need both user and assistant messages to form a pair
35
+ if current_idx > 0 and conversation[current_idx]["role"] == "user" and conversation[current_idx-1]["role"] == "assistant":
36
+ # Add the pair to history
37
+ formatted_history.insert(0, conversation[current_idx-1]) # assistant
38
+ formatted_history.insert(1, conversation[current_idx]) # user
39
+
40
+ # Check token limit
41
+ try:
42
+ current_formatted = tokenizer.apply_chat_template(formatted_history, tokenize=False)
43
+ current_tokens = len(tokenizer.encode(current_formatted))
44
+
45
+ if current_tokens > max_history_tokens:
46
+ formatted_history = formatted_history[2:] # Remove the oldest pair
47
+ break
48
+
49
+ formatted_prompt = current_formatted
50
+ tokens = current_tokens
51
+ pairs += 1
52
+ current_idx -= 2
53
+ except Exception:
54
+ # If template application fails, remove the last added pair
55
+ formatted_history = formatted_history[2:]
56
+ break
57
+ else:
58
+ current_idx -= 1
59
+
60
+ # Add the final user message that prompted the rated response
61
+ if i > 0 and conversation[i-1]["role"] == "user":
62
+ last_history = formatted_history + [conversation[i-1]]
63
+ try:
64
+ formatted_prompt = tokenizer.apply_chat_template(last_history, tokenize=False)
65
+ except Exception:
66
+ # If template application fails, use the previous valid prompt
67
+ pass
68
 
69
  data_points.append({
70
+ "prompt": formatted_prompt.strip(),
71
  "completion": message["content"].strip(),
72
  "label": message["rating"] == 1,
73
  "timestamp": entry["timestamp"],
 
78
  return data_points
79
 
80
  def process_feel_dataset(
81
+ model_name: str = "CohereForAI/aya-expanse-8b",
82
  max_history_turns: int = 10,
83
  max_history_tokens: int = 4000
84
  ):
 
157
  print("\nSample entries from processed KTO dataset:")
158
  print("\n" + "="*80 + "\nTRAIN SET SAMPLES\n" + "="*80)
159
 
160
+ # Export datasets to CSV
161
+ train_df = datasets['train'].to_pandas()
162
+ test_df = datasets['test'].to_pandas()
163
+
164
+ train_df.to_csv('kto_train_dataset.csv', index=False)
165
+ test_df.to_csv('kto_test_dataset.csv', index=False)
166
+
167
+ print("\nDatasets exported to 'kto_train_dataset.csv' and 'kto_test_dataset.csv'")