SchoolSpiritAI / app.py
phanerozoic's picture
Update app.py
785e1a7 verified
import os, re, time, datetime, threading, traceback, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers.utils import logging as hf_logging
os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(m):
line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}"
print(line, flush=True)
try:
with open(LOG_FILE, "a") as f:
f.write(line + "\n")
except FileNotFoundError:
pass
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6
MAX_IN, RATE_N, RATE_T = 300, 5, 60
SYSTEM_MSG = (
"You are **SchoolSpirit AI**, the friendly digital mascot of "
"SchoolSpirit AI LLC, founded by Charles Norton in 2025. "
"The company installs on‑prem AI chat mascots, fine‑tunes language models, "
"and ships turnkey GPU servers to K‑12 schools.\n\n"
"RULES:\n"
"• Reply in ≤ 4 sentences unless asked for detail.\n"
"• No personal‑data collection; no medical/legal/financial advice.\n"
"• If uncertain, say so and suggest contacting a human.\n"
"• If you can’t answer, politely direct the user to [email protected].\n"
"• Keep language age‑appropriate; avoid profanity, politics, mature themes."
)
WELCOME = "Hi there! I’m SchoolSpirit AI. How can I help?"
strip = lambda s: re.sub(r"\s+", " ", s.strip())
hf_logging.set_verbosity_error()
try:
tok = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
low_cpu_mem_usage=True,
)
MODEL_ERR = None
log("Model loaded")
except Exception as e:
MODEL_ERR = f"Model load error: {e}"
log(MODEL_ERR + "\n" + traceback.format_exc())
VISITS = {}
def allowed(ip):
now = time.time()
VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T]
if len(VISITS[ip]) >= RATE_N:
return False
VISITS[ip].append(now)
return True
def build_prompt(raw):
def render(m):
if m["role"] == "system":
return m["content"]
return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}"
sys, convo = raw[0], raw[1:]
while True:
parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"]
if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2:
return "\n".join(parts)
convo = convo[2:]
def chat_fn(user_msg, hist, state, request: gr.Request):
ip = request.client.host if request else "anon"
if not allowed(ip):
hist.append((user_msg, "Rate limit exceeded — please wait a minute."))
return hist, state, ""
user_msg = strip(user_msg or "")
if not user_msg:
return hist, state, ""
if len(user_msg) > MAX_IN:
hist.append((user_msg, f"Input >{MAX_IN} chars."))
return hist, state, ""
if MODEL_ERR:
hist.append((user_msg, MODEL_ERR))
return hist, state, ""
hist.append((user_msg, ""))
state["raw"].append({"role": "user", "content": user_msg})
prompt = build_prompt(state["raw"])
ids = tok(prompt, return_tensors="pt").to(model.device).input_ids
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
threading.Thread(
target=model.generate,
kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer),
).start()
partial = ""
for piece in streamer:
partial += piece
if "User:" in partial or "\nAI:" in partial:
partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
break
hist[-1] = (user_msg, partial)
yield hist, state, ""
reply = strip(partial)
hist[-1] = (user_msg, reply)
state["raw"].append({"role": "assistant", "content": reply})
yield hist, state, ""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("### SchoolSpirit AI Chat")
bot = gr.Chatbot(value=[("", WELCOME)], height=480)
st = gr.State({
"raw": [
{"role": "system", "content": SYSTEM_MSG},
{"role": "assistant", "content": WELCOME},
]
})
with gr.Row():
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
send = gr.Button("Send", variant="primary")
send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
demo.launch()