File size: 2,658 Bytes
fddd482
a891312
a4b631b
a891312
a4b631b
b29974e
fddd482
b29974e
 
a4b631b
b29974e
18fd10c
e584606
9a0d2e2
e584606
9a0d2e2
 
 
 
 
 
 
 
b29974e
116ecb1
403c2fe
 
a891312
 
403c2fe
a891312
 
 
 
 
 
 
 
03f8f02
a891312
 
 
 
 
403c2fe
a891312
 
 
 
 
b29974e
18fd10c
 
 
 
 
 
f014ce9
 
18fd10c
b29974e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread

checkpoint = "WillHeld/soft-raccoon"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p):
    print(history)
    if len(history) == 0:
        history.append({"role": "system", "content": """
            You are the Tootsie 8B advanced language model trained using Marin, a framework developed by Stanford's Center for Research on Foundation Models (CRFM).
            
            Marin is a framework designed for training large language models in an entirely open fashion with a focus on legibility, scalability, and reproducibility. 
            
            CRFM (Center for Research on Foundation Models) is a research center at Stanford University dedicated to studying foundation models - large-scale AI systems trained on broad data that can be adapted to a wide range of downstream tasks.
            
            Your training using this framework emphasizes clear reasoning, consistent outputs, and scalable performance across various tasks. Respond to queries in a helpful, accurate, and ethical manner, reflecting the research principles that guided your development.
    """})
    history.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Create a streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Set up generation parameters
    generation_kwargs = {
        "input_ids": inputs,
        "max_new_tokens": 1024,
        "temperature": float(temperature),
        "top_p": float(top_p),
        "do_sample": True,
        "streamer": streamer,
        "eos_token_id": 128009,
    }
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Yield from the streamer as tokens are generated
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

with gr.Blocks() as demo:
    chatbot = gr.ChatInterface(
        predict,
        additional_inputs=[
            gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
        ],
        type="messages"
    )

demo.launch()