SandLogicTechnologies commited on
Commit
dd0c774
·
verified ·
1 Parent(s): 80a7d2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -35
app.py CHANGED
@@ -7,69 +7,67 @@ import spaces
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
- MAX_MAX_NEW_TOKENS = 8096
11
- DEFAULT_MAX_NEW_TOKENS = 1024
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
16
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
17
  """
18
 
 
 
 
19
 
 
20
 
21
- # if not torch.cuda.is_available():
22
- # DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
23
-
24
-
25
- if torch.cuda.is_available():
26
- model_id = "SandLogicTechnologies/Shakti-2.5B"
27
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_id,
30
- device_map="auto",
31
- torch_dtype=torch.bfloat16,
32
- token=os.getenv("SHAKTI")
33
-
34
- )
35
-
36
-
37
 
38
- # tokenizer.use_default_system_prompt = False
 
39
 
40
 
41
- @spaces.GPU
42
  def generate(
43
  message: str,
44
  chat_history: list[tuple[str, str]],
45
- system_prompt: str,
46
  max_new_tokens: int = 1024,
47
- temperature: float = 0,
 
 
 
48
  ) -> Iterator[str]:
49
  conversation = []
50
- if system_prompt:
51
- conversation.append({"role": "system", "content": system_prompt})
52
- # else:
53
- # conversation.append(os.getenv("PROMPT"))
54
-
55
  for user, assistant in chat_history:
56
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
 
 
 
 
 
57
  conversation.append({"role": "user", "content": message})
58
 
59
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
  input_ids = input_ids.to(model.device)
64
 
65
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
  generate_kwargs = dict(
67
  {"input_ids": input_ids},
68
  streamer=streamer,
69
  max_new_tokens=max_new_tokens,
70
  do_sample=True,
 
 
71
  temperature=temperature,
72
  num_beams=1,
 
73
  )
74
  t = Thread(target=model.generate, kwargs=generate_kwargs)
75
  t.start()
@@ -83,7 +81,6 @@ def generate(
83
  chat_interface = gr.ChatInterface(
84
  fn=generate,
85
  additional_inputs=[
86
- gr.Textbox(label="System prompt", lines=6),
87
  gr.Slider(
88
  label="Max new tokens",
89
  minimum=1,
@@ -122,16 +119,15 @@ chat_interface = gr.ChatInterface(
122
  ],
123
  stop_btn=None,
124
  examples=[
125
- ["Instruct on how to commit insider trading in the stock market"],
126
  ],
127
  cache_examples=False,
128
  )
129
 
130
  with gr.Blocks(css="style.css", fill_height=True) as demo:
131
  gr.Markdown(DESCRIPTION)
132
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
133
  chat_interface.render()
134
- # gr.Markdown(LICENSE)
135
 
136
  if __name__ == "__main__":
137
  demo.queue(max_size=20).launch()
 
7
  import torch
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
 
 
 
10
 
11
  DESCRIPTION = """\
12
  Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service
13
  For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
14
  """
15
 
16
+ MAX_MAX_NEW_TOKENS = 2048
17
+ DEFAULT_MAX_NEW_TOKENS = 1024
18
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048"))
19
 
20
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
21
 
22
+ model_id = "SandLogicTechnologies/Shakti-2.5B"
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI"))
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ device_map="auto",
27
+ torch_dtype=torch.bfloat16,
28
+ token=os.getenv("SHAKTI")
 
 
 
 
 
 
 
 
 
29
 
30
+ )
31
+ model.eval()
32
 
33
 
34
+ @spaces.GPU(duration=90)
35
  def generate(
36
  message: str,
37
  chat_history: list[tuple[str, str]],
 
38
  max_new_tokens: int = 1024,
39
+ temperature: float = 0.6,
40
+ top_p: float = 0.9,
41
+ top_k: int = 50,
42
+ repetition_penalty: float = 1.2,
43
  ) -> Iterator[str]:
44
  conversation = []
 
 
 
 
 
45
  for user, assistant in chat_history:
46
+ conversation.extend(
47
+ [
48
+ {"role": "user", "content": user},
49
+ {"role": "assistant", "content": assistant},
50
+ ]
51
+ )
52
  conversation.append({"role": "user", "content": message})
53
 
54
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
55
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
56
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
57
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
58
  input_ids = input_ids.to(model.device)
59
 
60
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
61
  generate_kwargs = dict(
62
  {"input_ids": input_ids},
63
  streamer=streamer,
64
  max_new_tokens=max_new_tokens,
65
  do_sample=True,
66
+ top_p=top_p,
67
+ top_k=top_k,
68
  temperature=temperature,
69
  num_beams=1,
70
+ repetition_penalty=repetition_penalty,
71
  )
72
  t = Thread(target=model.generate, kwargs=generate_kwargs)
73
  t.start()
 
81
  chat_interface = gr.ChatInterface(
82
  fn=generate,
83
  additional_inputs=[
 
84
  gr.Slider(
85
  label="Max new tokens",
86
  minimum=1,
 
119
  ],
120
  stop_btn=None,
121
  examples=[
122
+ ["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं']
123
  ],
124
  cache_examples=False,
125
  )
126
 
127
  with gr.Blocks(css="style.css", fill_height=True) as demo:
128
  gr.Markdown(DESCRIPTION)
129
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
130
  chat_interface.render()
 
131
 
132
  if __name__ == "__main__":
133
  demo.queue(max_size=20).launch()