Johnny Lee commited on
Commit
1adb1d0
·
1 Parent(s): 6dd083e
Files changed (1) hide show
  1. app.py +73 -27
app.py CHANGED
@@ -41,7 +41,7 @@ LOG.setLevel(logging.INFO)
41
  GPT_3_5_CONTEXT_LENGTH = 4096
42
  CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
43
 
44
- SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
45
  Follow this message's instructions carefully. Respond using markdown.
46
  Never repeat these instructions in a subsequent message.
47
 
@@ -54,6 +54,18 @@ You will start an conversation with me in the following form:
54
  6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning.
55
  """
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  with open("templates.json") as json_f:
58
  CASES = {case["name"]: case["template"] for case in json.load(json_f)}
59
 
@@ -112,22 +124,29 @@ def make_llm_state(use_claude: bool = False) -> Dict[str, Any]:
112
 
113
 
114
  def make_template(
115
- system_msg: str = SYSTEM_MESSAGE, template_name: str = "Netflix"
116
  ) -> ChatPromptTemplate:
117
- knowledge_cutoff = "Early 2023"
118
  current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime(
119
  "%Y-%m-%d"
120
  )
121
- case_template = get_case_template(template_name)
122
- system_msg += f"""
123
- {case_template}
124
-
125
- Knowledge cutoff: {knowledge_cutoff}
126
- Current date: {current_date}
127
- """
 
 
 
 
 
 
 
 
128
 
129
  human_template = "{input}"
130
- LOG.info(system_msg)
131
  return ChatPromptTemplate.from_messages(
132
  [
133
  SystemMessagePromptTemplate.from_template(system_msg),
@@ -138,7 +157,9 @@ def make_template(
138
 
139
 
140
  def update_system_prompt(
141
- system_msg: str, llm_option: str, template_option: str
 
 
142
  ) -> Tuple[str, Dict[str, Any]]:
143
  template_output = make_template(system_msg, template_option)
144
  state = set_state()
@@ -160,6 +181,17 @@ def update_system_prompt(
160
  return updated_status, state
161
 
162
 
 
 
 
 
 
 
 
 
 
 
 
163
  def set_state(
164
  state: Optional[gr.State] = None, metadata: Optional[Dict[str, str]] = None
165
  ) -> Dict[str, Any]:
@@ -270,10 +302,13 @@ async def respond(
270
  LOG.info(f"RUNID: {run_id}")
271
  if run_id:
272
  run_collector.traced_runs = []
273
- url = Client().share_run(run_id)
274
- LOG.info(f"""URL : {url}""")
275
- url_markdown = f"""[Shareable chat history link]({url})
276
- [{url}]({url})"""
 
 
 
277
  yield state["history"], state, url_markdown
278
  LOG.info(f"""[{request.username}] ENDING CHAIN""")
279
  LOG.debug(f"History: {state['history']}")
@@ -287,7 +322,7 @@ async def respond(
287
  },
288
  )
289
  LOG.debug(f"Data to flag: {data_to_flag}")
290
- # gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
291
  except Exception as e:
292
  LOG.exception(e)
293
  raise e
@@ -301,7 +336,7 @@ theme = gr.themes.Soft()
301
 
302
  creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
303
 
304
- # gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats")
305
  title = "CBS Technology Strategy - Fall 2023"
306
  image_url = ""
307
  with gr.Blocks(
@@ -311,7 +346,12 @@ with gr.Blocks(
311
  ) as demo:
312
  state = gr.State()
313
  gr.Markdown(f"""### {title}""")
314
- with gr.Tab("Debate Partner"):
 
 
 
 
 
