Spaces:
Sleeping
Sleeping
import os, re, time, datetime, threading, traceback, torch, gradio as gr | |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer | |
from transformers.utils import logging as hf_logging | |
os.environ["HF_HOME"] = "/data/.huggingface" | |
LOG_FILE = "/data/requests.log" | |
def log(m): | |
line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}" | |
print(line, flush=True) | |
try: | |
with open(LOG_FILE, "a") as f: | |
f.write(line + "\n") | |
except FileNotFoundError: | |
pass | |
MODEL_ID = "ibm-granite/granite-3.3-2b-instruct" | |
CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6 | |
MAX_IN, RATE_N, RATE_T = 300, 5, 60 | |
SYSTEM_MSG = ( | |
"You are **SchoolSpirit AI**, the friendly digital mascot of " | |
"SchoolSpirit AI LLC, founded by Charles Norton in 2025. " | |
"The company installs on‑prem AI chat mascots, fine‑tunes language models, " | |
"and ships turnkey GPU servers to K‑12 schools.\n\n" | |
"RULES:\n" | |
"• Reply in ≤ 4 sentences unless asked for detail.\n" | |
"• No personal‑data collection; no medical/legal/financial advice.\n" | |
"• If uncertain, say so and suggest contacting a human.\n" | |
"• If you can’t answer, politely direct the user to [email protected].\n" | |
"• Keep language age‑appropriate; avoid profanity, politics, mature themes." | |
) | |
WELCOME = "Hi there! I’m SchoolSpirit AI. How can I help?" | |
strip = lambda s: re.sub(r"\s+", " ", s.strip()) | |
hf_logging.set_verbosity_error() | |
try: | |
tok = AutoTokenizer.from_pretrained(MODEL_ID) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
device_map="auto" if torch.cuda.is_available() else "cpu", | |
torch_dtype=torch.float16 if torch.cuda.is_available() else "auto", | |
low_cpu_mem_usage=True, | |
) | |
MODEL_ERR = None | |
log("Model loaded") | |
except Exception as e: | |
MODEL_ERR = f"Model load error: {e}" | |
log(MODEL_ERR + "\n" + traceback.format_exc()) | |
VISITS = {} | |
def allowed(ip): | |
now = time.time() | |
VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T] | |
if len(VISITS[ip]) >= RATE_N: | |
return False | |
VISITS[ip].append(now) | |
return True | |
def build_prompt(raw): | |
def render(m): | |
if m["role"] == "system": | |
return m["content"] | |
return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}" | |
sys, convo = raw[0], raw[1:] | |
while True: | |
parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"] | |
if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2: | |
return "\n".join(parts) | |
convo = convo[2:] | |
def chat_fn(user_msg, hist, state, request: gr.Request): | |
ip = request.client.host if request else "anon" | |
if not allowed(ip): | |
hist.append((user_msg, "Rate limit exceeded — please wait a minute.")) | |
return hist, state, "" | |
user_msg = strip(user_msg or "") | |
if not user_msg: | |
return hist, state, "" | |
if len(user_msg) > MAX_IN: | |
hist.append((user_msg, f"Input >{MAX_IN} chars.")) | |
return hist, state, "" | |
if MODEL_ERR: | |
hist.append((user_msg, MODEL_ERR)) | |
return hist, state, "" | |
hist.append((user_msg, "")) | |
state["raw"].append({"role": "user", "content": user_msg}) | |
prompt = build_prompt(state["raw"]) | |
ids = tok(prompt, return_tensors="pt").to(model.device).input_ids | |
streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True) | |
threading.Thread( | |
target=model.generate, | |
kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer), | |
).start() | |
partial = "" | |
for piece in streamer: | |
partial += piece | |
if "User:" in partial or "\nAI:" in partial: | |
partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip() | |
break | |
hist[-1] = (user_msg, partial) | |
yield hist, state, "" | |
reply = strip(partial) | |
hist[-1] = (user_msg, reply) | |
state["raw"].append({"role": "assistant", "content": reply}) | |
yield hist, state, "" | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown("### SchoolSpirit AI Chat") | |
bot = gr.Chatbot(value=[("", WELCOME)], height=480) | |
st = gr.State({ | |
"raw": [ | |
{"role": "system", "content": SYSTEM_MSG}, | |
{"role": "assistant", "content": WELCOME}, | |
] | |
}) | |
with gr.Row(): | |
txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4) | |
send = gr.Button("Send", variant="primary") | |
send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt]) | |
txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt]) | |
demo.launch() | |