phanerozoic commited on
Commit
61ca5d6
Β·
verified Β·
1 Parent(s): 502a6b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -76
app.py CHANGED
@@ -1,13 +1,10 @@
1
- import os, re, time, datetime, traceback, torch
2
- import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from transformers.utils import logging as hf_logging
5
 
6
- # -------------------------------------------------------------------
7
- # 1. Logging helpers
8
- # -------------------------------------------------------------------
9
  os.environ["HF_HOME"] = "/data/.huggingface"
10
- LOG_FILE = "/data/requests.log"
11
 
12
 
13
  def log(msg: str):
@@ -15,47 +12,39 @@ def log(msg: str):
15
  line = f"[{ts}] {msg}"
16
  print(line, flush=True)
17
  try:
18
- with open(LOG_FILE, "a") as f:
19
  f.write(line + "\n")
20
  except FileNotFoundError:
21
  pass
22
 
23
 
24
- # -------------------------------------------------------------------
25
- # 2. Configuration
26
- # -------------------------------------------------------------------
27
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
28
- MAX_TURNS, MAX_TOKENS, MAX_INPUT_CH = 4, 64, 300
 
 
29
 
30
  SYSTEM_MSG = (
31
  "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
32
  "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
33
- "mascots, offers custom fine‑tuning, and ships turnkey GPU hardware to "
34
- "K‑12 schools.\n\n"
35
- "GUIDELINES:\n"
36
- "β€’ Warm, encouraging tone for students, parents, staff.\n"
37
- "β€’ Replies ≀ 4 sentences unless asked for detail.\n"
38
- "β€’ If unsure/out‑of‑scope: say so and suggest human follow‑up.\n"
39
- "β€’ No personal‑data collection or sensitive advice.\n"
40
- "β€’ No profanity, politics, or mature themes."
41
  )
42
  WELCOME_MSG = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
43
 
 
44
 
45
- def strip(s: str) -> str:
46
- return re.sub(r"\s+", " ", s.strip())
47
 
48
-
49
- # -------------------------------------------------------------------
50
- # 3. Load model (GPU FP‑16 β†’ CPU fallback)
51
- # -------------------------------------------------------------------
52
  hf_logging.set_verbosity_error()
53
  try:
54
- log("Loading tokenizer …")
55
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
56
-
57
  if torch.cuda.is_available():
58
- log("GPU detected β†’ FP‑16")
59
  model = AutoModelForCausalLM.from_pretrained(
60
  MODEL_ID, device_map="auto", torch_dtype=torch.float16
61
  )
@@ -71,81 +60,91 @@ try:
71
  tokenizer=tok,
72
  max_new_tokens=MAX_TOKENS,
73
  do_sample=True,
74
- temperature=0.6,
75
  )
76
  MODEL_ERR = None
77
- log("Model loaded βœ”")
78
  except Exception as exc: # noqa: BLE001
79
  MODEL_ERR, gen = f"Model load error: {exc}", None
80
  log(MODEL_ERR)
81
 
82
 
83
- # -------------------------------------------------------------------
84
- # 4. Chat callback
85
- # -------------------------------------------------------------------
86
- def chat_fn(user_msg: str, history: list[tuple[str, str]], state: dict):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  """
88
- history: list of (user, assistant) tuples (Gradio default)
89
- state : dict carrying system_prompt + raw_history for the model
90
- Returns updated history (for UI) and state (for next round)
91
  """
92
  if MODEL_ERR:
93
- return history + [(user_msg, MODEL_ERR)], state
 
94
 
95
  user_msg = strip(user_msg or "")
96
  if not user_msg:
97
- return history + [(user_msg, "Please type something.")], state
 
98
  if len(user_msg) > MAX_INPUT_CH:
