acecalisto3 commited on
Commit
dbf2fcd
·
verified ·
1 Parent(s): 5b2924b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -0
app.py CHANGED
@@ -4,6 +4,53 @@ import os
4
  import logging
5
  from datetime import datetime
6
  from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Initialize the pipeline
9
  pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
 
4
  import logging
5
  from datetime import datetime
6
  from transformers import pipeline
7
+ import spaces
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # Load the model and tokenizer
12
+ model_name = "mixtral/instruct-v0.1"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
15
+
16
+ def stream_chat(
17
+ message: str,
18
+ history: list,
19
+ system_prompt: str,
20
+ temperature: float = 0.8,
21
+ max_new_tokens: int = 1024,
22
+ top_p: float = 1.0,
23
+ top_k: int = 20,
24
+ penalty: float = 1.2,
25
+ ):
26
+ conversation = [
27
+ {"role": "system", "content": system_prompt}
28
+ ]
29
+ for prompt, answer in history:
30
+ conversation.extend([
31
+ {"role": "user", "content": prompt},
32
+ {"role": "assistant", "content": answer},
33
+ ])
34
+
35
+ conversation.append({"role": "user", "content": message})
36
+
37
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
38
+
39
+ streamer = TextStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
40
+
41
+ generate_kwargs = dict(
42
+ input_ids=input_ids,
43
+ max_new_tokens=max_new_tokens,
44
+ do_sample=temperature != 0,
45
+ top_p=top_p,
46
+ top_k=top_k,
47
+ temperature=temperature,
48
+ eos_token_id=[128001, 128008, 128009],
49
+ streamer=streamer,
50
+ )
51
+
52
+ output = model.generate(**generate_kwargs)
53
+ return tokenizer.decode(output[0], skip_special_tokens=True)
54
 
55
  # Initialize the pipeline
56
  pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")