Tonic commited on
Commit
4ce621a
·
unverified ·
1 Parent(s): 66a9100

add singleton to avoid threading issues

Browse files
Files changed (1) hide show
  1. app.py +26 -10
app.py CHANGED
@@ -1,16 +1,30 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
4
 
5
  def load_model():
6
- model_id = "microsoft/bitnet-b1.58-2B-4T"
7
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
8
- model = AutoModelForCausalLM.from_pretrained(
9
- model_id,
10
- torch_dtype=torch.bfloat16,
11
- trust_remote_code=True
12
- )
13
- return model, tokenizer
 
 
 
 
 
 
 
 
 
 
14
 
15
  def manage_history(history):
16
  # Limit to 3 turns (each turn is user + assistant = 2 messages)
@@ -141,4 +155,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
141
  )
142
 
143
  if __name__ == "__main__":
144
- demo.launch(ssr_mode=False)
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
4
+
5
+ # Singleton for model and tokenizer
6
+ _model = None
7
+ _tokenizer = None
8
 
9
  def load_model():
10
+ global _model, _tokenizer
11
+ if _model is None or _tokenizer is None:
12
+ model_id = "microsoft/bitnet-b1.58-2B-4T"
13
+ _tokenizer = AutoTokenizer.from_pretrained(
14
+ model_id,
15
+ trust_remote_code=True
16
+ )
17
+ config = AutoConfig.from_pretrained(
18
+ model_id,
19
+ trust_remote_code=True
20
+ )
21
+ _model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ config=config,
24
+ torch_dtype=torch.bfloat16,
25
+ trust_remote_code=True
26
+ )
27
+ return _model, _tokenizer
28
 
29
  def manage_history(history):
30
  # Limit to 3 turns (each turn is user + assistant = 2 messages)
 
155
  )
156
 
157
  if __name__ == "__main__":
158
+ # Preload model to avoid threading issues
159
+ load_model()
160
+ demo.launch(ssr_mode=False, share=True)