phanerozoic commited on
Commit
94cf8c8
·
verified ·
1 Parent(s): e8212fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -48
app.py CHANGED
@@ -3,7 +3,9 @@ import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
- # ---------- Logging ---------------------------------------------------------
 
 
7
  os.environ["HF_HOME"] = "/data/.huggingface"
8
  LOG_FILE = "/data/requests.log"
9
 
@@ -19,26 +21,34 @@ def log(msg: str):
19
  pass
20
 
21
 
22
- # ---------- Config ----------------------------------------------------------
 
 
23
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
24
- MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 6, 128, 400
25
 
26
  SYSTEM_MSG = (
27
- "You are **SchoolSpirit AI**, the digital mascot for SchoolSpirit AI LLC, "
28
  "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
29
- "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to schools.\n\n"
30
- "Guidelines:\n"
31
- "• Warm, concise answers (max 4 sentences).\n"
 
 
 
32
  "• No personal‑data collection or sensitive advice.\n"
33
- "• If unsure, say so and suggest a human follow‑up.\n"
34
- "• Avoid profanity, politics, or mature themes."
35
  )
36
- WELCOME_MSG = "Welcome to SchoolSpirit AI! Do you have any questions?"
37
 
38
- strip = lambda s: re.sub(r"\s+", " ", s.strip())
39
 
 
 
40
 
41
- # ---------- Load model (GPU FP‑16 → CPU fallback) ---------------------------
 
 
 
42
  hf_logging.set_verbosity_error()
43
  try:
44
  log("Loading tokenizer …")
@@ -62,7 +72,6 @@ try:
62
  max_new_tokens=MAX_TOKENS,
63
  do_sample=True,
64
  temperature=0.6,
65
- pad_token_id=tok.eos_token_id,
66
  )
67
  MODEL_ERR = None
68
  log("Model loaded ✔")
@@ -71,30 +80,41 @@ except Exception as exc: # noqa: BLE001
71
  log(MODEL_ERR)
72
 
73
 
74
- # ---------- Chat callback ---------------------------------------------------
75
- def chat_fn(user_msg: str, history: list[dict]):
 
 
76
  """
77
- history comes in/out as list[{'role':'user'|'assistant','content':str}, …]
 
 
78
  """
79
  if MODEL_ERR:
80
- return history + [{"role": "assistant", "content": MODEL_ERR}]
81
 
82
  user_msg = strip(user_msg or "")
83
  if not user_msg:
84
- return history + [{"role": "assistant", "content": "Please type something."}]
85
  if len(user_msg) > MAX_INPUT_CH:
86
  warn = f"Message too long (>{MAX_INPUT_CH} chars)."
87
- return history + [{"role": "assistant", "content": warn}]
88
-
89
- # Append user to history
90
- history.append({"role": "user", "content": user_msg})
91
-
92
- # Keep system + last N messages
93
- convo = [m for m in history if m["role"] != "system"][-MAX_TURNS * 2 :]
94
- prompt_parts = [SYSTEM_MSG] + [
95
- f"{'User' if m['role']=='user' else 'AI'}: {m['content']}" for m in convo
96
- ] + ["AI:"]
97
- prompt = "\n".join(prompt_parts)
 
 
 
 
 
 
 
98
 
99
  try:
100
  raw = gen(prompt)[0]["generated_text"]
@@ -104,22 +124,28 @@ def chat_fn(user_msg: str, history: list[dict]):
104
  log("❌ Inference error:\n" + traceback.format_exc())
105
  reply = "Sorry—backend crashed. Please try again later."
106
 
107
- history.append({"role": "assistant", "content": reply})
108
- return history
109
-
110
-
111
- # ---------- Launch ----------------------------------------------------------
112
- gr.ChatInterface(
113
- fn=chat_fn,
114
- chatbot=gr.Chatbot(
115
- height=480,
116
- type="messages",
117
- value=[
118
- {"role": "assistant", "content": WELCOME_MSG}
119
- ], # ONE welcome bubble
120
- ),
121
- additional_inputs=None,
122
- title="SchoolSpirit AI Chat",
123
- theme=gr.themes.Soft(primary_hue="blue"),
124
- examples=None,
125
- ).launch()
 
 
 
 
 
 
 
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
+ # -------------------------------------------------------------------
7
+ # 1. Logging helpers
8
+ # -------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
  LOG_FILE = "/data/requests.log"
