Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -3,7 +3,9 @@ import gradio as gr
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from transformers.utils import logging as hf_logging
|
5 |
|
6 |
-
#
|
|
|
|
|
7 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
8 |
LOG_FILE = "/data/requests.log"
|
9 |
|
@@ -19,26 +21,34 @@ def log(msg: str):
|
|
19 |
pass
|
20 |
|
21 |
|
22 |
-
#
|
|
|
|
|
23 |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
|
24 |
-
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH =
|
25 |
|
26 |
SYSTEM_MSG = (
|
27 |
-
"You are **SchoolSpirit
|
28 |
"founded by Charles Norton in 2025. The company installs on‑prem AI chat "
|
29 |
-
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to
|
30 |
-
"
|
31 |
-
"
|
|
|
|
|
|
|
32 |
"• No personal‑data collection or sensitive advice.\n"
|
33 |
-
"•
|
34 |
-
"• Avoid profanity, politics, or mature themes."
|
35 |
)
|
36 |
-
WELCOME_MSG = "Welcome to SchoolSpirit
|
37 |
|
38 |
-
strip = lambda s: re.sub(r"\s+", " ", s.strip())
|
39 |
|
|
|
|
|
40 |
|
41 |
-
|
|
|
|
|
|
|
42 |
hf_logging.set_verbosity_error()
|
43 |
try:
|
44 |
log("Loading tokenizer …")
|
@@ -62,7 +72,6 @@ try:
|
|
62 |
max_new_tokens=MAX_TOKENS,
|
63 |
do_sample=True,
|
64 |
temperature=0.6,
|
65 |
-
pad_token_id=tok.eos_token_id,
|
66 |
)
|
67 |
MODEL_ERR = None
|
68 |
log("Model loaded ✔")
|
@@ -71,30 +80,41 @@ except Exception as exc: # noqa: BLE001
|
|
71 |
log(MODEL_ERR)
|
72 |
|
73 |
|
74 |
-
#
|
75 |
-
|
|
|
|
|
76 |
"""
|
77 |
-
history
|
|
|
|
|
78 |
"""
|
79 |
if MODEL_ERR:
|
80 |
-
return history + [
|
81 |
|
82 |
user_msg = strip(user_msg or "")
|
83 |
if not user_msg:
|
84 |
-
return history + [
|
85 |
if len(user_msg) > MAX_INPUT_CH:
|
86 |
warn = f"Message too long (>{MAX_INPUT_CH} chars)."
|
87 |
-
return history + [
|
88 |
-
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
#
|
93 |
-
convo = [m for m in
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
|
99 |
try:
|
100 |
raw = gen(prompt)[0]["generated_text"]
|
@@ -104,22 +124,28 @@ def chat_fn(user_msg: str, history: list[dict]):
|
|
104 |
log("❌ Inference error:\n" + traceback.format_exc())
|
105 |
reply = "Sorry—backend crashed. Please try again later."
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
],
|
120 |
-
)
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from transformers.utils import logging as hf_logging
|
5 |
|
6 |
+
# -------------------------------------------------------------------
|
7 |
+
# 1. Logging helpers
|
8 |
+
# -------------------------------------------------------------------
|
9 |
os.environ["HF_HOME"] = "/data/.huggingface"
|
10 |
LOG_FILE = "/data/requests.log"
|
11 |
|
|
|
21 |
pass
|
22 |
|
23 |
|
24 |
+
# -------------------------------------------------------------------
|
25 |
+
# 2. Configuration
|
26 |
+
# -------------------------------------------------------------------
|
27 |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
|
28 |
+
MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
|
29 |
|
30 |
SYSTEM_MSG = (
|
31 |
+
"You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
|
32 |
"founded by Charles Norton in 2025. The company installs on‑prem AI chat "
|
33 |
+
"mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
|
34 |
+
"K‑12 schools.\n\n"
|
35 |
+
"GUIDELINES:\n"
|
36 |
+
"• Warm, encouraging tone for students, parents, staff.\n"
|
37 |
+
"• Replies ≤ 4 sentences unless asked for detail.\n"
|
38 |
+
"• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
|
39 |
"• No personal‑data collection or sensitive advice.\n"
|
40 |
+
"• No profanity, politics, or mature themes."
|
|
|
41 |
)
|
42 |
+
WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
|
43 |
|
|
|
44 |
|
45 |
+
def strip(s: str) -> str:
|
46 |
+
return re.sub(r"\s+", " ", s.strip())
|
47 |
|
48 |
+
|
49 |
+
# -------------------------------------------------------------------
|
50 |
+
# 3. Load model (GPU FP‑16 → CPU fallback)
|
51 |
+
# -------------------------------------------------------------------
|
52 |
hf_logging.set_verbosity_error()
|
53 |
try:
|
54 |
log("Loading tokenizer …")
|
|
|
72 |
max_new_tokens=MAX_TOKENS,
|
73 |
do_sample=True,
|
74 |
temperature=0.6,
|
|
|
75 |
)
|
76 |
MODEL_ERR = None
|
77 |
log("Model loaded ✔")
|
|
|
80 |
log(MODEL_ERR)
|
81 |
|
82 |
|
83 |
+
# -------------------------------------------------------------------
|
84 |
+
# 4. Chat callback
|
85 |
+
# -------------------------------------------------------------------
|
86 |
+
def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
|
87 |
"""
|
88 |
+
history: list of (user, assistant) tuples (Gradio default)
|
89 |
+
state : dict carrying system_prompt + raw_history for the model
|
90 |
+
Returns updated history (for UI) and state (for next round)
|
91 |
"""
|
92 |
if MODEL_ERR:
|
93 |
+
return history + [(user_msg, MODEL_ERR)], state
|
94 |
|
95 |
user_msg = strip(user_msg or "")
|
96 |
if not user_msg:
|
97 |
+
return history + [(user_msg, "Please type something.")], state
|
98 |
if len(user_msg) > MAX_INPUT_CH:
|
99 |
warn = f"Message too long (>{MAX_INPUT_CH} chars)."
|
100 |
+
return history + [(user_msg, warn)], state
|
101 |
+
|
102 |
+
# ------------------------------------------------ Prompt assembly
|
103 |
+
raw_hist = state.get("raw", [])
|
104 |
+
raw_hist.append({"role": "user", "content": user_msg})
|
105 |
+
# keep system + last N exchanges
|
106 |
+
convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
|
107 |
+
raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
|
108 |
+
|
109 |
+
prompt = "\n".join(
|
110 |
+
[
|
111 |
+
m["content"]
|
112 |
+
if m["role"] == "system"
|
113 |
+
else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
|
114 |
+
for m in raw_hist
|
115 |
+
]
|
116 |
+
+ ["AI:"]
|
117 |
+
)
|
118 |
|
119 |
try:
|
120 |
raw = gen(prompt)[0]["generated_text"]
|
|
|
124 |
log("❌ Inference error:\n" + traceback.format_exc())
|
125 |
reply = "Sorry—backend crashed. Please try again later."
|
126 |
|
127 |
+
# ------------------------------------------------ Update state + UI history
|
128 |
+
raw_hist.append({"role": "assistant", "content": reply})
|
129 |
+
state["raw"] = raw_hist
|
130 |
+
history.append((user_msg, reply))
|
131 |
+
return history, state
|
132 |
+
|
133 |
+
|
134 |
+
# -------------------------------------------------------------------
|
135 |
+
# 5. Launch
|
136 |
+
# -------------------------------------------------------------------
|
137 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
138 |
+
chatbot = gr.Chatbot(
|
139 |
+
value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
|
140 |
+
)
|
141 |
+
state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
|
142 |
+
with gr.Row():
|
143 |
+
txt = gr.Textbox(
|
144 |
+
scale=4, placeholder="Type your question here...", show_label=False
|
145 |
+
)
|
146 |
+
send = gr.Button("Send", variant="primary")
|
147 |
+
|
148 |
+
send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
|
149 |
+
txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
|
150 |
+
|
151 |
+
demo.launch()
|