File size: 4,521 Bytes
53d6350
b36fdc1
83362fe
8ca6b5f
83362fe
 
b36fdc1
83362fe
 
526f9f1
83362fe
 
 
 
5d4d3e2
83362fe
 
8ca6b5f
53d6350
 
5d4d3e2
53d6350
9bd2927
 
 
 
 
 
 
 
490ca90
5d4d3e2
83362fe
5d4d3e2
cbefa1f
83362fe
 
 
 
 
5d4d3e2
 
83362fe
5d4d3e2
8ca6b5f
 
 
 
 
 
 
 
53d6350
8ca6b5f
 
f8f0b32
8ca6b5f
53d6350
8ca6b5f
 
 
 
 
 
 
 
83362fe
 
 
 
 
 
 
 
 
53d6350
 
83362fe
 
 
 
 
 
53d6350
5d4d3e2
83362fe
 
 
 
5d4d3e2
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
5d4d3e2
83362fe
 
 
 
7a0e378
8ca6b5f
83362fe
 
b36fdc1
8ca6b5f
83362fe
 
53d6350
83362fe
 
 
 
 
 
8ca6b5f
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b36fdc1
 
8ca6b5f
2f49d9a
f8f0b32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install python-docx")

import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
import gradio as gr
import spaces
from docx import Document

# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""

MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"

# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print(f"Model loaded on device: {model.device}")

# 3. Load Menu Text from Word document
def load_menu_text(docx_path):
    doc = Document(docx_path)
    full_text = []
    for para in doc.paragraphs:
        if para.text.strip():
            full_text.append(para.text.strip())
    return "\n".join(full_text)

MENU_TEXT = load_menu_text("menu.docx")
print(f"Loaded menu text from Word document.")

# 4. Simple retrieval function (search inside MENU_TEXT)
def retrieve_context(question, top_k=3):
    question = question.lower()
    sentences = MENU_TEXT.split("\n")
    matches = [s for s in sentences if any(word in s.lower() for word in question.split())]
    if not matches:
        return "Sorry, I couldn't find relevant menu information."
    return "\n\n".join(matches[:top_k])

# 5. Chat respond function
@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    context = retrieve_context(message)
    
    messages = [{"role": "system", "content": system_message}]
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})

    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield response

# 6. Gradio ChatInterface
demo = gr.ChatInterface(
    fn=respond,
    title="Café Eleven Assistant",
    description="Friendly café assistant based on real menu loaded from Word document!",
    examples=[
        [
            "What kinds of burgers do you have?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
        [
            "Do you have gluten-free pastries?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
    ],
    additional_inputs=[
        gr.Textbox(
            value=SYSTEM_PROMPT.strip(),
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
)

# 7. Launch
if __name__ == "__main__":
    demo.launch(share=True)