acecalisto3 commited on
Commit
5b2924b
·
verified ·
1 Parent(s): b212d6f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -47
app.py CHANGED
@@ -3,56 +3,10 @@ import requests
3
  import os
4
  import logging
5
  from datetime import datetime
6
- import torch
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
  from transformers import pipeline
9
 
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- messages = [
13
- {"role": "user", "content": "Who are you?"},
14
- ]
15
  pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
16
- pipe(messages)
17
-
18
- def stream_chat(
19
- message: str,
20
- history: list,
21
- system_prompt: str,
22
- temperature: float = 0.8,
23
- max_new_tokens: int = 1024,
24
- top_p: float = 1.0,
25
- top_k: int = 20,
26
- penalty: float = 1.2,
27
- ):
28
- conversation = [
29
- {"role": "system", "content": system_prompt}
30
- ]
31
- for prompt, answer in history:
32
- conversation.extend([
33
- {"role": "user", "content": prompt},
34
- {"role": "assistant", "content": answer},
35
- ])
36
-
37
- conversation.append({"role": "user", "content": message})
38
-
39
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
40
-
41
- streamer = TextStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
42
-
43
- generate_kwargs = dict(
44
- input_ids=input_ids,
45
- max_new_tokens=max_new_tokens,
46
- do_sample=temperature != 0,
47
- top_p=top_p,
48
- top_k=top_k,
49
- temperature=temperature,
50
- eos_token_id=[128001, 128008, 128009],
51
- streamer=streamer,
52
- )
53
-
54
- output = model.generate(**generate_kwargs)
55
- return tokenizer.decode(output[0], skip_special_tokens=True)
56
 
57
  app = Flask(__name__)
58
 
@@ -71,6 +25,16 @@ logging.basicConfig(
71
  )
72
  logger = logging.getLogger(__name__)
73
 
 
 
 
 
 
 
 
 
 
 
74
  @app.route('/')
75
  def index():
76
  return render_template('index.html')
 
3
  import os
4
  import logging
5
  from datetime import datetime
 
 
6
  from transformers import pipeline
7
 
8
+ # Initialize the pipeline
 
 
 
 
9
  pipe = pipeline("text-generation", model="mistralai/Mixtral-8x7B-Instruct-v0.1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  app = Flask(__name__)
12
 
 
25
  )
26
  logger = logging.getLogger(__name__)
27
 
28
+ def stream_chat(message: str, history: list, system_prompt: str):
29
+ conversation = [{"role": "system", "content": system_prompt}]
30
+ for prompt, answer in history:
31
+ conversation.extend([
32
+ {"role": "user", "content": prompt},
33
+ {"role": "assistant", "content": answer},
34
+ ])
35
+ conversation.append({"role": "user", "content": message})
36
+ return pipe(conversation)
37
+
38
  @app.route('/')
39
  def index():
40
  return render_template('index.html')