File size: 4,749 Bytes
0ea4bc5
785e1a7
e1aeada
 
9843b35
e8212fa
785e1a7
 
e8212fa
0ea4bc5
 
 
 
 
e8212fa
785e1a7
 
 
e8212fa
 
0ea4bc5
d0ad708
 
 
0ea4bc5
 
 
 
 
 
502a6b6
785e1a7
ef0a942
2966210
ef0a942
502a6b6
 
2cb9530
785e1a7
 
 
 
 
 
e8212fa
785e1a7
 
 
 
e8212fa
785e1a7
 
2cb9530
785e1a7
d0ad708
0ea4bc5
999c346
2cb9530
56e0226
785e1a7
999c346
d0ad708
0ea4bc5
785e1a7
 
56e0226
785e1a7
 
0ea4bc5
785e1a7
0ea4bc5
785e1a7
999c346
 
785e1a7
 
2e445c2
0ea4bc5
785e1a7
 
 
 
2966210
785e1a7
 
885a86a
785e1a7
0ea4bc5
 
56e0226
785e1a7
999c346
0ea4bc5
 
785e1a7
0ea4bc5
999c346
 
785e1a7
 
 
 
 
 
 
d0ad708
 
785e1a7
d0ad708
785e1a7
d0ad708
94cf8c8
2966210
785e1a7
 
0ea4bc5
 
 
 
 
94cf8c8
0ea4bc5
785e1a7
 
 
94cf8c8
56e0226
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os, re, time, datetime, threading, traceback, torch, gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from transformers.utils import logging as hf_logging

os.environ["HF_HOME"] = "/data/.huggingface"
LOG_FILE = "/data/requests.log"
def log(m):
    line = f"[{datetime.datetime.utcnow().strftime('%H:%M:%S.%f')[:-3]}] {m}"
    print(line, flush=True)
    try:
        with open(LOG_FILE, "a") as f:
            f.write(line + "\n")
    except FileNotFoundError:
        pass

MODEL_ID = "ibm-granite/granite-3.3-2b-instruct"
CTX_TOK, MAX_NEW, TEMP = 1800, 64, 0.6
MAX_IN, RATE_N, RATE_T = 300, 5, 60

SYSTEM_MSG = (
    "You are **SchoolSpirit AI**, the friendly digital mascot of "
    "SchoolSpirit AI LLC, founded by Charles Norton in 2025. "
    "The company installs on‑prem AI chat mascots, fine‑tunes language models, "
    "and ships turnkey GPU servers to K‑12 schools.\n\n"
    "RULES:\n"
    "• Reply in ≤ 4 sentences unless asked for detail.\n"
    "• No personal‑data collection; no medical/legal/financial advice.\n"
    "• If uncertain, say so and suggest contacting a human.\n"
    "• If you can’t answer, politely direct the user to [email protected].\n"
    "• Keep language age‑appropriate; avoid profanity, politics, mature themes."
)
WELCOME = "Hi there! I’m SchoolSpirit AI. How can I help?"

strip = lambda s: re.sub(r"\s+", " ", s.strip())

hf_logging.set_verbosity_error()
try:
    tok = AutoTokenizer.from_pretrained(MODEL_ID)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        device_map="auto" if torch.cuda.is_available() else "cpu",
        torch_dtype=torch.float16 if torch.cuda.is_available() else "auto",
        low_cpu_mem_usage=True,
    )
    MODEL_ERR = None
    log("Model loaded")
except Exception as e:
    MODEL_ERR = f"Model load error: {e}"
    log(MODEL_ERR + "\n" + traceback.format_exc())

VISITS = {}
def allowed(ip):
    now = time.time()
    VISITS[ip] = [t for t in VISITS.get(ip, []) if now - t < RATE_T]
    if len(VISITS[ip]) >= RATE_N:
        return False
    VISITS[ip].append(now)
    return True

def build_prompt(raw):
    def render(m):
        if m["role"] == "system":
            return m["content"]
        return f"{'User:' if m['role']=='user' else 'AI:'} {m['content']}"
    sys, convo = raw[0], raw[1:]
    while True:
        parts = [sys["content"]] + [render(m) for m in convo] + ["AI:"]
        if len(tok.encode("\n".join(parts), add_special_tokens=False)) <= CTX_TOK or len(convo) <= 2:
            return "\n".join(parts)
        convo = convo[2:]

def chat_fn(user_msg, hist, state, request: gr.Request):
    ip = request.client.host if request else "anon"
    if not allowed(ip):
        hist.append((user_msg, "Rate limit exceeded — please wait a minute."))
        return hist, state, ""
    user_msg = strip(user_msg or "")
    if not user_msg:
        return hist, state, ""
    if len(user_msg) > MAX_IN:
        hist.append((user_msg, f"Input >{MAX_IN} chars."))
        return hist, state, ""
    if MODEL_ERR:
        hist.append((user_msg, MODEL_ERR))
        return hist, state, ""

    hist.append((user_msg, ""))
    state["raw"].append({"role": "user", "content": user_msg})

    prompt = build_prompt(state["raw"])
    ids = tok(prompt, return_tensors="pt").to(model.device).input_ids
    streamer = TextIteratorStreamer(tok, skip_prompt=True, skip_special_tokens=True)
    threading.Thread(
        target=model.generate,
        kwargs=dict(input_ids=ids, max_new_tokens=MAX_NEW, temperature=TEMP, streamer=streamer),
    ).start()

    partial = ""
    for piece in streamer:
        partial += piece
        if "User:" in partial or "\nAI:" in partial:
            partial = re.split(r"(?:\n?User:|\n?AI:)", partial)[0].strip()
            break
        hist[-1] = (user_msg, partial)
        yield hist, state, ""

    reply = strip(partial)
    hist[-1] = (user_msg, reply)
    state["raw"].append({"role": "assistant", "content": reply})
    yield hist, state, ""

with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
    gr.Markdown("### SchoolSpirit AI Chat")
    bot = gr.Chatbot(value=[("", WELCOME)], height=480)
    st = gr.State({
        "raw": [
            {"role": "system", "content": SYSTEM_MSG},
            {"role": "assistant", "content": WELCOME},
        ]
    })
    with gr.Row():
        txt = gr.Textbox(placeholder="Type your question here…", show_label=False, lines=1, scale=4)
        send = gr.Button("Send", variant="primary")
    send.click(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])
    txt.submit(chat_fn, inputs=[txt, bot, st], outputs=[bot, st, txt])

demo.launch()