rohitmenonhart commited on
Commit
ae73e63
·
verified ·
1 Parent(s): f41bbc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -19
app.py CHANGED
@@ -1,21 +1,19 @@
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
- client = InferenceClient(
5
- "mistralai/Mistral-7B-Instruct-v0.3"
6
- )
7
 
8
-
9
- def format_prompt(message, history):
10
- prompt = "<s>"
11
- for user_prompt, bot_response in history:
12
- prompt += f"[INST] {user_prompt} [/INST]"
13
- prompt += f" {bot_response}</s> "
14
- prompt += f"[INST] {message} [/INST]"
15
- return prompt
16
 
17
  def generate(
18
- prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
19
  ):
20
  temperature = float(temperature)
21
  if temperature < 1e-2:
@@ -31,7 +29,7 @@ def generate(
31
  seed=42,
32
  )
33
 
34
- formatted_prompt = format_prompt(prompt, history)
35
 
36
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
37
  output = ""
@@ -42,7 +40,7 @@ def generate(
42
  return output
43
 
44
 
45
- additional_inputs=[
46
  gr.Slider(
47
  label="Temperature",
48
  value=0.9,
@@ -78,18 +76,24 @@ additional_inputs=[
78
  step=0.05,
79
  interactive=True,
80
  info="Penalize repeated tokens",
81
- )
 
 
 
 
 
82
  ]
83
 
84
 
85
  gr.ChatInterface(
86
- fn=generate,
 
 
87
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
88
  additional_inputs=additional_inputs,
89
- title="""Mistral 7B v0.3"""
90
  ).launch(show_api=False)
91
 
92
-
93
  gr.load("models/ehristoforu/dalle-3-xl-v2").launch()
94
 
95
- gr.load("models/microsoft/Phi-3-mini-4k-instruct").launch()
 
1
  from huggingface_hub import InferenceClient
2
  import gradio as gr
3
 
4
+ client = InferenceClient("mistralai/Mistral-7B-Instruct-v0.3")
 
 
5
 
6
+ def format_prompt(message, history, persona=None):
7
+ prompt = "<s>"
8
+ if persona:
9
+ prompt += f"[ROLE: {persona}] "
10
+ for user_prompt, bot_response in history:
11
+ prompt += f"[INST] {user_prompt} [/INST] {bot_response}</s> "
12
+ prompt += f"[INST] {message} [/INST]"
13
+ return prompt
14
 
15
  def generate(
16
+ prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0, persona=None
17
  ):
18
  temperature = float(temperature)
19
  if temperature < 1e-2:
 
29
  seed=42,
30
  )
31
 
32
+ formatted_prompt = format_prompt(prompt, history, persona)
33
 
34
  stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
35
  output = ""
 
40
  return output
41
 
42
 
43
+ additional_inputs = [
44
  gr.Slider(
45
  label="Temperature",
46
  value=0.9,
 
76
  step=0.05,
77
  interactive=True,
78
  info="Penalize repeated tokens",
79
+ ),
80
+ gr.Textbox(
81
+ label="Persona",
82
+ placeholder="Describe the role (e.g., 'wise mentor', 'friendly assistant')",
83
+ info="Define a persona for the model to roleplay.",
84
+ ),
85
  ]
86
 
87
 
88
  gr.ChatInterface(
89
+ fn=lambda prompt, history, temperature, max_new_tokens, top_p, repetition_penalty, persona: generate(
90
+ prompt, history, temperature, max_new_tokens, top_p, repetition_penalty, persona
91
+ ),
92
  chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
93
  additional_inputs=additional_inputs,
94
+ title="Mistral 7B v0.3"
95
  ).launch(show_api=False)
96
 
 
97
  gr.load("models/ehristoforu/dalle-3-xl-v2").launch()
98
 
99
+ gr.load("models/microsoft/Phi-3-mini-4k-instruct").launch()