rohitmenonhart commited on
Commit
76860a8
·
verified ·
1 Parent(s): ae73e63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -16
app.py CHANGED
@@ -3,17 +3,18 @@ import gradio as gr
3
 
4
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
5
 
6
- def format_prompt(message, history, persona=None):
7
- prompt = "<s>"
8
- if persona:
9
- prompt += f"[ROLE: {persona}] "
 
10
  for user_prompt, bot_response in history:
11
  prompt += f"[INST] {user_prompt} [/INST] {bot_response}</s> "
12
  prompt += f"[INST] {message} [/INST]"
13
  return prompt
14
 
15
  def generate(
16
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, persona=None
17
  ):
18
  temperature = float(temperature)
19
  if temperature < 1e-2:
@@ -29,7 +30,7 @@ def generate(
29
  seed=42,
30
  )
31
 
32
- formatted_prompt = format_prompt(prompt, history, persona)
33
 
34
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
  output = ""
@@ -76,19 +77,11 @@ additional_inputs = [
76
  step=0.05,
77
  interactive=True,
78
  info="Penalize repeated tokens",
79
- ),
80
- gr.Textbox(
81
- label="Persona",
82
- placeholder="Describe the role (e.g., 'wise mentor', 'friendly assistant')",
83
- info="Define a persona for the model to roleplay.",
84
- ),
85
  ]
86
 
87
-
88
  gr.ChatInterface(
89
- fn=lambda prompt, history, temperature, max_new_tokens, top_p, repetition_penalty, persona: generate(
90
- prompt, history, temperature, max_new_tokens, top_p, repetition_penalty, persona
91
- ),
92
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
93
  additional_inputs=additional_inputs,
94
  title="Mistral 7B v0.3"
 
3
 
4
  client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
5
 
6
+ # Define a fixed role
7
+ DEFAULT_PERSONA = "strict teacher"
8
+
9
+ def format_prompt(message, history):
10
+ prompt = f"<s>[ROLE: {DEFAULT_PERSONA}] "
11
  for user_prompt, bot_response in history:
12
  prompt += f"[INST] {user_prompt} [/INST] {bot_response}</s> "
13
  prompt += f"[INST] {message} [/INST]"
14
  return prompt
15
 
16
  def generate(
17
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0
18
  ):
19
  temperature = float(temperature)
20
  if temperature < 1e-2:
 
30
  seed=42,
31
  )
32
 
33
+ formatted_prompt = format_prompt(prompt, history)
34
 
35
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
36
  output = ""
 
77
  step=0.05,
78
  interactive=True,
79
  info="Penalize repeated tokens",
80
+ )
 
 
 
 
 
81
  ]
82
 
 
83
  gr.ChatInterface(
84
+ fn=generate,
 
 
85
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
86
  additional_inputs=additional_inputs,
87
  title="Mistral 7B v0.3"