Update app.py
Browse files
app.py
CHANGED
@@ -12,8 +12,8 @@ DEFAULT_MAX_NEW_TOKENS = 1024
|
|
12 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
13 |
|
14 |
#if torch.cuda.is_available():
|
15 |
-
model_id = "
|
16 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True,
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
tokenizer.use_default_system_prompt = False
|
19 |
|
@@ -29,14 +29,19 @@ def generate(
|
|
29 |
top_k: int = 50,
|
30 |
repetition_penalty: float = 1.2,
|
31 |
) -> Iterator[str]:
|
32 |
-
conversation = []
|
33 |
if system_prompt:
|
34 |
conversation.append({"role": "system", "content": system_prompt})
|
35 |
for user, assistant in chat_history:
|
36 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
37 |
-
conversation.append({"role": "user", "content": message})
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
-
input_ids = tokenizer
|
40 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
41 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
42 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
12 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
|
13 |
|
14 |
#if torch.cuda.is_available():
|
15 |
+
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto")
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
tokenizer.use_default_system_prompt = False
|
19 |
|
|
|
29 |
top_k: int = 50,
|
30 |
repetition_penalty: float = 1.2,
|
31 |
) -> Iterator[str]:
|
32 |
+
'''conversation = []
|
33 |
if system_prompt:
|
34 |
conversation.append({"role": "system", "content": system_prompt})
|
35 |
for user, assistant in chat_history:
|
36 |
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
37 |
+
conversation.append({"role": "user", "content": message})'''
|
38 |
+
prompt = "<s>"
|
39 |
+
for user_prompt, bot_response in chat_history:
|
40 |
+
prompt += f"[INST] {user_prompt} [/INST]"
|
41 |
+
prompt += f" {bot_response}</s> "
|
42 |
+
prompt += f"[INST] {message} [/INST]"
|
43 |
|
44 |
+
input_ids = tokenizer(conversation, return_tensors="pt")['input_ids']
|
45 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
46 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
47 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|