99
- warn = f"Message too long (>{MAX_INPUT_CH} chars)."
100
- return history + [(user_msg, warn)], state
101
-
102
- # ------------------------------------------------ Prompt assembly
103
- raw_hist = state.get("raw", [])
104
- raw_hist.append({"role": "user", "content": user_msg})
105
- # keep system + last N exchanges
106
- convo = [m for m in raw_hist if m["role"] != "system"][-MAX_TURNS * 2 :]
107
- raw_hist = [{"role": "system", "content": SYSTEM_MSG}] + convo
108
-
109
- prompt = "\n".join(
110
- [
111
- m["content"]
112
- if m["role"] == "system"
113
- else f'{"User" if m["role"]=="user" else "AI"}: {m["content"]}'
114
- for m in raw_hist
115
- ]
116
- + ["AI:"]
117
- )
118
 
 
119
  try:
120
  raw = gen(prompt)[0]["generated_text"]
121
- reply = strip(raw.split("AI:", 1)[-1])
122
- reply = re.split(r"\b(?:User:|AI:)", reply, 1)[0].strip()
123
  except Exception:
124
  log("❌ Inference error:\n" + traceback.format_exc())
125
  reply = "Sorryβ€”backend crashed. Please try again later."
126
 
127
- # ------------------------------------------------ Update state + UI history
128
- raw_hist.append({"role": "assistant", "content": reply})
129
- state["raw"] = raw_hist
130
- history.append((user_msg, reply))
131
  return history, state
132
 
133
 
134
- # -------------------------------------------------------------------
135
- # 5. Launch
136
- # -------------------------------------------------------------------
137
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
138
  chatbot = gr.Chatbot(
139
- value=[("", WELCOME_MSG)], height=480, label="SchoolSpiritΒ AI"
 
 
 
140
  )
141
- state = gr.State({"raw": [{"role": "system", "content": SYSTEM_MSG}]})
142
- with gr.Row():
143
- txt = gr.Textbox(
144
- scale=4, placeholder="Type your question here...", show_label=False
145
- )
146
- send = gr.Button("Send", variant="primary")
147
 
148
- send.click(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
149
- txt.submit(chat_fn, inputs=[txt, chatbot, state], outputs=[chatbot, state])
150
 
151
  demo.launch()
 
1
+ import os, re, time, datetime, traceback, torch, gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
3
  from transformers.utils import logging as hf_logging
4
 
5
+ # ───────────────── logging ─────────────────────────────────────────
 
 
6
  os.environ["HF_HOME"] = "/data/.huggingface"
7
+ LOG = "/data/requests.log"
8
 
9
 
10
  def log(msg: str):
 
12
  line = f"[{ts}] {msg}"
13
  print(line, flush=True)
14
  try:
15
+ with open(LOG, "a") as f:
16
  f.write(line + "\n")
17
  except FileNotFoundError:
18
  pass
19
 
20
 
21
+ # ───────────────── config ──────────────────────────────────────────
 
 
22
  MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
23
+ MAX_PAIRS = 4 # user/assistant pairs to keep
24
+ MAX_TOKENS = 128
25
+ MAX_INPUT_CH = 300
26
 
27
  SYSTEM_MSG = (
28
  "You are **SchoolSpiritΒ AI**, the digital mascot for SchoolSpiritΒ AIΒ LLC, "
29
  "founded by CharlesΒ Norton inΒ 2025. The company installs on‑prem AI chat "
30
+ "mascots, offers custom fine‑tuning, and supplies GPU servers to K‑12 schools.\n\n"
31
+ "RULES:\n"
32
+ "β€’ Friendly, concise (≀ 4 sentences) unless user wants detail.\n"
33
+ "β€’ If unsure or out of scope, say so and suggest human follow‑up.\n"
34
+ "β€’ No personal‑data collection, no medical/legal/financial advice.\n"
35
+ "β€’ Avoid profanity, politics, and mature themes."
 
 
36
  )
37
  WELCOME_MSG = "Welcome to SchoolSpiritΒ AI! Do you have any questions?"
38
 
39
+ strip = lambda s: re.sub(r"\s+", " ", s.strip())
40
 
 
 
41
 
42
+ # ───────────────── model load (GPU fp16 β†’ CPU) ─────────────────────
 
 
 
43
  hf_logging.set_verbosity_error()
44
  try:
 
45
  tok = AutoTokenizer.from_pretrained(MODEL_ID)
 
46
  if torch.cuda.is_available():
47
+ log("GPU detected β†’ FP16")
48
  model = AutoModelForCausalLM.from_pretrained(
49
  MODEL_ID, device_map="auto", torch_dtype=torch.float16
50
  )
 
60
  tokenizer=tok,
61
  max_new_tokens=MAX_TOKENS,
62
  do_sample=True,
63
+ temperature=0.65,
64
  )
