fukugawa commited on
Commit
5c5f097
·
1 Parent(s): a7c7893
docs/leaderboard_header.md CHANGED
@@ -1,7 +1,7 @@
1
  # 🚀 IndieBot Arena
2
 
3
  インディーボットアリーナはChatbot Arenaのインディーズ版のようなWebアプリです。
4
- Chatbot Arenaにインスパイアされて開発されましが、以下のような違いがあります。
5
 
6
  - 誰でもモデルを登録してコンペに参加可能
7
  - 重みのファイルサイズの階級別で戦う
@@ -13,4 +13,5 @@ Chatbot Arenaにインスパイアされて開発されましが、以下のよ
13
  Google公式モデルをみんなで倒しましょう!
14
 
15
  【更新履歴】
 
16
  - 2025/03/31 ベータ版を公開。機能は一通り完成していますが、性能をテスト中。
 
1
  # 🚀 IndieBot Arena
2
 
3
  インディーボットアリーナはChatbot Arenaのインディーズ版のようなWebアプリです。
4
+ Chatbot Arenaにインスパイアされて開発しましたが、以下のような違いがあります。
5
 
6
  - 誰でもモデルを登録してコンペに参加可能
7
  - 重みのファイルサイズの階級別で戦う
 
13
  Google公式モデルをみんなで倒しましょう!
14
 
15
  【更新履歴】
16
+ - 2025/04/03 チャットがストリーミング方式になり応答速度が向上しました。
17
  - 2025/03/31 ベータ版を公開。機能は一通り完成していますが、性能をテスト中。
indiebot_arena/config.py CHANGED
@@ -7,7 +7,7 @@ LANGUAGE = "ja"
7
  DEBUG = os.getenv("DEBUG", "False").lower() in ["true", "1", "yes"]
8
  LOCAL_TESTING = os.getenv("LOCAL_TESTING", "False").lower() in ["true", "1", "yes"]
9
  MODEL_SELECTION_MODE = os.getenv("MODEL_SELECTION_MODE", "random")
10
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
11
 
12
  if LOCAL_TESTING:
13
  MAX_NEW_TOKENS = 20
 
7
  DEBUG = os.getenv("DEBUG", "False").lower() in ["true", "1", "yes"]
8
  LOCAL_TESTING = os.getenv("LOCAL_TESTING", "False").lower() in ["true", "1", "yes"]
9
  MODEL_SELECTION_MODE = os.getenv("MODEL_SELECTION_MODE", "random")
10
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "512"))
11
 
12
  if LOCAL_TESTING:
13
  MAX_NEW_TOKENS = 20
indiebot_arena/ui/battle.py CHANGED
@@ -1,11 +1,13 @@
1
  import hashlib
2
  import os
3
  import random
 
 
4
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
  from indiebot_arena.config import MODEL_SELECTION_MODE, MAX_INPUT_TOKEN_LENGTH, MAX_NEW_TOKENS
11
  from indiebot_arena.service.arena_service import ArenaService
@@ -17,8 +19,8 @@ docs_path = os.path.join(base_dir, "docs", "battle_header.md")
17
 
18
 
19
  @spaces.GPU(duration=30)
20
- def generate(message: str, chat_history: list, model_id: str, max_new_tokens: int = MAX_NEW_TOKENS,
21
- temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2) -> str:
22
  tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  model = AutoModelForCausalLM.from_pretrained(
24
  model_id,
@@ -26,48 +28,57 @@ def generate(message: str, chat_history: list, model_id: str, max_new_tokens: in
26
  torch_dtype=torch.bfloat16,
27
  use_safetensors=True
28
  )
 
29
 
30
- conversation = chat_history.copy()
31
- conversation.append({"role": "user", "content": message})
32
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
33
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
34
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
35
- gr.Warning(f"Trimmed input as it exceeded {MAX_INPUT_TOKEN_LENGTH} tokens.")
36
  input_ids = input_ids.to(model.device)
37
- outputs = model.generate(
38
- input_ids=input_ids,
 
 
 
39
  max_new_tokens=max_new_tokens,
40
  do_sample=True,
41
  top_p=top_p,
42
  top_k=top_k,
43
  temperature=temperature,
 
44
  repetition_penalty=repetition_penalty,
45
  )
46
- response = tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)
47
- return response
 
 
 
 
 
48
 
49
 
50
- def format_chat_history(history):
51
- conversation = []
52
- for user_msg, assistant_msg in history:
53
- conversation.append({"role": "user", "content": user_msg})
54
- if assistant_msg:
55
- conversation.append({"role": "assistant", "content": assistant_msg})
56
- return conversation
57
 
58
 
59
- def submit_message(message, history_a, history_b, model_a, model_b):
60
- history_a.append((message, ""))
61
- history_b.append((message, ""))
62
- conv_history_a = format_chat_history(history_a[:-1])
63
- conv_history_b = format_chat_history(history_b[:-1])
 
 
64
 
65
- response_a = generate(message, conv_history_a, model_a)
66
- response_b = generate(message, conv_history_b, model_b)
67
 
68
- history_a[-1] = (message, response_a)
69
- history_b[-1] = (message, response_b)
70
- return history_a, history_b, "", gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=False)
 
 
 
 
71
 
72
 
73
  def get_random_values(model_labels):
@@ -175,8 +186,8 @@ def battle_content(dao, language):
175
  outputs=[model_dropdown_a, model_dropdown_b, dropdown_options_state]
176
  )
