Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,999 Bytes
53d6350 b36fdc1 83362fe 53d6350 83362fe b36fdc1 83362fe 526f9f1 83362fe 5d4d3e2 53d6350 83362fe 0f9d9fe 53d6350 5d4d3e2 53d6350 9bd2927 490ca90 5d4d3e2 83362fe 5d4d3e2 cbefa1f 83362fe 5d4d3e2 83362fe 5d4d3e2 53d6350 83362fe 53d6350 83362fe 53d6350 5d4d3e2 83362fe 5d4d3e2 83362fe 5d4d3e2 83362fe 7a0e378 53d6350 83362fe b36fdc1 53d6350 83362fe 53d6350 83362fe 53d6350 83362fe b36fdc1 53d6350 2f49d9a cbefa1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces
from pathlib import Path
# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(f"Model loaded on device: {model.device}")
# 3. Load PDF files and create simple document store
from PyPDF2 import PdfReader
# Read all PDFs into a list of small chunks
def load_pdfs(folder_path="."):
docs = []
for pdf_file in Path(folder_path).glob("*.pdf"):
reader = PdfReader(str(pdf_file))
for page in reader.pages:
text = page.extract_text()
if text:
for para in text.split("\n\n"):
if len(para.strip()) > 20: # keep meaningful text
docs.append(para.strip())
return docs
document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")
# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)
# 5. Retrieval function
def retrieve_context(question, top_k=3):
question_embedding = embedder.encode(question, normalize_embeddings=True)
scores = torch.tensor(doc_embeddings) @ torch.tensor(question_embedding)
top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
return "\n\n".join([document_chunks[idx] for idx in top_indices])
# 6. Chat respond function
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
context = retrieve_context(message)
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# 7. Gradio UI
demo = gr.ChatInterface(
fn=respond,
title="Café Eleven Assistant",
description="Friendly café assistant with real menu knowledge!",
examples=[
[
"What kinds of burgers do you have?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
[
"Do you have any gluten-free pastries?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT.strip(),
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# 8. Launch
if __name__ == "__main__":
demo.launch()
|