Locon213 commited on
Commit
bf7602e
·
verified ·
1 Parent(s): b7645ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -27
app.py CHANGED
@@ -1,21 +1,43 @@
1
  from peft import PeftModel
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
 
 
 
 
 
 
3
  import gradio as gr
 
4
 
5
- # Загрузка модели и токенизатора
6
  base_model = AutoModelForCausalLM.from_pretrained(
7
  "Qwen/Qwen2.5-0.5B-Instruct",
8
- device_map="auto"
 
 
9
  )
 
 
10
  model = PeftModel.from_pretrained(base_model, "Locon213/ThinkLite")
 
 
 
 
 
 
 
 
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
12
 
13
- # Конфигурация генерации
14
  generation_config = GenerationConfig(
15
  temperature=0.7,
16
  top_p=0.9,
17
  top_k=50,
18
- max_new_tokens=512,
19
  repetition_penalty=1.1,
20
  do_sample=True
21
  )
@@ -27,35 +49,57 @@ def format_prompt(message, history):
27
  prompt += f"<<<USER>>> {message}\n<<<ASSISTANT>>>"
28
  return prompt
29
 
30
- def generate_response(message, history):
31
- # Форматируем промпт с историей чата
32
  formatted_prompt = format_prompt(message, history)
33
-
34
- # Токенизация и генерация
35
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
36
- outputs = model.generate(
 
 
 
 
 
 
 
 
37
  **inputs,
38
  generation_config=generation_config,
 
39
  pad_token_id=tokenizer.eos_token_id
40
  )
41
 
42
- # Декодирование и извлечение ответа
43
- response = tokenizer.decode(outputs[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
44
 
45
- return response.strip()
46
-
47
- # Создание чат-интерфейса
48
- chat_interface = gr.ChatInterface(
49
- fn=generate_response,
50
- examples=[
51
- "Объясни квантовую запутанность простыми словами",
52
- "Как научиться программировать?",
53
- "Напиши стихотворение про ИИ"
54
- ],
55
- title="ThinkLite Chat",
56
- description="Общайтесь с ThinkLite - адаптированной версией Qwen2.5-0.5B-Instruct",
57
- theme="soft"
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  if __name__ == "__main__":
61
- chat_interface.launch()
 
1
  from peft import PeftModel
2
+ from transformers import (
3
+ AutoModelForCausalLM,
4
+ AutoTokenizer,
5
+ GenerationConfig,
6
+ TextIteratorStreamer
7
+ )
8
+ import torch
9
  import gradio as gr
10
+ from threading import Thread
11
 
12
+ # Загрузка и объединение модели с адаптерами
13
  base_model = AutoModelForCausalLM.from_pretrained(
14
  "Qwen/Qwen2.5-0.5B-Instruct",
15
+ device_map="auto",
16
+ torch_dtype=torch.float16,
17
+ low_cpu_mem_usage=True
18
  )
19
+
20
+ # Объединение основной модели с адаптерами
21
  model = PeftModel.from_pretrained(base_model, "Locon213/ThinkLite")
22
+ model = model.merge_and_unload()
23
+
24
+ # Применяем оптимизации для CPU
25
+ model = torch.quantization.quantize_dynamic(
26
+ model,
27
+ {torch.nn.Linear},
28
+ dtype=torch.qint8
29
+ )
30
+ model.config.use_cache = True
31
+
32
+ # Загрузка токенизатора
33
  tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
34
 
35
+ # Конфигурация генерации с оптимизированными параметрами
36
  generation_config = GenerationConfig(
37
  temperature=0.7,
38
  top_p=0.9,
39
  top_k=50,
40
+ max_new_tokens=256, # Уменьшено для экономии памяти
41
  repetition_penalty=1.1,
42
  do_sample=True
43
  )
 
49
  prompt += f"<<<USER>>> {message}\n<<<ASSISTANT>>>"
50
  return prompt
51
 
52
+ def generate_stream(message, history):
 
53
  formatted_prompt = format_prompt(message, history)
 
 
54
  inputs = tokenizer(formatted_prompt, return_tensors="pt").to(model.device)
55
+
56
+ streamer = TextIteratorStreamer(
57
+ tokenizer,
58
+ skip_prompt=True,
59
+ skip_special_tokens=True,
60
+ timeout=30
61
+ )
62
+
63
+ generation_kwargs = dict(
64
  **inputs,
65
  generation_config=generation_config,
66
+ streamer=streamer,
67
  pad_token_id=tokenizer.eos_token_id
68
  )
69
 
70
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
71
+ thread.start()
72
 
73
+ partial_message = ""
74
+ for new_token in streamer:
75
+ partial_message += new_token
76
+ yield partial_message
77
+
78
+ # Создание интерфейса с оптимизированным дизайном
79
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
80
+ gr.Markdown("# ThinkLite Chat (Optimized)")
81
+ gr.Markdown("🚀 Версия с потоковым выводом и оптимизацией для CPU")
82
+
83
+ chatbot = gr.Chatbot(height=400)
84
+ msg = gr.Textbox(label="Ваше сообщение")
85
+ clear_btn = gr.Button("Очистить историю")
86
+
87
+ def user(message, chat_history):
88
+ return "", chat_history + [[message, None]]
89
+
90
+ def bot(chat_history):
91
+ message = chat_history[-1][0]
92
+ history = chat_history[:-1]
93
+
94
+ chat_history[-1][1] = ""
95
+ for response in generate_stream(message, history):
96
+ chat_history[-1][1] = response
97
+ yield chat_history
98
+
99
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
100
+ bot, chatbot, chatbot
101
+ )
102
+ clear_btn.click(lambda: [], None, chatbot, queue=False)
103
 
104
  if __name__ == "__main__":
105
+ demo.queue(max_size=10).launch(debug=False)