File size: 2,965 Bytes
86cf89c
2f68a4d
 
86cf89c
 
42c8aea
86cf89c
 
 
 
42c8aea
86cf89c
2f68a4d
7cdb4a4
2f68a4d
 
42c8aea
2f68a4d
 
86cf89c
 
2f68a4d
 
 
 
 
 
86cf89c
42c8aea
2f68a4d
 
 
 
 
 
 
 
 
 
 
 
86cf89c
90bb6d8
2f68a4d
 
 
 
42c8aea
2f68a4d
42c8aea
90bb6d8
2f68a4d
42c8aea
7cdb4a4
86cf89c
42c8aea
2f68a4d
42c8aea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86cf89c
 
 
42c8aea
7b77c74
86cf89c
 
 
 
 
2f68a4d
2d100fa
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from huggingface_hub import InferenceClient
import gradio as gr
import os

# Load prompts from JSON file
def load_prompts_from_json(file_path):
    with open(file_path, 'r') as file:
        return json.load(file)

# Load prompts from 'prompts.json'
prompts = load_prompts_from_json('prompts.json')

# Inference client
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")

# Secret prompt from environment variable (if needed)
secret_prompt = os.getenv("SECRET_PROMPT")

def format_prompt(new_message, history, prompt_type='default'):
    prompt = prompts.get(prompt_type, secret_prompt)
    for user_msg, bot_msg in history:
        prompt += f"[INST] {user_msg} [/INST]"
        prompt += f" {bot_msg}</s> "
    prompt += f"[INST] {new_message} [/INST]"
    return prompt

def generate(prompt, history, temperature=0.25, max_new_tokens=512, top_p=0.95, repetition_penalty=1.0, prompt_type='default'):
    # Configuration of parameters
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=727,
    )
    formatted_prompt = format_prompt(prompt, history, prompt_type)
    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""
    for response in stream:
        output += response.token.text
        yield output
    return output, history + [(prompt, output)]  # Store conversation history

# Chatbot without avatars and with transparent design
samir_chatbot = gr.Chatbot(bubble_full_width=True, show_label=False, show_copy_button=False, likeable=False)

# Dropdown for prompt types
prompt_type_dropdown = gr.Dropdown(choices=list(prompts.keys()), label="Prompt Type", value='default')

# Minimalistic theme and Gradio demo configuration
theme = 'syddharth/gray-minimal'

# Choose how you want to handle state:

# Option 1: No State Management (if conversation history is not needed)
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(lines=2, label="Input"),
        gr.Slider(0, 1, value=0.25, label="Temperature"),
        gr.Slider(1, 2048, value=512, step=1, label="Max Tokens"),
        gr.Slider(0, 1, value=0.95, label="Top P"),
        gr.Slider(1, 2, value=1.0, label="Repetition Penalty"),
        prompt_type_dropdown
    ],
    outputs=[samir_chatbot],
    title="Tutorial Master",
    theme=theme
)

# Option 2: State Management for Conversation History
demo = gr.Interface(
    fn=generate,
    inputs=[
        gr.Textbox(lines=2, label="Input"),
        "state"  # State input for conversation history
    ],
    outputs=[samir_chatbot],
    title="Tutorial Master",
    theme=theme
)

# Launch the demo with the queue
demo.queue().launch(show_api=False)