File size: 5,752 Bytes
53d6350
b36fdc1
83362fe
53d6350
83362fe
 
b36fdc1
83362fe
 
526f9f1
83362fe
 
 
 
5d4d3e2
53d6350
83362fe
 
f3f8525
0f9d9fe
53d6350
f8f0b32
53d6350
 
5d4d3e2
53d6350
9bd2927
 
 
 
 
 
 
 
490ca90
5d4d3e2
83362fe
5d4d3e2
cbefa1f
83362fe
 
 
 
 
5d4d3e2
 
83362fe
5d4d3e2
f8f0b32
53d6350
 
f8f0b32
53d6350
f3f8525
 
 
 
 
 
 
 
 
f8f0b32
f3f8525
 
 
f8f0b32
f3f8525
 
 
 
 
f8f0b32
f3f8525
 
 
f8f0b32
53d6350
 
f8f0b32
53d6350
 
 
 
 
 
 
f8f0b32
53d6350
 
f8f0b32
 
 
53d6350
 
 
 
83362fe
 
 
 
 
 
 
 
 
53d6350
 
83362fe
 
 
 
 
 
53d6350
5d4d3e2
83362fe
 
 
 
5d4d3e2
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
5d4d3e2
83362fe
 
 
 
7a0e378
f8f0b32
83362fe
 
b36fdc1
53d6350
83362fe
 
53d6350
83362fe
 
 
 
 
 
53d6350
83362fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b36fdc1
 
53d6350
2f49d9a
f8f0b32
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")

import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces
import pdfplumber

from pathlib import Path
from PyPDF2 import PdfReader

# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""

MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"

# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

print(f"Model loaded on device: {model.device}")

# 3. Load PDF files 
def load_pdfs(folder_path="."):
    docs = []
    current_section = None
    for pdf_file in Path(folder_path).glob("*.pdf"):
        with pdfplumber.open(str(pdf_file)) as pdf:
            for page in pdf.pages:
                text = page.extract_text()
                if text:
                    lines = text.split("\n")
                    for line in lines:
                        line = line.strip()
                        if not line:
                            continue

                        if line.isupper() and len(line.split()) <= 6:
                            if current_section:
                                docs.append(current_section)
                            current_section = line
                        else:
                            if current_section:
                                current_section += f" | {line}"
                            else:
                                current_section = line

            if current_section:
                docs.append(current_section)
                current_section = None

    return docs


document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")

# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")  # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)

# 5. Retrieval function with float32 fix
def retrieve_context(question, top_k=3):
    question_embedding = embedder.encode(question, normalize_embeddings=True)
    question_embedding = torch.tensor(question_embedding, dtype=torch.float32)
    doc_embeds = torch.tensor(doc_embeddings, dtype=torch.float32)
    scores = doc_embeds @ question_embedding
    top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
    return "\n\n".join([document_chunks[idx] for idx in top_indices])

# 6. Chat respond function
@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    context = retrieve_context(message)
    
    messages = [{"role": "system", "content": system_message}]
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})

    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    response = ""
    for new_text in streamer:
        response += new_text
        yield response

# 7. Gradio ChatInterface
demo = gr.ChatInterface(
    fn=respond,
    title="Café Eleven Assistant",
    description="Friendly café assistant with real menu knowledge!",
    examples=[
        [
            "What kinds of burgers do you have?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
        [
            "Do you have any gluten-free pastries?",
            SYSTEM_PROMPT.strip(),
            512,
            0.7,
            0.95,
        ],
    ],
    additional_inputs=[
        gr.Textbox(
            value=SYSTEM_PROMPT.strip(),
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
)

# 8. Launch
if __name__ == "__main__":
    demo.launch(share=True)