phk0 commited on
Commit
027d8aa
·
verified ·
1 Parent(s): 02fcf54

added system_message

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -5,18 +5,18 @@ client = InferenceClient(
5
  "mistralai/Mistral-7B-Instruct-v0.3"
6
  )
7
 
8
- def format_prompt(message, history, system_prompt=None):
9
  prompt = "<s>"
10
  for user_prompt, bot_response in history:
11
  prompt += f"[INST] {user_prompt} [/INST]"
12
  prompt += f" {bot_response}</s> "
13
- if system_prompt:
14
- prompt += f"[SYS] {system_prompt} [/SYS]"
15
  prompt += f"[INST] {message} [/INST]"
16
  return prompt
17
 
18
  def generate(
19
- prompt, history, system_prompt=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
20
  ):
21
  temperature = float(temperature)
22
  if temperature < 1e-2:
@@ -32,7 +32,7 @@ def generate(
32
  seed=42,
33
  )
34
 
35
- formatted_prompt = format_prompt(prompt, history, system_prompt)
36
 
37
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
  output = ""
@@ -44,7 +44,7 @@ def generate(
44
 
45
 
46
  additional_inputs=[
47
- gr.Slider(
48
  label="System message",
49
  value="",
50
  interactive=True,
@@ -93,5 +93,5 @@ gr.ChatInterface(
93
  fn=generate,
94
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
95
  additional_inputs=additional_inputs,
96
- title="""Mistral 7B v0.3"""
97
  ).launch(show_api=False)
 
5
  "mistralai/Mistral-7B-Instruct-v0.3"
6
  )
7
 
8
+ def format_prompt(message, history, system_message=None):
9
  prompt = "<s>"
10
  for user_prompt, bot_response in history:
11
  prompt += f"[INST] {user_prompt} [/INST]"
12
  prompt += f" {bot_response}</s> "
13
+ if system_message:
14
+ prompt += f"[SYS] {system_message} [/SYS]"
15
  prompt += f"[INST] {message} [/INST]"
16
  return prompt
17
 
18
  def generate(
19
+ prompt, history, system_message=None, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
20
  ):
21
  temperature = float(temperature)
22
  if temperature < 1e-2:
 
32
  seed=42,
33
  )
34
 
35
+ formatted_prompt = format_prompt(prompt, history, system_message)
36
 
37
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
38
  output = ""
 
44
 
45
 
46
  additional_inputs=[
47
+ gr.TextArea(
48
  label="System message",
49
  value="",
50
  interactive=True,
 
93
  fn=generate,
94
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
95
  additional_inputs=additional_inputs,
96
+ title="""mistralai/Mistral-7B-Instruct-v0.3"""
97
  ).launch(show_api=False)