177
  with gr.Row():
178
- chatbot_a = gr.Chatbot(label="Chatbot A")
179
- chatbot_b = gr.Chatbot(label="Chatbot B")
180
  with gr.Row():
181
  vote_a_btn = gr.Button("A is better", variant="primary", interactive=False)
182
  vote_b_btn = gr.Button("B is better", variant="primary", interactive=False)
@@ -190,10 +201,23 @@ def battle_content(dao, language):
190
  vote_message = gr.Textbox(show_label=False, interactive=False, visible=False)
191
  with gr.Column(scale=1):
192
  next_battle_btn = gr.Button("次のバトルへ", variant="primary", interactive=False, visible=False, elem_id="next_battle_btn")
193
- user_input.submit(
194
- fn=submit_message,
195
- inputs=[user_input, chatbot_a, chatbot_b, model_dropdown_a, model_dropdown_b],
196
- outputs=[chatbot_a, chatbot_b, user_input, vote_a_btn, vote_b_btn, weight_class_radio]
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  )
198
  vote_a_btn.click(
199
  fn=on_vote_a_click,
 
1
  import hashlib
2
  import os
3
  import random
4
+ from collections.abc import Iterator
5
+ from threading import Thread
6
 
7
  import gradio as gr
8
  import spaces
9
  import torch
10
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
 
12
  from indiebot_arena.config import MODEL_SELECTION_MODE, MAX_INPUT_TOKEN_LENGTH, MAX_NEW_TOKENS
13
  from indiebot_arena.service.arena_service import ArenaService
 
19
 
20
 
21
  @spaces.GPU(duration=30)
22
+ def generate(chat_history: list, model_id: str, max_new_tokens: int = MAX_NEW_TOKENS,
23
+ temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2)-> Iterator[str]:
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_id,
 
28
  torch_dtype=torch.bfloat16,
29
  use_safetensors=True
30
  )
31
+ model.eval()
32
 
33
+ input_ids = tokenizer.apply_chat_template(chat_history, add_generation_prompt=True, return_tensors="pt")
 
 
34
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
35
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
36
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
37
  input_ids = input_ids.to(model.device)
38
+
39
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
40
+ generate_kwargs = dict(
41
+ {"input_ids": input_ids},
42
+ streamer=streamer,
43
  max_new_tokens=max_new_tokens,
44
  do_sample=True,
45
  top_p=top_p,
46
  top_k=top_k,
47
  temperature=temperature,
48
+ num_beams=1,
49
  repetition_penalty=repetition_penalty,
50
  )
51
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
52
+ t.start()
53
+
54
+ outputs = []
55
+ for text in streamer:
56
+ outputs.append(text)
57
+ yield "".join(outputs)
58
 
59
 
60
+ def update_user_message(user_message, history_a, history_b, weight_class_radio):
61
+ new_history_a = history_a + [{"role": "user", "content": user_message}]
62
+ new_history_b = history_b + [{"role": "user", "content": user_message}]
63
+ return "", new_history_a, new_history_b, gr.update(interactive=False)
 
 
 
64
 
65
 
66
+ def bot1_response(history, model_id):
67
+ history = history.copy()
68
+ history.append({"role": "assistant", "content": ""})
69
+ conv_history = history[:-1]
70
+ for text in generate(conv_history, model_id):
71
+ history[-1]["content"] = text
72
+ yield history, gr.update(interactive=True), gr.update(interactive=True)
73
 
 
 
74
 
75
+ def bot2_response(history, model_id):
76
+ history = history.copy()
77
+ history.append({"role": "assistant", "content": ""})
78
+ conv_history = history[:-1]
79
+ for text in generate(conv_history, model_id):
80
+ history[-1]["content"] = text
81
+ yield history, gr.update(interactive=True), gr.update(interactive=True)
82
 
83
 
84
  def get_random_values(model_labels):
 
186
  outputs=[model_dropdown_a, model_dropdown_b, dropdown_options_state]
187
  )
188
  with gr.Row():
189
+ chatbot_a = gr.Chatbot(label="Chatbot A", type="messages")
190
+ chatbot_b = gr.Chatbot(label="Chatbot B", type="messages")
191
  with gr.Row():
192
  vote_a_btn = gr.Button("A is better", variant="primary", interactive=False)
193
  vote_b_btn = gr.Button("B is better", variant="primary", interactive=False)
 
201
  vote_message = gr.Textbox(show_label=False, interactive=False, visible=False)
202
  with gr.Column(scale=1):
203
  next_battle_btn = gr.Button("次のバトルへ", variant="primary", interactive=False, visible=False, elem_id="next_battle_btn")
204
+ user_event = user_input.submit(
205
+ update_user_message,
206
+ inputs=[user_input, chatbot_a, chatbot_b, weight_class_radio],
207
+ outputs=[user_input, chatbot_a, chatbot_b, weight_class_radio],
208
+ queue=False
209
+ )
210
+ user_event.then(
211
+ bot1_response,
212
+ inputs=[chatbot_a, model_dropdown_a],
213
+ outputs=[chatbot_a, vote_a_btn, vote_b_btn],
214
+ queue=True
215
+ )
216
+ user_event.then(
217
+ bot2_response,
218
+ inputs=[chatbot_b, model_dropdown_b],
219
+ outputs=[chatbot_b, vote_a_btn, vote_b_btn],
220
+ queue=True
221
  )
222
  vote_a_btn.click(
223
  fn=on_vote_a_click,