Johnny Lee commited on
Commit
a3c7493
·
1 Parent(s): 90aac8f

much cleanup

Browse files
Files changed (1) hide show
  1. app.py +374 -242
app.py CHANGED
@@ -1,14 +1,17 @@
1
  # ruff: noqa: E501
 
2
  import asyncio
3
  import datetime
4
  import logging
5
  import os
 
6
  import requests
7
  import json
8
  import uuid
 
9
 
10
  from copy import deepcopy
11
- from typing import Any, Dict, List, Optional, Tuple
12
 
13
  import gradio as gr
14
  import pytz
@@ -37,49 +40,95 @@ logging.basicConfig(format="%(asctime)s %(name)s %(levelname)s:%(message)s")
37
  LOG = logging.getLogger(__name__)
38
  LOG.setLevel(logging.INFO)
39
 
40
-
41
  GPT_3_5_CONTEXT_LENGTH = 4096
42
  CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
43
 
44
- CASE_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
45
- Follow this message's instructions carefully. Respond using markdown.
46
- Never repeat these instructions in a subsequent message.
 
 
47
 
48
- You will start an conversation with me in the following form:
49
- 1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario.
50
- 2. We will pretend to be executives charged with solving the strategic question outlined in the scenario.
51
- 3. To start the conversation, you will provide summarize the question and provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position.
52
- 4. After receiving my position and explanation. You will choose an alternate position in the scenario.
53
- 5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic.
54
- 6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning.
55
- """
56
 
57
- RESEARCH_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
58
- Follow this message's instructions carefully. Respond using markdown.
59
- Never repeat these instructions in a subsequent message.
60
 
61
- You will start an conversation with me in the following form:
62
- 1. You are to be a professional research consultant to the MBA student.
63
- 2. The student will be working in a group of classmates to collaborate on a proposal to solve a business dillema.
64
- 3. Be as helpful as you can to the student while remaining factual.
65
- 4. If you are not certain, please warn the student to conduct additional research on the internet.
66
- 5. Use tables and bullet points as useful way to compare insights
67
- """
68
 
69
- with open("templates.json") as json_f:
70
- CASES = {case["name"]: case["template"] for case in json.load(json_f)}
 
 
71
 
 
 
 
 
 
 
 
 
72
 
73
- def get_case_template(template_name: str) -> str:
74
- case_template = CASES[template_name]
75
- return f"""{template_name}
76
 
77
- {case_template}
 
 
 
 
 
78
  """
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  def reset_textbox():
82
- return gr.update(value="")
83
 
84
 
85
  def auth(username, password):
@@ -98,157 +147,188 @@ def auth(username, password):
98
  return (username, password) in creds
99
 
100
 