11
 
 
21
  pass
22
 
23
 
24
+ # -------------------------------------------------------------------
25
+ # 2. Configuration
26
+ # -------------------------------------------------------------------
27
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
+ MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
29
 
30
  SYSTEM_MSG = (
31
+ "You are **SchoolSpiritAI**, the digital mascot for SchoolSpiritAILLC, "
32
  "founded by Charles Norton in 2025. The company installs on‑prem AI chat "
33
+ "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
34
+ "K‑12 schools.\n\n"
35
+ "GUIDELINES:\n"
36
+ "• Warm, encouraging tone for students, parents, staff.\n"
37
+ "• Replies ≤ 4 sentences unless asked for detail.\n"
38
+ "• If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
39
  "• No personal‑data collection or sensitive advice.\n"
40
+ "• No profanity, politics, or mature themes."
 
41
  )
42
+ WELCOME_MSG = "Welcome to SchoolSpiritAI! Do you have any questions?"
43
 
 
44
 
45
+ def strip(s: str) -> str:
46
+ return re.sub(r"\s+", " ", s.strip())
47
 
48
+
49
+ # -------------------------------------------------------------------
50
+ # 3. Load model (GPU FP‑16 → CPU fallback)
51
+ # -------------------------------------------------------------------
52
  hf_logging.set_verbosity_error()
53
  try:
54
  log("Loading tokenizer …")
 
72
  max_new_tokens=MAX_TOKENS,
73
  do_sample=True,
74
  temperature=0.6,
 
75
  )
76
  MODEL_ERR = None
77
  log("Model loaded ✔")
 
80
  log(MODEL_ERR)
81
 
82
 
83
+ # -------------------------------------------------------------------
84
+ # 4. Chat callback
85
+ # -------------------------------------------------------------------
86
+ def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
87
  """
88
+ history: list of (user, assistant) tuples (Gradio default)
89
+ state : dict carrying system_prompt + raw_history for the model
90
+ Returns updated history (for UI) and state (for next round)
91
  """
92
  if MODEL_ERR:
93
+ return history + [(user_msg, MODEL_ERR)], state
94
 
95
  user_msg = strip(user_msg or "")
96
  if not user_msg:
97
+ return history + [(user_msg, "Please type something.")], state
98
  if len(user_msg) > MAX_INPUT_CH:
99
  warn = f"Message too long (>{MAX_INPUT_CH} chars)."
100
+ return history + [(user_msg, warn)], state
101
+
102
+ # ------------------------------------------------ Prompt assembly
103
+ raw_hist = state.get("raw", [])
104
+ raw_hist.append({"role": "user", "content": user_msg})
105
+ # keep system + last N exchanges
106
+ convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
107
+ raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
108
+
109
+ prompt = "\n".join(
110
+ [
111
+ m["content"]
112
+ if m["role"] == "system"
113
+ else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
114
+ for m in raw_hist
115
+ ]
116
+ + ["AI:"]
117
+ )
118
 
119
  try:
120
  raw = gen(prompt)[0]["generated_text"]
 
124
  log("❌ Inference error:\n" + traceback.format_exc())
125
  reply = "Sorry—backend crashed. Please try again later."
126
 
127
+ # ------------------------------------------------ Update state + UI history
128
+ raw_hist.append({"role": "assistant", "content": reply})
129
+ state["raw"] = raw_hist
130
+ history.append((user_msg, reply))
131
+ return history, state
132
+
133
+
134
+ # -------------------------------------------------------------------
135
+ # 5. Launch
136
+ # -------------------------------------------------------------------
137
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
138
+ chatbot = gr.Chatbot(
139
+ value=[("", WELCOME_MSG)], height=480, label="SchoolSpirit AI"
140
+ )
141
+ state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
142
+ with gr.Row():
143
+ txt = gr.Textbox(
144
+ scale=4, placeholder="Type your question here...", show_label=False
145
+ )
146
+ send = gr.Button("Send", variant="primary")
147
+
148
+ send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
149
+ txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
150
+
151
+ demo.launch()