Cafe-Chatbot / app.py
Copain22's picture
Update app.py
53d6350 verified
raw
history blame
5 kB
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces
from pathlib import Path
# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(f"Model loaded on device: {model.device}")
# 3. Load PDF files and create simple document store
from PyPDF2 import PdfReader
# Read all PDFs into a list of small chunks
def load_pdfs(folder_path="."):
docs = []
for pdf_file in Path(folder_path).glob("*.pdf"):
reader = PdfReader(str(pdf_file))
for page in reader.pages:
text = page.extract_text()
if text:
for para in text.split("\n\n"):
if len(para.strip()) > 20: # keep meaningful text
docs.append(para.strip())
return docs
document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")
# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)
# 5. Retrieval function
def retrieve_context(question, top_k=3):
question_embedding = embedder.encode(question, normalize_embeddings=True)
scores = torch.tensor(doc_embeddings) @ torch.tensor(question_embedding)
top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
return "\n\n".join([document_chunks[idx] for idx in top_indices])
# 6. Chat respond function
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
context = retrieve_context(message)
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# 7. Gradio UI
demo = gr.ChatInterface(
fn=respond,
title="Café Eleven Assistant",
description="Friendly café assistant with real menu knowledge!",
examples=[
[
"What kinds of burgers do you have?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
[
"Do you have any gluten-free pastries?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT.strip(),
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# 8. Launch
if __name__ == "__main__":
demo.launch()