acecalisto3 commited on
Commit
7bc4eb7
·
verified ·
1 Parent(s): 6cb7bb4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -7
app.py CHANGED
@@ -3,11 +3,17 @@ import requests
3
  import os
4
  import logging
5
  from datetime import datetime
 
 
 
6
 
 
7
 
8
- device = "cuda"
 
 
 
9
 
10
- @spaces.GPU()
11
  def stream_chat(
12
  message: str,
13
  history: list,
@@ -18,9 +24,6 @@ def stream_chat(
18
  top_k: int = 20,
19
  penalty: float = 1.2,
20
  ):
21
- print(f'message: {message}')
22
- print(f'history: {history}')
23
-
24
  conversation = [
25
  {"role": "system", "content": system_prompt}
26
  ]
@@ -32,9 +35,9 @@ def stream_chat(
32
 
33
  conversation.append({"role": "user", "content": message})
34
 
35
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(model.device)
36
 
37
- streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
38
 
39
  generate_kwargs = dict(
40
  input_ids=input_ids,
@@ -46,6 +49,10 @@ def stream_chat(
46
  eos_token_id=[128001, 128008, 128009],
47
  streamer=streamer,
48
  )
 
 
 
 
49
  app = Flask(__name__)
50
 
51
  # Configure logging
 
3
  import os
4
  import logging
5
  from datetime import datetime
6
+ import torch
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from transformers import TextStreamer
9
 
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Load the model and tokenizer
13
+ model_name = "mixtral/instruct-v0.1"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
16
 
 
17
  def stream_chat(
18
  message: str,
19
  history: list,
 
24
  top_k: int = 20,
25
  penalty: float = 1.2,
26
  ):
 
 
 
27
  conversation = [
28
  {"role": "system", "content": system_prompt}
29
  ]
 
35
 
36
  conversation.append({"role": "user", "content": message})
37
 
38
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt").to(device)
39
 
40
+ streamer = TextStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
41
 
42
  generate_kwargs = dict(
43
  input_ids=input_ids,
 
49
  eos_token_id=[128001, 128008, 128009],
50
  streamer=streamer,
51
  )
52
+
53
+ output = model.generate(**generate_kwargs)
54
+ return tokenizer.decode(output[0], skip_special_tokens=True)
55
+
56
  app = Flask(__name__)
57
 
58
  # Configure logging