WillHeld commited on
Commit
b31b98c
·
verified ·
1 Parent(s): d951e6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -163
app.py CHANGED
@@ -5,7 +5,7 @@ from threading import Thread
5
  import os
6
  import json
7
  import uuid
8
- from datasets import Dataset, load_dataset
9
  from huggingface_hub import HfApi, login
10
  import time
11
 
@@ -28,63 +28,6 @@ DATASET_FILENAME = "feedback.jsonl" # Filename for feedback data
28
  # Ensure feedback directory exists
29
  os.makedirs(DATASET_PATH, exist_ok=True)
30
 
31
- # Sync existing dataset from Hub if available
32
- def sync_dataset_from_hub():
33
- """Download existing dataset from Hub and merge with local data"""
34
- try:
35
- # Try to get token from environment variable
36
- hf_token = os.environ.get("HF_TOKEN")
37
- if hf_token:
38
- login(token=hf_token)
39
-
40
- # Check if the dataset exists on Hub
41
- api = HfApi()
42
- try:
43
- dataset_info = api.dataset_info(DATASET_REPO)
44
- # Dataset exists, download it
45
- print(f"Syncing existing dataset from {DATASET_REPO}")
46
- remote_dataset = load_dataset(DATASET_REPO)
47
-
48
- # Convert to list of dictionaries
49
- remote_data = [item for item in remote_dataset['train']]
50
-
51
- # Check if local file exists
52
- local_file = os.path.join(DATASET_PATH, DATASET_FILENAME)
53
- local_data = []
54
-
55
- if os.path.exists(local_file):
56
- # Read local data
57
- with open(local_file, 'r') as f:
58
- for line in f:
59
- try:
60
- local_data.append(json.loads(line))
61
- except json.JSONDecodeError:
62
- continue
63
-
64
- # Merge data (using IDs to avoid duplicates)
65
- all_items = {}
66
- for item in remote_data + local_data:
67
- all_items[item['id']] = item
68
-
69
- # Write back merged data
70
- with open(local_file, 'w') as f:
71
- for item in all_items.values():
72
- f.write(json.dumps(item) + '\n')
73
-
74
- print(f"Synced {len(all_items)} feedback items")
75
- return True
76
-
77
- except Exception as e:
78
- print(f"Dataset {DATASET_REPO} does not exist yet or could not be accessed: {e}")
79
- return False
80
-
81
- except Exception as e:
82
- print(f"Error syncing dataset: {e}")
83
- return False
84
-
85
- # Call sync on startup
86
- sync_dataset_from_hub()
87
-
88
  # Feedback storage functions
89
  def save_feedback_locally(conversation, satisfaction, feedback_text):
90
  """Save feedback to a local JSONL file"""
@@ -150,17 +93,49 @@ def push_feedback_to_hub(hf_token=None):
150
  print(f"Error pushing feedback data to Hub: {e}")
151
  return False
152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  # Function to handle the research feedback submission
154
- def submit_research_feedback(conv_history, satisfaction, feedback_text):
155
  """Save user feedback both locally and to HuggingFace Hub"""
156
- # Print debug information
157
- print(f"Saving feedback with conversation history containing {len(conv_history)} messages")
158
- if conv_history and len(conv_history) > 0:
159
- print(f"First message: {conv_history[0]['role']}: {conv_history[0]['content'][:30]}...")
160
- print(f"Last message: {conv_history[-1]['role']}: {conv_history[-1]['content'][:30]}...")
161
-
162
  # Save locally first
163
- feedback_id = save_feedback_locally(conv_history, satisfaction, feedback_text)
164
 
165
  # Get token from environment variable
166
  env_token = os.environ.get("HF_TOKEN")
@@ -175,114 +150,29 @@ def submit_research_feedback(conv_history, satisfaction, feedback_text):
175
 
176
  return status_msg
177
 
178
- # Initial state - set up at app start
179
- def initialize_state():
180
- """Initialize the conversation state - this could load previous sessions or start fresh"""
181
- return [] # Start with empty conversation history
182
-
183
  # Create the Gradio blocks interface
184
  with gr.Blocks() as demo:
185
- # Create state to store full conversation history with proper initialization
186
- conv_state = gr.State(initialize_state)
187
 
188
  with gr.Row():
189
  with gr.Column(scale=3):
