Commit
·
f541eb3
1
Parent(s):
2ddd665
Add cpu inference option for testing
Browse files
app.py
CHANGED
@@ -4,10 +4,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, Stopping
|
|
4 |
import time
|
5 |
import numpy as np
|
6 |
from torch.nn import functional as F
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
11 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
12 |
|
13 |
|
@@ -29,8 +34,12 @@ class StopOnTokens(StoppingCriteria):
|
|
29 |
|
30 |
def contrastive_generate(text, bad_text):
|
31 |
with torch.no_grad():
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
history = None
|
35 |
bad_history = None
|
36 |
curr_output = list()
|
@@ -83,12 +92,13 @@ def system_update(msg):
|
|
83 |
|
84 |
|
85 |
with gr.Blocks() as demo:
|
|
|
86 |
with gr.Row():
|
87 |
with gr.Column():
|
88 |
chatbot = gr.Chatbot([])
|
89 |
clear = gr.Button("Clear")
|
90 |
with gr.Column():
|
91 |
-
system_msg = gr.Textbox(start_message, label="System Message", interactive=True)
|
92 |
msg = gr.Textbox(label="Chat Message")
|
93 |
|
94 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
|
|
4 |
import time
|
5 |
import numpy as np
|
6 |
from torch.nn import functional as F
|
7 |
+
import os
|
8 |
+
token_key = os.environ.get(“HUGGING_FACE_HUB_TOKEN”)
|
9 |
+
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16).cuda()
|
12 |
+
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
|
13 |
+
else:
|
14 |
+
m = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key, torch_dtype=torch.float16)
|
15 |
+
tok = AutoTokenizer.from_pretrained("stabilityai/stablelm-tuned-alpha-7b",use_auth_token=token_key)
|
16 |
generator = pipeline('text-generation', model=m, tokenizer=tok, device=0)
|
17 |
|
18 |
|
|
|
34 |
|
35 |
def contrastive_generate(text, bad_text):
|
36 |
with torch.no_grad():
|
37 |
+
if torch.cuda_is_available():
|
38 |
+
tokens = tok(text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
|
39 |
+
bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'].cuda()[:,:4096-1024]
|
40 |
+
else:
|
41 |
+
tokens = tok(text, return_tensors="pt")['input_ids'][:,:4096-1024]
|
42 |
+
bad_tokens = tok(bad_text, return_tensors="pt")['input_ids'][:,:4096-1024]
|
43 |
history = None
|
44 |
bad_history = None
|
45 |
curr_output = list()
|
|
|
92 |
|
93 |
|
94 |
with gr.Blocks() as demo:
|
95 |
+
gr.Markdown("###StableLM-tuned-Alpha-7B Chat")
|
96 |
with gr.Row():
|
97 |
with gr.Column():
|
98 |
chatbot = gr.Chatbot([])
|
99 |
clear = gr.Button("Clear")
|
100 |
with gr.Column():
|
101 |
+
system_msg = start_message#gr.Textbox(start_message, label="System Message", interactive=True)
|
102 |
msg = gr.Textbox(label="Chat Message")
|
103 |
|
104 |
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|