File size: 4,566 Bytes
3cbe237
 
 
 
 
 
 
 
 
 
9c8d57e
3cbe237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df0e5d2
 
 
 
 
3cbe237
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import subprocess
from threading import Thread
import os
import torch
import spaces
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer

# Update model configuration for Mistral-small-24B
MODEL_ID = "mistralai/Mistral-Small-24B-Instruct-2501"
CHAT_TEMPLATE = "mistral"  # Mistral uses its own chat template
MODEL_NAME = MODEL_ID.split("/")[-1]
CONTEXT_LENGTH = 32768  # Mistral supports longer context
COLOR = "black"
EMOJI = "🌪️"  # Mistral-themed emoji
DESCRIPTION = f"This is {MODEL_NAME} model, a powerful 24B parameter language model from Mistral AI."

def load_system_message():
    try:
        with open('system_message.txt', 'r', encoding='utf-8') as file:
            return file.read().strip()
    except FileNotFoundError:
        print("Warning: system_message.txt not found. Using default message.")
        return "You are a helpful assistant. First recognize the user request and then reply carefully with thinking."
    except Exception as e:
        print(f"Error loading system message: {e}")
        return "You are a helpful assistant. First recognize the user request and then reply carefully with thinking."

SYSTEM_MESSAGE = load_system_message()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_new_tokens, top_k, repetition_penalty, top_p):
    # Format history using Mistral's chat template
    messages = [{"role": "system", "content": SYSTEM_MESSAGE}]
    
    for user, assistant in history:
        messages.append({"role": "user", "content": user})
        messages.append({"role": "assistant", "content": assistant})
    
    messages.append({"role": "user", "content": message})
    
    # Convert messages to Mistral format
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
    input_ids, attention_mask = enc.input_ids, enc.attention_mask

    if input_ids.shape[1] > CONTEXT_LENGTH:
        input_ids = input_ids[:, -CONTEXT_LENGTH:]
        attention_mask = attention_mask[:, -CONTEXT_LENGTH:]

    generate_kwargs = dict(
        input_ids=input_ids.to(device),
        attention_mask=attention_mask.to(device),
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        yield "".join(outputs)

# Load model with optimized settings for Mistral-24B
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    use_double_quant=True,  # Enable double quantization
    bnb_4bit_quant_type="nf4"  # Use normal float 4 for better precision
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Set the pad token to be the same as the end of sequence token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    quantization_config=quantization_config,
    torch_dtype=torch.bfloat16
)

# Create Gradio interface
gr.ChatInterface(
    predict,
    title=EMOJI + " " + MODEL_NAME,
    description=DESCRIPTION,
    examples=[
        ['What are the key differences between classical and quantum computing?'],
        ['Explain the concept of recursive neural networks in simple terms.'],
        ['How does transfer learning work in large language models?'],
        ['What are the ethical considerations in AI development?']
    ],
    additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False),
    additional_inputs=[
        gr.Textbox(SYSTEM_MESSAGE, label="System prompt", visible=False),  # Hidden system prompt
        gr.Slider(0, 1, 0.7, label="Temperature"),  # Adjusted default for Mistral
        gr.Slider(0, 32768, 12000, label="Max new tokens"),  # Increased for longer context
        gr.Slider(1, 100, 50, label="Top K sampling"),
        gr.Slider(0, 2, 1.1, label="Repetition penalty"),
        gr.Slider(0, 1, 0.95, label="Top P sampling"),
    ],
).queue().launch()