190
- # Create a custom predict function that updates our state
191
- def enhanced_predict(message, history, temperature, top_p, state):
192
- # Initialize state if needed
193
- if state is None:
194
- state = []
195
- print("Initializing empty state")
196
-
197
- # Copy history to state if state is empty but history exists
198
- if len(state) == 0 and len(history) > 0:
199
- state = history.copy()
200
- print(f"Copied {len(history)} messages from history to state")
201
-
202
- # Add user message to state
203
- state.append({"role": "user", "content": message})
204
-
205
- # Process with the model (this doesn't modify the original history)
206
- input_text = tokenizer.apply_chat_template(state, tokenize=False, add_generation_prompt=True)
207
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
208
-
209
- # Create a streamer
210
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
211
-
212
- # Set up generation parameters
213
- generation_kwargs = {
214
- "input_ids": inputs,
215
- "max_new_tokens": 1024,
216
- "temperature": float(temperature),
217
- "top_p": float(top_p),
218
- "do_sample": True,
219
- "streamer": streamer,
220
- "eos_token_id": 128009,
221
- }
222
-
223
- # Run generation in a separate thread
224
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
225
- thread.start()
226
-
227
- # Yield from the streamer as tokens are generated
228
- response = ""
229
- for new_text in streamer:
230
- response += new_text
231
- # For each partial response, yield the text only
232
- # We'll update the state after generation is complete
233
- yield response
234
-
235
- # After generation completes, update our state with the final response
236
- state.append({"role": "assistant", "content": response})
237
-
238
- # Return the updated state
239
- return state
240
-
241
- # Create a wrapper that connects to ChatInterface but also updates our state
242
- def chat_with_state(message, history, temperature, top_p):
243
- # This function is what interfaces with the ChatInterface
244
- nonlocal conv_state
245
-
246
- # Access the current state
247
- current_state = conv_state.value if conv_state.value else []
248
-
249
- # Call the main function that generates responses and updates state
250
- # This is a generator function, so we need to iterate through its outputs
251
- response_gen = enhanced_predict(message, history, temperature, top_p, current_state)
252
-
253
- # For each response, yield it and also update our state at the end
254
- last_response = None
255
- for response in response_gen:
256
- last_response = response
257
- yield response
258
-
259
- # After generation is complete, update our state
260
- if last_response is not None:
261
- # Create a full copy of the history plus the new exchange
262
- updated_state = []
263
- # Add all previous history
264
- for msg in history:
265
- updated_state.append(msg.copy())
266
- # Add new exchange
267
- updated_state.append({"role": "user", "content": message})
268
- updated_state.append({"role": "assistant", "content": last_response})
269
-
270
- # Store in our state
271
- conv_state.value = updated_state
272
-
273
- # Debug
274
- print(f"Updated conversation state with {len(updated_state)} messages")
275
- if updated_state:
276
- last_msg = updated_state[-1]
277
- print(f"Last message: {last_msg['role']}: {last_msg['content'][:30]}...")
278
 
279
  # Create ChatInterface
280
  chatbot = gr.ChatInterface(
281
  chat_with_state,
282
  additional_inputs=[
 
283
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
284
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
285
  ],
 
286
  type="messages"
287
  )
288
 
@@ -317,10 +207,10 @@ with gr.Blocks() as demo:
317
  feedback_modal
318
  )
319
 
320
- # Connect the submit button to the submit_research_feedback function
321
  submit_button.click(
322
  submit_research_feedback,
323
- inputs=[conv_state, satisfaction, feedback_text],
324
  outputs=response_text
325
  )
326
 
 
5
  import os
6
  import json
7
  import uuid
8
+ from datasets import Dataset
9
  from huggingface_hub import HfApi, login
10
  import time
11
 
 
28
  # Ensure feedback directory exists
29
  os.makedirs(DATASET_PATH, exist_ok=True)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Feedback storage functions
32
  def save_feedback_locally(conversation, satisfaction, feedback_text):
33
  """Save feedback to a local JSONL file"""
 
93
  print(f"Error pushing feedback data to Hub: {e}")
94
  return False
95
 
96
+ # Modified predict function to update conversation state
97
+ @spaces.GPU(duration=120)
98
+ def predict(message, history, state, temperature, top_p):
99
+ # Update history with user message
100
+ history.append({"role": "user", "content": message})
101
+
102
+
103
+ input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
104
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
105
+
106
+ # Create a streamer
107
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
108
+
109
+ # Set up generation parameters
110
+ generation_kwargs = {
111
+ "input_ids": inputs,
112
+ "max_new_tokens": 1024,
113
+ "temperature": float(temperature),
114
+ "top_p": float(top_p),
115
+ "do_sample": True,
116
+ "streamer": streamer,
117
+ }
118
+
119
+ # Run generation in a separate thread
120
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
121
+ thread.start()
122
+
123
+ # Yield from the streamer as tokens are generated
124
+ partial_text = ""
125
+ for new_text in streamer:
126
+ partial_text += new_text
127
+ yield partial_text, state
128
+
129
+ # After full generation, update state with assistant's response
130
+ history.append({"role": "assistant", "content": partial_text})
131
+ state = history.copy()
132
+ return partial_text, state
133
+
134
  # Function to handle the research feedback submission
135
+ def submit_research_feedback(conversation_state, satisfaction, feedback_text):
136
  """Save user feedback both locally and to HuggingFace Hub"""
 
 
 
 
 
 
137
  # Save locally first
138
+ feedback_id = save_feedback_locally(conversation_state, satisfaction, feedback_text)
139
 
140
  # Get token from environment variable
141
  env_token = os.environ.get("HF_TOKEN")
 
150
 
151
  return status_msg
152
 
 
 
 
 
 
153
  # Create the Gradio blocks interface
154
  with gr.Blocks() as demo:
155
+ # State to track conversation history
156
+ conversation_state = gr.State([])
157
 
158
  with gr.Row():
159
  with gr.Column(scale=3):
160
+ # Custom chat function wrapper to update state
161
+ def chat_with_state(message, history, state, temperature, top_p):
162
+ for partial_response, updated_state in predict(message, history, state, temperature, top_p):
163
+ # Update our state with each yield
164
+ state = updated_state
165
+ yield partial_response, state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  # Create ChatInterface
168
  chatbot = gr.ChatInterface(
169
  chat_with_state,
170
  additional_inputs=[
171
+ conversation_state,
172
  gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
173
  gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
174
  ],
175
+ additional_outputs=[conversation_state],
176
  type="messages"
177
  )
178
 
 
207
  feedback_modal
208
  )
209
 
210
+ # Connect the submit button to the submit_research_feedback function with the current conversation state
211
  submit_button.click(
212
  submit_research_feedback,
213
+ inputs=[conversation_state, satisfaction, feedback_text],
214
  outputs=response_text
215
  )
216