Artples commited on
Commit
130e4a8
·
verified ·
1 Parent(s): 7073a02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -79
app.py CHANGED
@@ -1,11 +1,8 @@
1
  import os
2
- from threading import Thread
3
- from typing import Iterator
4
-
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
@@ -19,20 +16,16 @@ This Space demonstrates [L-MChat](https://huggingface.co/collections/Artples/l-m
19
  if not torch.cuda.is_available():
20
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
21
 
22
- # Dictionary to manage model details
23
  model_details = {
24
  "Fast-Model": "Artples/L-MChat-Small",
25
  "Quality-Model": "Artples/L-MChat-7b"
26
  }
27
 
28
- # Initialize models and tokenizers based on availability
29
  models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in model_details.items()}
30
  tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in model_details.items()}
31
- for tokenizer in tokenizers.values():
32
- tokenizer.use_default_system_prompt = False
33
 
34
  @spaces.GPU(enable_queue=True, duration=90)
35
- def generate(
36
  model_choice: str,
37
  message: str,
38
  chat_history: list[tuple[str, str]],
@@ -42,86 +35,26 @@ def generate(
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
- ) -> Iterator[str]:
46
  model = models[model_choice]
47
  tokenizer = tokenizers[model_choice]
48
-
49
- conversation = []
50
- if system_prompt:
51
- conversation.append({"role": "system", "content": system_prompt})
52
- for user, assistant in chat_history:
53
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
54
  conversation.append({"role": "user", "content": message})
55
 
56
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True)
57
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
58
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
59
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
60
  input_ids = input_ids.to(model.device)
61
 
62
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
63
- generate_kwargs = dict(
64
- {"input_ids": input_ids},
65
- streamer=streamer,
66
- max_new_tokens=max_new_tokens,
67
- do_sample=True,
68
- top_p=top_p,
69
- top_k=top_k,
70
- temperature=temperature,
71
- num_beams=1,
72
- repetition_penalty=repetition_penalty,
73
- )
74
- t = Thread(target=model.generate, kwargs=generate_kwargs)
75
- t.start()
76
 
77
- outputs = []
78
- for text in streamer:
79
- outputs.append(text)
80
- yield "".join(outputs)
81
 
82
  chat_interface = gr.ChatInterface(
83
  theme='ehristoforu/RE_Theme',
84
  fn=generate,
85
- additional_inputs=[
86
- gr.Textbox(label="System prompt", lines=6),
87
- gr.Dropdown(label="Model Choice", choices=list(model_details.keys()), value="Quality-Model"),
88
- gr.Slider(
89
- label="Max new tokens",
90
- minimum=1,
91
- maximum=MAX_MAX_NEW_TOKENS,
92
- step=1,
93
- value=DEFAULT_MAX_NEW_TOKENS,
94
- ),
95
- gr.Slider(
96
- label="Temperature",
97
- minimum=0.1,
98
- maximum=4.0,
99
- step=0.1,
100
- value=0.6,
101
- ),
102
- gr.Slider(
103
- label="Top-p (nucleus sampling)",
104
- minimum=0.05,
105
- maximum=1.0,
106
- step=0.05,
107
- value=0.9,
108
- ),
109
- gr.Slider(
110
- label="Top-k",
111
- minimum=1,
112
- maximum=1000,
113
- step=1,
114
- value=50,
115
- ),
116
- gr.Slider(
117
- label="Repetition penalty",
118
- minimum=1.0,
119
- maximum=2.0,
120
- step=0.05,
121
- value=1.2,
122
- ),
123
- ],
124
- stop_btn=None,
125
  examples=[
126
  ["Hello there! How are you doing?"],
127
  ["Can you explain briefly to me what is the Python programming language?"],
@@ -136,4 +69,4 @@ with gr.Blocks(css="style.css") as demo:
136
  chat_interface.render()
137
 
138
  if __name__ == "__main__":
139
- demo.queue(max_size=20).launch()
 
1
  import os
 
 
 
2
  import gradio as gr
3
  import spaces
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
  MAX_MAX_NEW_TOKENS = 2048
8
  DEFAULT_MAX_NEW_TOKENS = 1024
 
16
  if not torch.cuda.is_available():
17
  DESCRIPTION += "\n<p>Running on CPU! This demo does not work on CPU.</p>"
18
 
 
19
  model_details = {
20
  "Fast-Model": "Artples/L-MChat-Small",
21
  "Quality-Model": "Artples/L-MChat-7b"
22
  }
23
 
 
24
  models = {name: AutoModelForCausalLM.from_pretrained(model_id, device_map="auto") for name, model_id in model_details.items()}
25
  tokenizers = {name: AutoTokenizer.from_pretrained(model_id) for name, model_id in model_details.items()}
 
 
26
 
27
  @spaces.GPU(enable_queue=True, duration=90)
28
+ async def generate(
29
  model_choice: str,
30
  message: str,
31
  chat_history: list[tuple[str, str]],
 
35
  top_p: float = 0.9,
36
  top_k: int = 50,
37
  repetition_penalty: float = 1.2,
38
+ ) -> str:
39
  model = models[model_choice]
40
  tokenizer = tokenizers[model_choice]
41
+
42
+ conversation = [{"role": "system", "content": system_prompt}] if system_prompt else []
43
+ conversation += [{"role": "user", "content": user}, {"role": "assistant", "content": assistant} for user, assistant in chat_history]
 
 
 
44
  conversation.append({"role": "user", "content": message})
45
 
46
+ input_ids = tokenizer(conversation, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKEN_LENGTH).input_ids
 
 
 
47
  input_ids = input_ids.to(model.device)
48
 
49
+ output_ids = model.generate(input_ids, max_length=MAX_INPUT_TOKEN_LENGTH + max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty)
50
+ output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ return output_text
 
 
 
53
 
54
  chat_interface = gr.ChatInterface(
55
  theme='ehristoforu/RE_Theme',
56
  fn=generate,
57
+ additional_inputs=[gr.Textbox(label="System prompt", lines=6), gr.Dropdown(label="Model Choice", choices=list(model_details.keys()), value="Quality-Model")],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  examples=[
59
  ["Hello there! How are you doing?"],
60
  ["Can you explain briefly to me what is the Python programming language?"],
 
69
  chat_interface.render()
70
 
71
  if __name__ == "__main__":
72
+ demo.launch()