File size: 3,608 Bytes
fddd482
a891312
a4b631b
a891312
a4b631b
93441ec
fddd482
b29974e
 
a4b631b
b29974e
18fd10c
e584606
9a0d2e2
e584606
04f373e
353522c
c2e7776
353522c
 
 
 
 
 
c2e7776
353522c
 
 
 
 
 
 
 
 
 
 
 
c2e7776
b362c6e
 
 
 
 
 
 
353522c
 
b29974e
116ecb1
403c2fe
 
a891312
 
403c2fe
a891312
 
 
 
 
 
 
 
03f8f02
a891312
 
 
 
 
403c2fe
a891312
 
 
 
 
b29974e
18fd10c
 
 
 
 
 
f014ce9
 
18fd10c
b29974e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
from threading import Thread

checkpoint = "marin-community/marin-8b-instruct"
device = "cuda"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)

@spaces.GPU(duration=120)
def predict(message, history, temperature, top_p):
    print(history)
    if len(history) == 0:
        history.append({"role": "system", "content": """
You are a helpful, knowledgeable, and versatile AI assistant powered by Marin 8B Instruct (deeper-starling-05-15), which was trained by the Marin team.
Knowledge cutoff: July 2024

## MODEL FACTS:
- 8B parameter Llama 3-style architecture
- 4096 hidden size, 14336 feedforward size
- 32 layers, 32 attention heads, 8 KV heads
- Trained on diverse datasets: Nemotron-CC, DCLM, Starcoder, Proofpile 2, FineMath, Dolma, Wikipedia, StackExchange, arXiv papers, and specialized instruction datasets
- LICENSE: Apache 2.0

## INTERACTION GUIDELINES:
- Respond helpfully to user queries while maintaining factual accuracy
- Think step-by-step when approaching complex reasoning or math problems
- Clearly state limitations and uncertainties when appropriate
- Aim for concise, useful responses that directly address user needs
- Use Markdown formatting for code blocks and structured content

## LIMITATIONS:
- May occasionally generate incorrect information
- Encourage users to excercise caution with your own outputs
- Not intended for fully autonomous use
- Responses should be verified for critical applications

## ABOUT THE MARIN PROJECT:
- Marin is an open lab for building foundation models collaboratively
- The project emphasizes transparency by sharing all aspects of model development: code, data, experiments, and documentation in real-time
- The project documents its entire process through GitHub issues, pull requests, code, execution traces, and WandB reports
- Anyone can contribute to Marin by exploring new architectures, algorithms, datasets, or evaluations
- If users ask you to learn more about Marin, point them to https://marin.community

Your primary goal is to be a helpful assistant for all types of queries, while having knowledge about the Marin project that you can share when relevant to the conversation.
"""})
    history.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # Create a streamer
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Set up generation parameters
    generation_kwargs = {
        "input_ids": inputs,
        "max_new_tokens": 1024,
        "temperature": float(temperature),
        "top_p": float(top_p),
        "do_sample": True,
        "streamer": streamer,
        "eos_token_id": 128009,
    }
    
    # Run generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Yield from the streamer as tokens are generated
    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

with gr.Blocks() as demo:
    chatbot = gr.ChatInterface(
        predict,
        additional_inputs=[
            gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
        ],
        type="messages"
    )

demo.launch()