File size: 6,912 Bytes
b376f12
 
 
 
 
 
 
 
 
 
 
 
ef3faba
b376f12
 
 
 
 
 
 
b8d64ca
19342c6
 
b8d64ca
6997fc5
b8d64ca
 
 
da09cca
b376f12
 
 
 
 
 
 
a7575a1
b376f12
da09cca
 
 
 
b376f12
 
 
 
03c2ae6
 
 
 
db24877
03c2ae6
 
 
 
b376f12
 
 
 
 
 
 
 
da09cca
 
 
 
 
03c2ae6
da09cca
b376f12
 
35e6309
b376f12
 
 
03c2ae6
b376f12
03c2ae6
 
 
b376f12
03c2ae6
b376f12
 
 
 
 
 
 
 
 
 
 
 
 
 
03c2ae6
edd0bac
 
 
 
 
 
 
 
 
03c2ae6
edd0bac
 
 
db24877
edd0bac
 
 
 
 
 
 
 
 
 
03c2ae6
 
da09cca
6efc321
5ce66fb
da09cca
 
 
 
3ba38dc
da09cca
19342c6
 
3ba38dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da09cca
 
 
03c2ae6
 
d7174fa
03c2ae6
 
 
 
 
da09cca
 
b376f12
da09cca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""Template Demo for IBM Granite Hugging Face spaces."""

from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from threading import Thread

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from themes.research_monochrome import theme

today_date = datetime.today().strftime("%B %-d, %Y")  # noqa: DTZ002

SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024.
Today's Date: {today_date}.
You are Granite, developed by IBM. You are a helpful AI assistant"""
TITLE = "IBM Granite 3.1 8b Instruct"
DESCRIPTION = """
<p>Granite 3.1 8b instruct is an open-source LLM supporting a 128k context window. Start with one of the sample prompts
or enter your own. Keep in mind that AI can occasionally make mistakes.
<span class="gr_docs_link">
<a href="https://www.ibm.com/granite/docs/">View Documentation <i class="fa fa-external-link"></i></a>
</span>
</p>
"""
MAX_INPUT_TOKEN_LENGTH = 128_000
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
TOP_P = 0.85
TOP_K = 50
REPETITION_PENALTY = 1.05

if not torch.cuda.is_available():
    print("This demo may not work on CPU.")

model = AutoModelForCausalLM.from_pretrained(
    "ibm-granite/granite-3.1-8b-instruct", torch_dtype=torch.float16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granite-3.1-8b-instruct")
tokenizer.use_default_system_prompt = False


@spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    temperature: float = TEMPERATURE,
    repetition_penalty: float = REPETITION_PENALTY,
    top_p: float = TOP_P,
    top_k: float = TOP_K,
    max_new_tokens: int = MAX_NEW_TOKENS,
) -> Iterator[str]:
    """Generate function for chat demo."""
    # Build messages
    conversation = []
    conversation.append({"role": "system", "content": SYS_PROMPT})
    conversation += chat_history
    conversation.append({"role": "user", "content": message})

    # Convert messages to prompt format
    input_ids = tokenizer.apply_chat_template(
        conversation,
        return_tensors="pt",
        add_generation_prompt=True,
        truncation=True,
        max_length=MAX_INPUT_TOKEN_LENGTH - max_new_tokens,
    )

    input_ids = input_ids.to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


css_file_path = Path(Path(__file__).parent / "app.css")
head_file_path = Path(Path(__file__).parent / "app_head.html")

# advanced settings (displayed in Accordion)
temperature_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"]
)
top_p_slider = gr.Slider(
    minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"]
)
top_k_slider = gr.Slider(
    minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"]
)
repetition_penalty_slider = gr.Slider(
    minimum=0,
    maximum=2.0,
    value=REPETITION_PENALTY,
    step=0.05,
    label="Repetition Penalty",
    elem_classes=["gr_accordion_element"],
)
max_new_tokens_slider = gr.Slider(
    minimum=1,
    maximum=2000,
    value=MAX_NEW_TOKENS,
    step=1,
    label="Max New Tokens",
    elem_classes=["gr_accordion_element"],
)
chat_interface_accordion = gr.Accordion(label="Advanced Settings", open=False)

with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo:
    gr.HTML(f"<h1>{TITLE}</h1>", elem_classes=["gr_title"])
    gr.HTML(DESCRIPTION)
    chat_interface = gr.ChatInterface(
        fn=generate,
        examples=[
            ["Explain the concept of quantum computing to someone with no background in physics or computer science."],
            ["What is OpenShift?"],
            ["What's the importance of low latency inference?"],
            ["Help me boost productivity habits."],
            [
                """Explain the following code in a concise manner:

```java
import java.util.ArrayList;
import java.util.List;

public class Main {

    public static void main(String[] args) {
        int[] arr = {1, 5, 3, 4, 2};
        int diff = 3;
        List<Pair> pairs = findPairs(arr, diff);
        for (Pair pair : pairs) {
            System.out.println(pair.x + " " + pair.y);
        }
    }

    public static List<Pair> findPairs(int[] arr, int diff) {
        List<Pair> pairs = new ArrayList<>();
        for (int i = 0; i < arr.length; i++) {
            for (int j = i + 1; j < arr.length; j++) {
                if (Math.abs(arr[i] - arr[j]) < diff) {
                    pairs.add(new Pair(arr[i], arr[j]));
                }
            }
        }

        return pairs;
    }
}

class Pair {
    int x;
    int y;
    public Pair(int x, int y) {
        this.x = x;
        this.y = y;
    }
}
```"""
            ],
            [
                """Generate a Java code block from the following explanation:

The code in the Main class finds all pairs in an array whose absolute difference is less than a given value.

The findPairs method takes two arguments: an array of integers and a difference value. It iterates over the array and compares each element to every other element in the array. If the absolute difference between the two elements is less than the difference value, a new Pair object is created and added to a list.

The Pair class is a simple data structure that stores two integers.

The main method creates an array of integers, initializes the difference value, and calls the findPairs method to find all pairs in the array. Finally, the code iterates over the list of pairs and prints each pair to the console."""  # noqa: E501
            ],
        ],
        example_labels=[
            "Explain quantum computing",
            "What is OpenShift?",
            "Importance of low latency inference",
            "Boosting productivity habits",
            "Explain and document your code",
            "Generate Java Code",
        ],
        cache_examples=False,
        type="messages",
        additional_inputs=[
            temperature_slider,
            repetition_penalty_slider,
            top_p_slider,
            top_k_slider,
            max_new_tokens_slider,
        ],
        additional_inputs_accordion=chat_interface_accordion,
    )

if __name__ == "__main__":
    demo.queue().launch()