File size: 4,999 Bytes
53d6350
b36fdc1
83362fe
53d6350
83362fe
 
b36fdc1
83362fe
 
526f9f1
83362fe
 
 
 
5d4d3e2
53d6350
83362fe
 
0f9d9fe
53d6350
 
 
5d4d3e2
53d6350
9bd2927
 
 
 
 
 
 
 
490ca90
5d4d3e2
83362fe
5d4d3e2
cbefa1f
83362fe
 
 
 
 
5d4d3e2
 
83362fe
5d4d3e2
53d6350
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83362fe
 
 
 
 
 
 
 
 
53d6350
 
83362fe
 
 
 
 
 
53d6350
5d4d3e2
83362fe
 
 
 
5d4d3e2
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
5d4d3e2
83362fe
 
 
 
7a0e378
53d6350
83362fe
 
b36fdc1
53d6350
83362fe
 
53d6350
83362fe
 
 
 
 
 
53d6350
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b36fdc1
 
53d6350
2f49d9a
cbefa1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")

import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces

from pathlib import Path

# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""

MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"

# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print(f"Model loaded on device: {model.device}")

# 3. Load PDF files and create simple document store
from PyPDF2 import PdfReader

# Read all PDFs into a list of small chunks
def load_pdfs(folder_path="."):
    docs = []
    for pdf_file in Path(folder_path).glob("*.pdf"):
        reader = PdfReader(str(pdf_file))
        for page in reader.pages:
            text = page.extract_text()
            if text:
                for para in text.split("\n\n"):
                    if len(para.strip()) > 20:  # keep meaningful text
                        docs.append(para.strip())
    return docs

document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")

# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")  # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)

# 5. Retrieval function
def retrieve_context(question, top_k=3):
    question_embedding = embedder.encode(question, normalize_embeddings=True)
    scores = torch.tensor(doc_embeddings) @ torch.tensor(question_embedding)
    top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
    return "\n\n".join([document_chunks[idx] for idx in top_indices])

# 6. Chat respond function
@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    context = retrieve_context(message)
    
    messages = [{"role": "system", "content": system_message}]
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})

    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield response

# 7. Gradio UI
demo = gr.ChatInterface(
    fn=respond,
    title="Café Eleven Assistant",
    description="Friendly café assistant with real menu knowledge!",
    examples=[
        [
            "What kinds of burgers do you have?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
        [
            "Do you have any gluten-free pastries?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
    ],
    additional_inputs=[
        gr.Textbox(
            value=SYSTEM_PROMPT.strip(),
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
)

# 8. Launch
if __name__ == "__main__":
    demo.launch()