65
  MODEL_ERR = None
 
66
  except Exception as exc: # noqa: BLE001
67
  MODEL_ERR, gen = f"Model load error: {exc}", None
68
  log(MODEL_ERR)
69
 
70
 
71
+ # ───────────────── helper ──────────────────────────────────────────
72
+ def build_prompt(msgs):
73
+ """Granite likes ### markers"""
74
+ lines = [f"### System:\n{SYSTEM_MSG}"]
75
+ for m in msgs:
76
+ if m["role"] == "user":
77
+ lines.append(f"### User:\n{m['content']}")
78
+ elif m["role"] == "assistant":
79
+ lines.append(f"### Assistant:\n{m['content']}")
80
+ lines.append("### Assistant:")
81
+ return "\n".join(lines)
82
+
83
+
84
+ def trim(msgs):
85
+ """Keep system + last MAX_PAIRS*2 messages"""
86
+ convo = [m for m in msgs if m["role"] != "system"]
87
+ return [{"role": "system", "content": SYSTEM_MSG}] + convo[-MAX_PAIRS * 2 :]
88
+
89
+
90
+ # ───────────────── chat callback ───────────────────────────────────
91
+ def chat_fn(user_msg, history, state):
92
  """
93
+ user_msg : str
94
+ history : list[dict] for UI (assistant & user only)
95
+ state : {"msgs": full_message_history_with_system}
96
  """
97
  if MODEL_ERR:
98
+ history.append({"role": "assistant", "content": MODEL_ERR})
99
+ return history, state
100
 
101
  user_msg = strip(user_msg or "")
102
  if not user_msg:
103
+ history.append({"role": "assistant", "content": "Please type something."})
104
+ return history, state
105
  if len(user_msg) > MAX_INPUT_CH:
106
+ history.append(
107
+ {
108
+ "role": "assistant",
109
+ "content": f"Message too long (>{MAX_INPUT_CH} characters).",
110
+ }
111
+ )
112
+ return history, state
113
+
114
+ # Update raw history
115
+ state["msgs"].append({"role": "user", "content": user_msg})
116
+ state["msgs"] = trim(state["msgs"])
 
 
 
 
 
 
 
 
117
 
118
+ prompt = build_prompt(state["msgs"])
119
  try:
120
  raw = gen(prompt)[0]["generated_text"]
121
+ reply = strip(raw.split("### Assistant:", 1)[-1])
 
122
  except Exception:
123
  log("❌ Inference error:\n" + traceback.format_exc())
124
  reply = "Sorryβ€”backend crashed. Please try again later."
125
 
126
+ # Append to histories
127
+ state["msgs"].append({"role": "assistant", "content": reply})
128
+ history.append({"role": "assistant", "content": reply})
 
129
  return history, state
130
 
131
 
132
+ # ───────────────── UI ──────────────────────────────────────────────
 
 
133
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
134
  chatbot = gr.Chatbot(
135
+ value=[WELCOME_MSG],
136
+ label="SchoolSpiritΒ AI",
137
+ height=480,
138
+ type="messages",
139
  )
140
+ txt = gr.Textbox(
141
+ placeholder="Type your question here…",
142
+ show_label=False,
143
+ container=False,
144
+ )
145
+ state = gr.State({"msgs": [{"role": "system", "content": SYSTEM_MSG}]})
146
 
147
+ txt.submit(chat_fn, [txt, chatbot, state], [chatbot, state])
148
+ txt.submit(lambda _: "", None, txt) # clear textbox
149
 
150
  demo.launch()