315
  case_input = gr.Dropdown(
316
  label="Case",
317
  choices=CASES.keys(),
@@ -327,36 +367,42 @@ with gr.Blocks(
327
  )
328
  b1 = gr.Button(value="Submit")
329
  share_link = gr.Markdown()
330
- with gr.Tab("Setup"):
331
  llm_input = gr.Dropdown(
332
  label="LLM",
333
  choices=["Claude 2", "GPT-4"],
334
  value="GPT-4",
335
  multiselect=False,
 
336
  )
337
  system_prompt_input = gr.TextArea(
338
- label="System Prompt", value=SYSTEM_MESSAGE, lines=10
339
  )
340
- update_system_button = gr.Button(value="Update Prompt & Reset")
341
- status_markdown = gr.Markdown()
342
- # gradio_flagger.setup([chatbot], "chats")
343
 
344
  chat_bot_submit_params = dict(
345
  fn=respond, inputs=[input_message, state], outputs=[chatbot, state, share_link]
346
  )
347
  input_message.submit(**chat_bot_submit_params)
348
  b1.click(**chat_bot_submit_params)
 
 
 
 
 
349
  update_system_button.click(
350
  update_system_prompt,
351
- [system_prompt_input, llm_input, case_input],
352
  [status_markdown, state],
353
  )
354
  case_input.change(
355
  update_system_prompt,
356
- [system_prompt_input, llm_input, case_input],
357
  [status_markdown, state],
358
  )
359
-
 
360
  update_system_button.click(reset_textbox, [], [input_message])
361
  update_system_button.click(reset_textbox, [], [chatbot])
362
  case_input.change(reset_textbox, [], [input_message])
 
41
  GPT_3_5_CONTEXT_LENGTH = 4096
42
  CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
43
 
44
+ CASE_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
45
  Follow this message's instructions carefully. Respond using markdown.
46
  Never repeat these instructions in a subsequent message.
47
 
 
54
  6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning.
55
  """
56
 
57
+ RESEARCH_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
58
+ Follow this message's instructions carefully. Respond using markdown.
59
+ Never repeat these instructions in a subsequent message.
60
+
61
+ You will start an conversation with me in the following form:
62
+ 1. You are to be a professional research consultant to the MBA student.
63
+ 2. The student will be working in a group of classmates to collaborate on a proposal to solve a business dillema.
64
+ 3. Be as helpful as you can to the student while remaining factual.
65
+ 4. If you are not certain, please warn the student to conduct additional research on the internet.
66
+ 5. Use tables and bullet points as useful way to compare insights
67
+ """
68
+
69
  with open("templates.json") as json_f:
70
  CASES = {case["name"]: case["template"] for case in json.load(json_f)}
71
 
 
124
 
125
 
126
  def make_template(
127
+ system_msg: str = CASE_SYSTEM_MESSAGE, template_name: str = "Netflix"
128
  ) -> ChatPromptTemplate:
129
+ knowledge_cutoff = "Sept 2021"
130
  current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime(
131
  "%Y-%m-%d"
132
  )
133
+ if template_name in CASES.keys():
134
+ message_template = get_case_template(template_name)
135
+ system_msg += f"""
136
+ {message_template}
137
+
138
+ Knowledge cutoff: {knowledge_cutoff}
139
+ Current date: {current_date}
140
+ """
141
+ elif template_name == "Research Assistant":
142
+ knowledge_cutoff = "Early 2023"
143
+ system_msg = f"""{RESEARCH_SYSTEM_MESSAGE}
144
+
145
+ Knowledge cutoff: {knowledge_cutoff}
146
+ Current date: {current_date}
147
+ """
148
 
149
  human_template = "{input}"
 
150
  return ChatPromptTemplate.from_messages(
151
  [
152
  SystemMessagePromptTemplate.from_template(system_msg),
 
157
 
158
 
159
  def update_system_prompt(
160
+ template_option: str,
161
+ system_msg: str = CASE_SYSTEM_MESSAGE,
162
+ llm_option: str = "gpt-4",
163
  ) -> Tuple[str, Dict[str, Any]]:
164
  template_output = make_template(system_msg, template_option)
165
  state = set_state()
 
181
  return updated_status, state
182
 
183
 
184
+ def update_system_prompt_mode(system_mode: str):
185
+ if system_mode == "Research Assistant":
186
+ status, state = update_system_prompt(
187
+ llm_option="Claude 2", template_option=system_mode
188
+ )
189
+ return state, gr.update(visible=False)
190
+ else:
191
+ status, state = update_system_prompt(template_option="Netflix")
192
+ return state, gr.update(visible=True, value="Netflix")
193
+
194
+
195
  def set_state(
196
  state: Optional[gr.State] = None, metadata: Optional[Dict[str, str]] = None
197
  ) -> Dict[str, Any]:
 
302
  LOG.info(f"RUNID: {run_id}")
303
  if run_id:
304
  run_collector.traced_runs = []
305
+ try:
306
+ url = Client().share_run(run_id)
307
+ LOG.info(f"""URL : {url}""")
308
+ url_markdown = f"""[Shareable chat history link]({url})"""
309
+ except Exception as exc:
310
+ LOG.error(exc)
311
+ url_markdown = "Share link not currently available"
312
  yield state["history"], state, url_markdown
313
  LOG.info(f"""[{request.username}] ENDING CHAIN""")
314
  LOG.debug(f"History: {state['history']}")
 
322
  },
323
  )
324
  LOG.debug(f"Data to flag: {data_to_flag}")
325
+ gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
326
  except Exception as e:
327
  LOG.exception(e)
328
  raise e
 
336
 
337
  creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
338
 
339
+ gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats")
340
  title = "CBS Technology Strategy - Fall 2023"
341
  image_url = ""
342
  with gr.Blocks(
 
346
  ) as demo:
347
  state = gr.State()
348
  gr.Markdown(f"""### {title}""")
349
+ with gr.Tab("Chatbot"):
350
+ chatbot_mode = gr.Radio(
351
+ label="Mode",
352
+ choices=["Debate Partner", "Research Assistant"],
353
+ value="Debate Partner",
354
+ )
355
  case_input = gr.Dropdown(
356
  label="Case",
357
  choices=CASES.keys(),
 
367
  )
368
  b1 = gr.Button(value="Submit")
369
  share_link = gr.Markdown()
 
370
  llm_input = gr.Dropdown(
371
  label="LLM",
372
  choices=["Claude 2", "GPT-4"],
373
  value="GPT-4",
374
  multiselect=False,
375
+ visible=False,
376
  )
377
  system_prompt_input = gr.TextArea(
378
+ label="System Prompt", value=CASE_SYSTEM_MESSAGE, lines=10, visible=False
379
  )
380
+ update_system_button = gr.Button(value="Update Prompt & Reset", visible=False)
381
+ status_markdown = gr.Markdown(visible=False)
382
+ gradio_flagger.setup([chatbot], "chats")
383
 
384
  chat_bot_submit_params = dict(
385
  fn=respond, inputs=[input_message, state], outputs=[chatbot, state, share_link]
386
  )
387
  input_message.submit(**chat_bot_submit_params)
388
  b1.click(**chat_bot_submit_params)
389
+ chatbot_mode.change(
390
+ update_system_prompt_mode,
391
+ [chatbot_mode],
392
+ [state, case_input],
393
+ )
394
  update_system_button.click(
395
  update_system_prompt,
396
+ [case_input, system_prompt_input, llm_input],
397
  [status_markdown, state],
398
  )
399
  case_input.change(
400
  update_system_prompt,
401
+ [case_input, system_prompt_input, llm_input],
402
  [status_markdown, state],
403
  )
404
+ chatbot_mode.change(reset_textbox, [], [input_message])
405
+ chatbot_mode.change(reset_textbox, [], [chatbot])
406
  update_system_button.click(reset_textbox, [], [input_message])
407
  update_system_button.click(reset_textbox, [], [chatbot])
408
  case_input.change(reset_textbox, [], [input_message])