Tawkat commited on
Commit
df7e831
·
verified ·
1 Parent(s): c88ab80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -5
app.py CHANGED
@@ -12,8 +12,8 @@ DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  #if torch.cuda.is_available():
15
- model_id = "meta-llama/Llama-2-7b-chat-hf"
16
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True, token=HF_TOKEN, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
  tokenizer.use_default_system_prompt = False
19
 
@@ -29,14 +29,19 @@ def generate(
29
  top_k: int = 50,
30
  repetition_penalty: float = 1.2,
31
  ) -> Iterator[str]:
32
- conversation = []
33
  if system_prompt:
34
  conversation.append({"role": "system", "content": system_prompt})
35
  for user, assistant in chat_history:
36
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
37
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
38
 
39
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
40
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
41
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
42
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  #if torch.cuda.is_available():
15
+ model_id = "mistralai/Mistral-7B-Instruct-v0.1"
16
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
17
  tokenizer = AutoTokenizer.from_pretrained(model_id)
18
  tokenizer.use_default_system_prompt = False
19
 
 
29
  top_k: int = 50,
30
  repetition_penalty: float = 1.2,
31
  ) -> Iterator[str]:
32
+ '''conversation = []
33
  if system_prompt:
34
  conversation.append({"role": "system", "content": system_prompt})
35
  for user, assistant in chat_history:
36
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
37
+ conversation.append({"role": "user", "content": message})'''
38
+ prompt = "<s>"
39
+ for user_prompt, bot_response in chat_history:
40
+ prompt += f"[INST] {user_prompt} [/INST]"
41
+ prompt += f" {bot_response}</s> "
42
+ prompt += f"[INST] {message} [/INST]"
43
 
44
+ input_ids = tokenizer(conversation, return_tensors="pt")['input_ids']
45
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
46
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
47
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")