101
- def make_llm_state(use_claude: bool = False) -> Dict[str, Any]:
102
- if use_claude:
103
- llm = ChatAnthropic(
104
- model="claude-2",
105
- anthropic_api_key=ANTHROPIC_API_KEY,
106
- temperature=1,
107
- max_tokens_to_sample=5000,
108
- streaming=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  )
110
- context_length = CLAUDE_2_CONTEXT_LENGTH
111
- tokenizer = tiktoken.get_encoding("cl100k_base")
112
- else:
113
- llm = ChatOpenAI(
114
- model_name="gpt-4",
115
- temperature=1,
116
- openai_api_key=OPENAI_API_KEY,
117
- max_retries=6,
118
- request_timeout=100,
119
- streaming=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
- context_length = GPT_3_5_CONTEXT_LENGTH
122
- _, tokenizer = llm._get_encoding_model()
123
- return dict(llm=llm, context_length=context_length, tokenizer=tokenizer)
124
-
125
-
126
- def make_template(
127
- system_msg: str = CASE_SYSTEM_MESSAGE, template_name: str = "Netflix"
128
- ) -> ChatPromptTemplate:
129
- knowledge_cutoff = "Sept 2021"
130
- current_date = datetime.datetime.now(pytz.timezone("America/New_York")).strftime(
131
- "%Y-%m-%d"
132
- )
133
- if template_name in CASES.keys():
134
- message_template = get_case_template(template_name)
135
- system_msg += f"""
136
- {message_template}
137
-
138
- Knowledge cutoff: {knowledge_cutoff}
139
- Current date: {current_date}
140
- """
141
- elif template_name == "Research Assistant":
142
- knowledge_cutoff = "Early 2023"
143
- system_msg = f"""{RESEARCH_SYSTEM_MESSAGE}
144
-
145
- Knowledge cutoff: {knowledge_cutoff}
146
- Current date: {current_date}
147
- """
148
-
149
- human_template = "{input}"
150
- return ChatPromptTemplate.from_messages(
151
- [
152
- SystemMessagePromptTemplate.from_template(system_msg),
153
- MessagesPlaceholder(variable_name="history"),
154
- HumanMessagePromptTemplate.from_template(human_template),
155
- ]
156
- )
157
-
158
-
159
- def update_system_prompt(
160
- template_option: str,
161
- system_msg: str = CASE_SYSTEM_MESSAGE,
162
- llm_option: str = "gpt-4",
163
- ) -> Tuple[str, Dict[str, Any]]:
164
- template_output = make_template(system_msg, template_option)
165
- state = set_state()
166
- state["template"] = template_output
167
- use_claude = llm_option == "Claude 2"
168
- state["llm_state"] = make_llm_state(use_claude)
169
- llm = state["llm_state"]["llm"]
170
- state["memory"] = ConversationTokenBufferMemory(
171
- llm=llm,
172
- max_token_limit=state["llm_state"]["context_length"],
173
- return_messages=True,
174
- )
175
- state["chain"] = ConversationChain(
176
- memory=state["memory"],
177
- prompt=state["template"],
178
- llm=llm,
179
- )
180
- updated_status = "Prompt Updated! Chat has reset."
181
- return updated_status, state
182
-
183
-
184
- def update_system_prompt_mode(system_mode: str):
185
- if system_mode == "Research Assistant":
186
- status, state = update_system_prompt(
187
- llm_option="Claude 2", template_option=system_mode
188
- )
189
- return state, gr.update(visible=False)
190
- else:
191
- status, state = update_system_prompt(template_option="Netflix")
192
- return state, gr.update(visible=True, value="Netflix")
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- def set_state(
196
- state: Optional[gr.State] = None, metadata: Optional[Dict[str, str]] = None
197
- ) -> Dict[str, Any]:
198
- if state is None:
199
- template = make_template()
200
- llm_state = make_llm_state()
201
- llm = llm_state["llm"]
 
 
202
  memory = ConversationTokenBufferMemory(
203
- llm=llm, max_token_limit=llm_state["context_length"], return_messages=True
204
  )
205
- chain = ConversationChain(
206
- memory=memory, prompt=template, llm=llm, metadata=metadata
207
  )
208
- session_id = str(uuid.uuid4())
209
- state = dict(
210
- template=template,
211
- llm_state=llm_state,
212
- history=[],
213
  memory=memory,
 
 
 
 
 
 
 
214
  chain=chain,
215
- session_id=session_id,
216
  )
217
- return state
218
- else:
219
- return state
220
 
221
 
222
  async def respond(
223
- inp: str,
224
- state: Optional[Dict[str, Any]],
 
 
225
  request: gr.Request,
226
- ) -> Tuple[List[str], gr.State, Optional[str]]:
227
  """Execute the chat functionality."""
228
 
229
  def prep_messages(
230
  user_msg: str, memory_buffer: List[BaseMessage]
231
  ) -> Tuple[str, List[BaseMessage]]:
232
- messages_to_send = state["template"].format_messages(
233
  input=user_msg, history=memory_buffer
234
  )
235
- user_msg_token_count = llm.get_num_tokens_from_messages([messages_to_send[-1]])
236
- total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
237
- while user_msg_token_count > context_length:
 
 
 
 
238
  LOG.warning(
239
  f"Pruning user message due to user message token length of {user_msg_token_count}"
240
  )
241
- user_msg = tokenizer.decode(
242
- llm.get_token_ids(user_msg)[: context_length - 100]
243
  )
244
- messages_to_send = state["template"].format_messages(
245
  input=user_msg, history=memory_buffer
246
  )
247
- user_msg_token_count = llm.get_num_tokens_from_messages(
248
  [messages_to_send[-1]]
249
  )
250
- total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
251
- while total_token_count > context_length:
 
 
252
  LOG.warning(
253
  f"Pruning memory due to total token length of {total_token_count}"
254
  )
@@ -256,45 +336,76 @@ async def respond(
256
  memory_buffer.pop(0)
257
  continue
258
  memory_buffer = memory_buffer[1:]
259
- messages_to_send = state["template"].format_messages(
260
  input=user_msg, history=memory_buffer
261
  )
262
- total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
 
 
263
  return user_msg, memory_buffer
264
 
265
  try:
266
  if state is None:
267
- state = set_state(metadata=dict(username=request.username))
268
- llm = state["llm_state"]["llm"]
269
- context_length = state["llm_state"]["context_length"]
270
- tokenizer = state["llm_state"]["tokenizer"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  LOG.info(f"""[{request.username}] STARTING CHAIN""")
272
- LOG.debug(f"History: {state['history']}")
273
- LOG.debug(f"User input: {inp}")
274
- inp, state["memory"].chat_memory.messages = prep_messages(
275
- inp, state["memory"].buffer
 
 
 
276
  )
277
- messages_to_send = state["template"].format_messages(
278
- input=inp, history=state["memory"].buffer
279
  )
280
- total_token_count = llm.get_num_tokens_from_messages(messages_to_send)
281
  LOG.debug(f"Messages to send: {messages_to_send}")
282
- LOG.info(f"Tokens to send: {total_token_count}")
283
  callback = AsyncIteratorCallbackHandler()
284
  run_collector = RunCollectorCallbackHandler()
285
  run = asyncio.create_task(
286
- state["chain"].apredict(
287
- input=inp,
288
  callbacks=[callback, run_collector],
289
  )
290
  )
291
- state["history"].append((inp, ""))
292
  run_id = None
 
293
  async for tok in callback.aiter():
294
- user, bot = state["history"][-1]
295
  bot += tok
296
- state["history"][-1] = (user, bot)
297
- yield state["history"], state, None
298
  await run
299
  if run_collector.traced_runs and run_id is None:
300
  run_id = run_collector.traced_runs[0].id
@@ -302,112 +413,133 @@ async def respond(
302
  if run_id:
303
  run_collector.traced_runs = []
304
  try:
305
- url = Client().share_run(run_id)
306
- LOG.info(f"""URL : {url}""")
307
- url_markdown = f"""[Shareable chat history link]({url})"""
 
 
308
  except Exception as exc:
309
  LOG.error(exc)
310
  url_markdown = "Share link not currently available"
311
- yield state["history"], state, url_markdown
 
 
 
 
 
 
312
  LOG.info(f"""[{request.username}] ENDING CHAIN""")
313
- LOG.debug(f"History: {state['history']}")
314
- LOG.debug(f"Memory: {state['memory'].json()}")
315
  data_to_flag = (
316
  {
317
- "history": deepcopy(state["history"]),
318
  "username": request.username,
319
  "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
320
- "session_id": state["session_id"],
 
 
321
  },
322
  )
323
  LOG.debug(f"Data to flag: {data_to_flag}")
324
  gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
325
  except Exception as e:
326
- LOG.exception(e)
327
  raise e
328
 
329
 
330
- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
331
- ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
332
- HF_TOKEN = os.getenv("HF_TOKEN")
 
 
333
 
334
- theme = gr.themes.Soft()
335
 
336
- creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- gradio_flagger = gr.HuggingFaceDatasetSaver(HF_TOKEN, "chats")
339
- title = "CBS Technology Strategy - Fall 2023"
340
- image_url = ""
341
  with gr.Blocks(
342
  theme=theme,
343
  analytics_enabled=False,
344
- title=title,
345
  ) as demo:
346
  state = gr.State()
347
- gr.Markdown(f"""### {title}""")
348
  with gr.Tab("Chatbot"):
349
- chatbot_mode = gr.Radio(
350
- label="Mode",
351
- choices=["Debate Partner", "Research Assistant"],
352
- value="Debate Partner",
353
- )
354
- case_input = gr.Dropdown(
355
- label="Case",
356
- choices=CASES.keys(),
357
- value="Netflix",
358
- multiselect=False,
359
- )
 
360
  chatbot = gr.Chatbot(label="ChatBot")
361
  with gr.Row():
362
  input_message = gr.Textbox(
363
  placeholder="Send a message.",
364
- label="Type an input and press Enter",
365
  scale=5,
366
  )
367
- b1 = gr.Button(value="Submit")
368
- share_link = gr.Markdown()
369
- llm_input = gr.Dropdown(
370
- label="LLM",
371
- choices=["Claude 2", "GPT-4"],
372
- value="GPT-4",
373
- multiselect=False,
374
- visible=False,
375
- )
376
- system_prompt_input = gr.TextArea(
377
- label="System Prompt", value=CASE_SYSTEM_MESSAGE, lines=10, visible=False
378
- )
379
- update_system_button = gr.Button(value="Update Prompt & Reset", visible=False)
380
- status_markdown = gr.Markdown(visible=False)
381
  gradio_flagger.setup([chatbot], "chats")
382
 
383
- chat_bot_submit_params = dict(
384
- fn=respond, inputs=[input_message, state], outputs=[chatbot, state, share_link]
385
- )
386
- input_message.submit(**chat_bot_submit_params)
387
- b1.click(**chat_bot_submit_params)
388
- chatbot_mode.change(
389
- update_system_prompt_mode,
390
- [chatbot_mode],
391
- [state, case_input],
392
  )
393
- update_system_button.click(
394
- update_system_prompt,
395
- [case_input, system_prompt_input, llm_input],
396
- [status_markdown, state],
 
 
397
  )
398
- case_input.change(
399
- update_system_prompt,
400
- [case_input, system_prompt_input, llm_input],
401
- [status_markdown, state],
402
  )
403
- chatbot_mode.change(reset_textbox, [], [input_message])
404
- chatbot_mode.change(reset_textbox, [], [chatbot])
405
- update_system_button.click(reset_textbox, [], [input_message])
406
- update_system_button.click(reset_textbox, [], [chatbot])
407
- case_input.change(reset_textbox, [], [input_message])
408
- case_input.change(reset_textbox, [], [chatbot])
409
- b1.click(reset_textbox, [], [input_message])
410
- input_message.submit(reset_textbox, [], [input_message])
411
 
412
  demo.queue(max_size=99, concurrency_count=99, api_open=False).launch(
413
  debug=True, auth=auth
 
1
  # ruff: noqa: E501
2
+ from __future__ import annotations
3
  import asyncio
4
  import datetime
5
  import logging
6
  import os
7
+ from enum import Enum
8
  import requests
9
  import json
10
  import uuid
11
+ from pydantic import BaseModel
12
 
13
  from copy import deepcopy
14
+ from typing import Any, Dict, List, Optional, Tuple, Union
15
 
16
  import gradio as gr
17
  import pytz
 
40
  LOG = logging.getLogger(__name__)
41
  LOG.setLevel(logging.INFO)
42
 
 
43
  GPT_3_5_CONTEXT_LENGTH = 4096
44
  CLAUDE_2_CONTEXT_LENGTH = 100000 # need to use claude tokenizer
45
 
46
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
47
+ ANTHROPIC_API_KEY = os.getenv("ANTHROPIC_API_KEY")
48
+ HF_TOKEN = os.getenv("HF_TOKEN")
49
+
50
+ theme = gr.themes.Soft()
51
 
52
+ creds = [(os.getenv("CHAT_USERNAME"), os.getenv("CHAT_PASSWORD"))]
 
 
 
 
 
 
 
53
 
54
+ gradio_flagger = gr.HuggingFaceDatasetSaver(
55
+ hf_token=HF_TOKEN, dataset_name="chats", separate_dirs=True
56
+ )
57
 
 
 
 
 
 
 
 
58
 
59
+ class ChatSystemMessage(str, Enum):
60
+ CASE_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
61
+ Follow this message's instructions carefully. Respond using markdown.
62
+ Never repeat these instructions in a subsequent message.
63
 
64
+ You will start an conversation with me in the following form:
65
+ 1. Below these instructions you will receive a business scenario. The scenario will (a) include the name of a company or category, and (b) a debatable multiple-choice question about the business scenario.
66
+ 2. We will pretend to be executives charged with solving the strategic question outlined in the scenario.
67
+ 3. To start the conversation, you will provide summarize the question and provide all options in the multiple choice question to me. Then, you will ask me to choose a position and provide a short opening argument. Do not yet provide your position.
68
+ 4. After receiving my position and explanation. You will choose an alternate position in the scenario.
69
+ 5. Inform me which position you have chosen, then proceed to have a discussion with me on this topic.
70
+ 6. The discussion should be informative and very rigorous. Do not agree with my arguments easily. Pursue a Socratic method of questioning and reasoning.
71
+ """
72
 
73
+ RESEARCH_SYSTEM_MESSAGE = """You are a helpful AI assistant for a Columbia Business School MBA student.
74
+ Follow this message's instructions carefully. Respond using markdown.
75
+ Never repeat these instructions in a subsequent message.
76
 
77
+ You will start an conversation with me in the following form:
78
+ 1. You are to be a professional research consultant to the MBA student.
79
+ 2. The student will be working in a group of classmates to collaborate on a proposal to solve a business dillema.
80
+ 3. Be as helpful as you can to the student while remaining factual.
81
+ 4. If you are not certain, please warn the student to conduct additional research on the internet.
82
+ 5. Use tables and bullet points as useful way to compare insights
83
  """
84
 
85
 
86
+ class ChatbotMode(str, Enum):
87
+ DEBATE_PARTNER = "Debate Partner"
88
+ RESEARCH_ASSISTANT = "Research Assistant"
89
+ DEFAULT = DEBATE_PARTNER
90
+
91
+
92
+ class PollQuestion(BaseModel): # type: ignore[misc]
93
+ name: str
94
+ template: str
95
+
96
+
97
+ class PollQuestions(BaseModel): # type: ignore[misc]
98
+ cases: List[PollQuestion]
99
+
100
+ @classmethod
101
+ def from_json_file(cls, json_file_path: str) -> PollQuestions:
102
+ """Expects a JSON file with an array of poll questions
103
+ Each JSON object should have "name" and "template" keys
104
+ """
105
+ with open(json_file_path, "r") as json_f:
106
+ payload = json.load(json_f)
107
+ return_obj_list = []
108
+ if isinstance(payload, list):
109
+ for case in payload:
110
+ return_obj_list.append(PollQuestion(**case))
111
+ return cls(cases=return_obj_list)
112
+ raise ValueError(
113
+ f"JSON object in {json_file_path} must be an array of PollQuestion"
114
+ )
115
+
116
+ def get_case(self, case_name: str) -> PollQuestion:
117
+ """Searches cases to return the template for poll question"""
118
+ for case in self.cases:
119
+ if case.name == case_name:
120
+ return case
121
+
122
+ def get_case_names(self) -> List[str]:
123
+ """Returns the names in cases"""
124
+ return [case.name for case in self.cases]
125
+
126
+
127
+ poll_questions = PollQuestions.from_json_file("templates.json")
128
+
129
+
130
  def reset_textbox():
131
+ return gr.update(value=""), gr.update(value=""), gr.update(value="")
132
 
133
 
134
  def auth(username, password):
 
147
  return (username, password) in creds
148
 
149
 
150
+ class ChatSession(BaseModel):
151
+ class Config:
152
+ arbitrary_types_allowed = True
153
+
154
+ context_length: int
155
+ tokenizer: tiktoken.Encoding
156
+ chain: ConversationChain
157
+ history: List[BaseMessage] = []
158
+ session_id: str = str(uuid.uuid4())
159
+
160
+ @staticmethod
161
+ def set_metadata(
162
+ username: str,
163
+ chatbot_mode: str,
164
+ turns_completed: int,
165
+ case: Optional[str] = None,
166
+ ) -> Dict[str, Union[str, int]]:
167
+ metadata = dict(
168
+ username=username,
169
+ chatbot_mode=chatbot_mode,
170
+ turns_completed=turns_completed,
171
+ case=case,
172
  )
173
+ return metadata
174
+
175
+ @staticmethod
176
+ def _make_template(
177
+ system_msg: str, poll_question_name: Optional[str] = None
178
+ ) -> ChatPromptTemplate:
179
+ knowledge_cutoff = "Sept 2021"
180
+ current_date = datetime.datetime.now(
181
+ pytz.timezone("America/New_York")
182
+ ).strftime("%Y-%m-%d")
183
+ if poll_question_name:
184
+ poll_question = poll_questions.get_case(poll_question_name)
185
+ if poll_question:
186
+ message_template = poll_question.template
187
+ system_msg += f"""
188
+ {message_template}
189
+
190
+ Knowledge cutoff: {knowledge_cutoff}
191
+ Current date: {current_date}
192
+ """
193
+ else:
194
+ knowledge_cutoff = "Early 2023"
195
+ system_msg += f"""
196
+
197
+ Knowledge cutoff: {knowledge_cutoff}
198
+ Current date: {current_date}
199
+ """
200
+
201
+ human_template = "{input}"
202
+ return ChatPromptTemplate.from_messages(
203
+ [
204
+ SystemMessagePromptTemplate.from_template(system_msg),
205
+ MessagesPlaceholder(variable_name="history"),
206
+ HumanMessagePromptTemplate.from_template(human_template),
207
+ ]
208
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
+ @staticmethod
211
+ def _set_llm(
212
+ use_claude: bool,
213
+ ) -> Tuple[Union[ChatOpenAI, ChatAnthropic], int, tiktoken.tokenizer]:
214
+ if use_claude:
215
+ llm = ChatAnthropic(
216
+ model="claude-2",
217
+ anthropic_api_key=ANTHROPIC_API_KEY,
218
+ temperature=1,
219
+ max_tokens_to_sample=5000,
220
+ streaming=True,
221
+ )
222
+ context_length = CLAUDE_2_CONTEXT_LENGTH
223
+ tokenizer = tiktoken.get_encoding("cl100k_base")
224
+ return llm, context_length, tokenizer
225
+ else:
226
+ llm = ChatOpenAI(
227
+ model_name="gpt-4",
228
+ temperature=1,
229
+ openai_api_key=OPENAI_API_KEY,
230
+ max_retries=6,
231
+ request_timeout=100,
232
+ streaming=True,
233
+ )
234
+ context_length = GPT_3_5_CONTEXT_LENGTH
235
+ _, tokenizer = llm._get_encoding_model()
236
+ return llm, context_length, tokenizer
237
+
238
+ def update_system_prompt(
239
+ self, system_msg: str, poll_question_name: Optional[str] = None
240
+ ) -> None:
241
+ self.chain.prompt = self._make_template(system_msg, poll_question_name)
242
+
243
+ def change_llm(self, use_claude: bool) -> None:
244
+ llm, self.context_length, self.tokenizer = self._set_llm(use_claude)
245
+ self.chain.llm = llm
246
+
247
+ def clear_memory(self) -> None:
248
+ self.chain.memory.clear()
249
+ self.history = []
250
+
251
+ def set_chatbot_mode(
252
+ self, case_mode: bool, poll_question_name: Optional[str] = None
253
+ ) -> None:
254
+ if case_mode and poll_question_name:
255
+ self.change_llm(use_claude=False)
256
+ self.update_system_prompt(
257
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
258
+ poll_question_name=poll_question_name,
259
+ )
260
+ else:
261
+ self.change_llm(use_claude=True)
262
+ self.update_system_prompt(
263
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE
264
+ )
265
 
266
+ @classmethod
267
+ def new(
268
+ cls,
269
+ use_claude: bool,
270
+ system_msg: str,
271
+ metadata: Dict[str, Any],
272
+ poll_question_name: Optional[str] = None,
273
+ ) -> ChatSession:
274
+ llm, context_length, tokenizer = cls._set_llm(use_claude)
275
  memory = ConversationTokenBufferMemory(
276
+ llm=llm, max_token_limit=context_length, return_messages=True
277
  )
278
+ template = cls._make_template(
279
+ system_msg=system_msg, poll_question_name=poll_question_name
280
  )
281
+ chain = ConversationChain(
 
 
 
 
282
  memory=memory,
283
+ prompt=template,
284
+ llm=llm,
285
+ metadata=metadata,
286
+ )
287
+ return cls(
288
+ context_length=context_length,
289
+ tokenizer=tokenizer,
290
  chain=chain,
 
291
  )
 
 
 
292
 
293
 
294
  async def respond(
295
+ chat_input: str,
296
+ chatbot_mode: str,
297
+ case_input: str,
298
+ state: ChatSession,
299
  request: gr.Request,
300
+ ) -> Tuple[List[str], ChatSession, str]:
301
  """Execute the chat functionality."""
302
 
303
  def prep_messages(
304
  user_msg: str, memory_buffer: List[BaseMessage]
305
  ) -> Tuple[str, List[BaseMessage]]:
306
+ messages_to_send = state.chain.prompt.format_messages(
307
  input=user_msg, history=memory_buffer
308
  )
309
+ user_msg_token_count = state.chain.llm.get_num_tokens_from_messages(
310
+ [messages_to_send[-1]]
311
+ )
312
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
313
+ messages_to_send
314
+ )
315
+ while user_msg_token_count > state.context_length:
316
  LOG.warning(
317
  f"Pruning user message due to user message token length of {user_msg_token_count}"
318
  )
319
+ user_msg = state.tokenizer.decode(
320
+ state.chain.llm.get_token_ids(user_msg)[: state.context_length - 100]
321
  )
322
+ messages_to_send = state.chain.prompt.format_messages(
323
  input=user_msg, history=memory_buffer
324
  )
325
+ user_msg_token_count = state.chain.llm.get_num_tokens_from_messages(
326
  [messages_to_send[-1]]
327
  )
328
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
329
+ messages_to_send
330
+ )
331
+ while total_token_count > state.context_length:
332
  LOG.warning(
333
  f"Pruning memory due to total token length of {total_token_count}"
334
  )
 
336
  memory_buffer.pop(0)
337
  continue
338
  memory_buffer = memory_buffer[1:]
339
+ messages_to_send = state.chain.prompt.format_messages(
340
  input=user_msg, history=memory_buffer
341
  )
342
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
343
+ messages_to_send
344
+ )
345
  return user_msg, memory_buffer
346
 
347
  try:
348
  if state is None:
349
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
350
+ new_session = ChatSession.new(
351
+ use_claude=False,
352
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
353
+ metadata=ChatSession.set_metadata(
354
+ username=request.username,
355
+ chatbot_mode=chatbot_mode,
356
+ turns_completed=0,
357
+ case=case_input,
358
+ ),
359
+ poll_question_name=case_input,
360
+ )
361
+ else:
362
+ new_session = ChatSession.new(
363
+ use_claude=True,
364
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE,
365
+ metadata=ChatSession.set_metadata(
366
+ username=request.username,
367
+ chatbot_mode=chatbot_mode,
368
+ turns_completed=0,
369
+ ),
370
+ poll_question_name=None,
371
+ )
372
+ state = new_session
373
+ state.chain.metadata = ChatSession.set_metadata(
374
+ username=request.username,
375
+ chatbot_mode=chatbot_mode,
376
+ turns_completed=len(state.history) + 1,
377
+ case=case_input,
378
+ )
379
  LOG.info(f"""[{request.username}] STARTING CHAIN""")
380
+ LOG.debug(f"History: {state.history}")
381
+ LOG.debug(f"User input: {chat_input}")
382
+ chat_input, state.chain.memory.chat_memory.messages = prep_messages(
383
+ chat_input, state.chain.memory.buffer
384
+ )
385
+ messages_to_send = state.chain.prompt.format_messages(
386
+ input=chat_input, history=state.chain.memory.buffer
387
  )
388
+ total_token_count = state.chain.llm.get_num_tokens_from_messages(
389
+ messages_to_send
390
  )
 
391
  LOG.debug(f"Messages to send: {messages_to_send}")
392
+ LOG.debug(f"Tokens to send: {total_token_count}")
393
  callback = AsyncIteratorCallbackHandler()
394
  run_collector = RunCollectorCallbackHandler()
395
  run = asyncio.create_task(
396
+ state.chain.apredict(
397
+ input=chat_input,
398
  callbacks=[callback, run_collector],
399
  )
400
  )
401
+ state.history.append((chat_input, ""))
402
  run_id = None
403
+ langsmith_url = None
404
  async for tok in callback.aiter():
405
+ user, bot = state.history[-1]
406
  bot += tok
407
+ state.history[-1] = (user, bot)
408
+ yield state.history, state, None
409
  await run
410
  if run_collector.traced_runs and run_id is None:
411
  run_id = run_collector.traced_runs[0].id
 
413
  if run_id:
414
  run_collector.traced_runs = []
415
  try:
416
+ langsmith_url = Client().share_run(run_id)
417
+ LOG.info(f"""Run ID: {run_id} \n URL : {langsmith_url}""")
418
+ url_markdown = (
419
+ f"""[Click to view shareable chat]({langsmith_url})"""
420
+ )
421
  except Exception as exc:
422
  LOG.error(exc)
423
  url_markdown = "Share link not currently available"
424
+ if (
425
+ len(state.history) > 9
426
+ and chatbot_mode == ChatbotMode.DEBATE_PARTNER
427
+ ):
428
+ url_markdown += """\n
429
+ 🙌 You have completed 10 exchanges with the chatbot."""
430
+ yield state.history, state, url_markdown
431
  LOG.info(f"""[{request.username}] ENDING CHAIN""")
432
+ LOG.debug(f"History: {state.history}")
433
+ LOG.debug(f"Memory: {state.chain.memory.json()}")
434
  data_to_flag = (
435
  {
436
+ "history": deepcopy(state.history),
437
  "username": request.username,
438
  "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
439
+ "session_id": state.session_id,
440
+ "metadata": state.chain.metadata,
441
+ "langsmith_url": langsmith_url,
442
  },
443
  )
444
  LOG.debug(f"Data to flag: {data_to_flag}")
445
  gradio_flagger.flag(flag_data=data_to_flag, username=request.username)
446
  except Exception as e:
447
+ LOG.error(e)
448
  raise e
449
 
450
 
451
+ class ChatbotConfig(BaseModel):
452
+ app_title: str = "CBS Technology Strategy - Fall 2023"
453
+ chatbot_modes: List[ChatbotMode] = [mode for mode in ChatbotMode]
454
+ case_options: List[str] = poll_questions.get_case_names()
455
+ default_case_option: str = "Netflix"
456
 
 
457
 
458
+ def change_chatbot_mode(
459
+ state: ChatSession, chatbot_mode: str, poll_question_name: str, request: gr.Request
460
+ ) -> Tuple[Any, ChatSession]:
461
+ """Returns a function that sets the visibility of the case input field and the state"""
462
+ if state is None:
463
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
464
+ new_session = ChatSession.new(
465
+ use_claude=False,
466
+ system_msg=ChatSystemMessage.CASE_SYSTEM_MESSAGE,
467
+ metadata=dict(username=request.username),
468
+ poll_question_name=case_input,
469
+ )
470
+ else:
471
+ new_session = ChatSession.new(
472
+ use_claude=True,
473
+ system_msg=ChatSystemMessage.RESEARCH_SYSTEM_MESSAGE,
474
+ metadata=dict(username=request.username),
475
+ poll_question_name=None,
476
+ )
477
+ state = new_session
478
+ if chatbot_mode == ChatbotMode.DEBATE_PARTNER:
479
+ state.set_chatbot_mode(case_mode=True, poll_question_name=poll_question_name)
480
+ state.clear_memory()
481
+ return gr.update(visible=True), state
482
+ elif chatbot_mode == ChatbotMode.RESEARCH_ASSISTANT:
483
+ state.set_chatbot_mode(case_mode=False)
484
+ state.clear_memory()
485
+ return gr.update(visible=False), state
486
+ else:
487
+ raise ValueError("chatbot_mode is not correctly set")
488
 
489
+
490
+ config = ChatbotConfig()
 
491
  with gr.Blocks(
492
  theme=theme,
493
  analytics_enabled=False,
494
+ title=config.app_title,
495
  ) as demo:
496
  state = gr.State()
497
+ gr.Markdown(f"""### {config.app_title}""")
498
  with gr.Tab("Chatbot"):
499
+ with gr.Row():
500
+ chatbot_mode = gr.Radio(
501
+ label="Mode",
502
+ choices=config.chatbot_modes,
503
+ value=ChatbotMode.DEFAULT,
504
+ )
505
+ case_input = gr.Dropdown(
506
+ label="Case",
507
+ choices=config.case_options,
508
+ value=config.default_case_option,
509
+ multiselect=False,
510
+ )
511
  chatbot = gr.Chatbot(label="ChatBot")
512
  with gr.Row():
513
  input_message = gr.Textbox(
514
  placeholder="Send a message.",
515
+ label="Type a message to begin",
516
  scale=5,
517
  )
518
+ chat_submit_button = gr.Button(value="Submit")
519
+ status_message = gr.Markdown()
 
 
 
 
 
 
 
 
 
 
 
 
520
  gradio_flagger.setup([chatbot], "chats")
521
 
522
+ chatbot_submit_params = dict(
523
+ fn=respond,
524
+ inputs=[input_message, chatbot_mode, case_input, state],
525
+ outputs=[chatbot, state, status_message],
 
 
 
 
 
526
  )
527
+ input_message.submit(**chatbot_submit_params)
528
+ chat_submit_button.click(**chatbot_submit_params)
529
+ chatbot_mode_params = dict(
530
+ fn=change_chatbot_mode,
531
+ inputs=[state, chatbot_mode, case_input],
532
+ outputs=[case_input, state],
533
  )
534
+ chatbot_mode.change(**chatbot_mode_params)
535
+ case_input.change(**chatbot_mode_params)
536
+ clear_chatbot_messages_params = dict(
537
+ fn=reset_textbox, inputs=[], outputs=[input_message, chatbot, status_message]
538
  )
539
+ chatbot_mode.change(**clear_chatbot_messages_params)
540
+ case_input.change(**clear_chatbot_messages_params)
541
+ chat_submit_button.click(**clear_chatbot_messages_params)
542
+ input_message.submit(**clear_chatbot_messages_params)
 
 
 
 
543
 
544
  demo.queue(max_size=99, concurrency_count=99, api_open=False).launch(
545
  debug=True, auth=auth