sergey21000 commited on
Commit
c0c6dc4
·
verified ·
1 Parent(s): a34817c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +289 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from shutil import rmtree
3
+ from typing import Union, List, Dict, Tuple, Optional
4
+ from tqdm import tqdm
5
+
6
+ import requests
7
+ from llama_cpp import Llama
8
+ import gradio as gr
9
+
10
+
11
+
12
+ # ================== VARIABLES =============================
13
+
14
+ MODELS_PATH = Path('models')
15
+ MODELS_PATH.mkdir(exist_ok=True)
16
+ DEFAULT_GGUF_URL = 'https://huggingface.co/bartowski/google_gemma-3-1b-it-GGUF/resolve/main/google_gemma-3-1b-it-Q8_0.gguf'
17
+
18
+
19
+ GENERATE_KWARGS = dict(
20
+ temperature=0.2,
21
+ top_p=0.95,
22
+ top_k=40,
23
+ repeat_penalty=1.0,
24
+ )
25
+
26
+ LLAMA_MODEL_KWARGS = dict(
27
+ n_gpu_layers=-1,
28
+ verbose=False,
29
+ n_ctx=4096,
30
+ )
31
+
32
+ # ================== ANNOTATIONS ========================
33
+
34
+ CHAT_HISTORY = List[Optional[Dict[str, Optional[str]]]]
35
+ MODEL_DICT = Dict[str, Llama]
36
+
37
+
38
+ # ================== FUNCS =============================
39
+
40
+ def download_file(file_url: str, file_path: Union[str, Path]) -> None:
41
+ response = requests.get(file_url, stream=True)
42
+ if response.status_code != 200:
43
+ raise Exception(f'Файл недоступен для скачивания по ссылке: {file_url}')
44
+ total_size = int(response.headers.get('content-length', 0))
45
+ progress_tqdm = tqdm(desc='Loading GGUF file', total=total_size, unit='iB', unit_scale=True)
46
+ progress_gradio = gr.Progress()
47
+ completed_size = 0
48
+ with open(file_path, 'wb') as file:
49
+ for data in response.iter_content(chunk_size=4096):
50
+ size = file.write(data)
51
+ progress_tqdm.update(size)
52
+ completed_size += size
53
+ desc = f'Loading GGUF file, {completed_size/1024**3:.3f}/{total_size/1024**3:.3f} GB'
54
+ progress_gradio(completed_size/total_size, desc=desc)
55
+
56
+
57
+ def download_gguf_and_init_model(gguf_url: str, model_dict: MODEL_DICT) -> Tuple[MODEL_DICT, bool, str]:
58
+ log = ''
59
+ if not gguf_url.endswith('.gguf'):
60
+ log += f'The link must be a direct link to the GGUF file\n'
61
+ return model_dict, log
62
+
63
+ gguf_filename = gguf_url.rsplit('/')[-1]
64
+ model_path = MODELS_PATH / gguf_filename
65
+ progress = gr.Progress()
66
+
67
+ if not model_path.is_file():
68
+ progress(0.3, desc='Шаг 1/2: Loading GGUF model file')
69
+ try:
70
+ download_file(gguf_url, model_path)
71
+ log += f'Model file {gguf_filename} successfully loaded\n'
72
+ except Exception as ex:
73
+ log += f'Error loading model from link {gguf_url}, error code:\n{ex}\n'
74
+ curr_model = model_dict.get('model')
75
+ if curr_model is None:
76
+ log += f'Model is missing from dictionary "model_dict"\n'
77
+ return model_dict, load_log
78
+ curr_model_filename = Path(curr_model.model_path).name
79
+ log += f'Current initialized model: {curr_model_filename}\n'
80
+ return model_dict, log
81
+ else:
82
+ log += f'Model file {gguf_filename} loaded, initializing model...\n'
83
+
84
+ progress(0.7, desc='Шаг 2/2: Model initialization')
85
+ model = Llama(model_path=str(model_path), **LLAMA_MODEL_KWARGS)
86
+ model_dict = {'model': model}
87
+ support_system_role = 'System role not supported' not in model.metadata['tokenizer.chat_template']
88
+ log += f'Model {gguf_filename} initialized\n'
89
+ return model_dict, support_system_role, log
90
+
91
+
92
+ def user_message_to_chatbot(user_message: str, chatbot: CHAT_HISTORY) -> Tuple[str, CHAT_HISTORY]:
93
+ if user_message:
94
+ chatbot.append({'role': 'user', 'content': user_message})
95
+ return '', chatbot
96
+
97
+
98
+ def bot_response_to_chatbot(
99
+ chatbot: CHAT_HISTORY,
100
+ model_dict: MODEL_DICT,
101
+ system_prompt: str,
102
+ support_system_role: bool,
103
+ history_len: int,
104
+ do_sample: bool,
105
+ *generate_args,
106
+ ):
107
+
108
+ model = model_dict.get('model')
109
+ if model is None:
110
+ gr.Info('Model not initialized')
111
+ yield chatbot
112
+ return
113
+
114
+ if len(chatbot) == 0 or chatbot[-1]['role'] == 'assistant':
115
+ yield chatbot
116
+ return
117
+
118
+ messages = []
119
+ if support_system_role and system_prompt:
120
+ messages.append({'role': 'system', 'content': system_prompt})
121
+
122
+ if history_len != 0:
123
+ messages.extend(chatbot[:-1][-(history_len*2):])
124
+
125
+ messages.append(chatbot[-1])
126
+
127
+ gen_kwargs = dict(zip(GENERATE_KWARGS.keys(), generate_args))
128
+ gen_kwargs['top_k'] = int(gen_kwargs['top_k'])
129
+ if not do_sample:
130
+ gen_kwargs['top_p'] = 0.0
131
+ gen_kwargs['top_k'] = 1
132
+ gen_kwargs['repeat_penalty'] = 1.0
133
+
134
+ stream_response = model.create_chat_completion(
135
+ messages=messages,
136
+ stream=True,
137
+ **gen_kwargs,
138
+ )
139
+
140
+ chatbot.append({'role': 'assistant', 'content': ''})
141
+ for chunk in stream_response:
142
+ token = chunk['choices'][0]['delta'].get('content')
143
+ if token is not None:
144
+ chatbot[-1]['content'] += token
145
+ yield chatbot
146
+
147
+
148
+ def get_system_prompt_component(interactive: bool) -> gr.Textbox:
149
+ value = '' if interactive else 'System prompt is not supported by this model'
150
+ return gr.Textbox(value=value, label='System prompt', interactive=interactive)
151
+
152
+
153
+ def get_generate_args(do_sample: bool) -> List[gr.component]:
154
+ generate_args = [
155
+ gr.Slider(minimum=0.1, maximum=3, value=GENERATE_KWARGS['temperature'], step=0.1, label='temperature', visible=do_sample),
156
+ gr.Slider(minimum=0, maximum=1, value=GENERATE_KWARGS['top_p'], step=0.01, label='top_p', visible=do_sample),
157
+ gr.Slider(minimum=1, maximum=50, value=GENERATE_KWARGS['top_k'], step=1, label='top_k', visible=do_sample),
158
+ gr.Slider(minimum=1, maximum=5, value=GENERATE_KWARGS['repeat_penalty'], step=0.1, label='repeat_penalty', visible=do_sample),
159
+ ]
160
+ return generate_args
161
+
162
+
163
+ # =============== INIT MODEL =============================
164
+
165
+ start_model_dict, start_support_system_role, start_load_log = download_gguf_and_init_model(
166
+ gguf_url=DEFAULT_GGUF_URL, model_dict={},
167
+ )
168
+
169
+
170
+ # ================== INTERFACE =============================
171
+
172
+ theme = gr.themes.Base(primary_hue='green', secondary_hue='yellow', neutral_hue='zinc').set(
173
+ loader_color='rgb(0, 255, 0)',
174
+ slider_color='rgb(0, 200, 0)',
175
+ body_text_color_dark='rgb(0, 200, 0)',
176
+ button_secondary_background_fill_dark='green',
177
+ )
178
+
179
+ # css = None
180
+ css = '''
181
+ .gradio-container {
182
+ width: 70% !important;
183
+ margin: 0 auto !important;
184
+ }
185
+ '''
186
+
187
+ with gr.Blocks(theme=theme, css=css) as interface:
188
+ model_dict = gr.State(start_model_dict)
189
+ support_system_role = gr.State(start_support_system_role)
190
+
191
+ # ================= CHAT BOT PAGE ======================
192
+ with gr.Tab('Chatbot'):
193
+ with gr.Row():
194
+ with gr.Column(scale=3):
195
+ chatbot = gr.Chatbot(
196
+ type='messages', # new in gradio 5+
197
+ show_copy_button=True,
198
+ bubble_full_width=False,
199
+ height=480,
200
+ )
201
+ user_message = gr.Textbox(label='User')
202
+
203
+ with gr.Row():
204
+ user_message_btn = gr.Button('Send')
205
+ stop_btn = gr.Button('Stop')
206
+ clear_btn = gr.Button('Clear')
207
+
208
+ system_prompt = get_system_prompt_component(interactive=support_system_role.value)
209
+
210
+ with gr.Column(scale=1, min_width=80):
211
+ with gr.Group():
212
+ gr.Markdown('Length of message history')
213
+ history_len = gr.Slider(
214
+ minimum=0,
215
+ maximum=10,
216
+ value=0,
217
+ step=1,
218
+ info='Number of previous messages taken into account in history',
219
+ label='history_len',
220
+ show_label=False,
221
+ )
222
+
223
+ with gr.Group():
224
+ gr.Markdown('Generation parameters')
225
+ do_sample = gr.Checkbox(
226
+ value=False,
227
+ label='do_sample',
228
+ info='Activate random sampling',
229
+ )
230
+ generate_args = get_generate_args(do_sample.value)
231
+ do_sample.change(
232
+ fn=get_generate_args,
233
+ inputs=do_sample,
234
+ outputs=generate_args,
235
+ show_progress=False,
236
+ )
237
+
238
+ generate_event = gr.on(
239
+ triggers=[user_message.submit, user_message_btn.click],
240
+ fn=user_message_to_chatbot,
241
+ inputs=[user_message, chatbot],
242
+ outputs=[user_message, chatbot],
243
+ ).then(
244
+ fn=bot_response_to_chatbot,
245
+ inputs=[chatbot, model_dict, system_prompt, support_system_role, history_len, do_sample, *generate_args],
246
+ outputs=[chatbot],
247
+ )
248
+ stop_btn.click(
249
+ fn=None,
250
+ inputs=None,
251
+ outputs=None,
252
+ cancels=generate_event,
253
+ )
254
+ clear_btn.click(
255
+ fn=lambda: None,
256
+ inputs=None,
257
+ outputs=[chatbot],
258
+ )
259
+
260
+ # ================= LOAD MODELS PAGE ======================
261
+ with gr.Tab('Load model'):
262
+ gguf_url = gr.Textbox(
263
+ value='',
264
+ label='Link to GGUF',
265
+ placeholder='URL link to the model in GGUF format',
266
+ )
267
+ load_model_btn = gr.Button('Downloading GGUF and initializing model')
268
+ load_log = gr.Textbox(
269
+ value=start_load_log,
270
+ label='Model loading status',
271
+ lines=3,
272
+ )
273
+
274
+ load_model_btn.click(
275
+ fn=download_gguf_and_init_model,
276
+ inputs=[gguf_url, model_dict],
277
+ outputs=[model_dict, support_system_role, load_log],
278
+ ).success(
279
+ fn=get_system_prompt_component,
280
+ inputs=[support_system_role],
281
+ outputs=[system_prompt],
282
+ )
283
+
284
+ gr.HTML("""<h3 style='text-align: center'>
285
+ <a href="https://github.com/sergey21000/gradio-llamacpp-chatbot" target='_blank'>GitHub Repository</a></h3>
286
+ """)
287
+
288
+ if __name__ == '__main__':
289
+ interface.launch(server_name='0.0.0.0', server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llama_cpp_python==0.3.8
2
+ gradio>5