Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,752 Bytes
53d6350 b36fdc1 83362fe 53d6350 83362fe b36fdc1 83362fe 526f9f1 83362fe 5d4d3e2 53d6350 83362fe f3f8525 0f9d9fe 53d6350 f8f0b32 53d6350 5d4d3e2 53d6350 9bd2927 490ca90 5d4d3e2 83362fe 5d4d3e2 cbefa1f 83362fe 5d4d3e2 83362fe 5d4d3e2 f8f0b32 53d6350 f8f0b32 53d6350 f3f8525 f8f0b32 f3f8525 f8f0b32 f3f8525 f8f0b32 f3f8525 f8f0b32 53d6350 f8f0b32 53d6350 f8f0b32 53d6350 f8f0b32 53d6350 83362fe 53d6350 83362fe 53d6350 5d4d3e2 83362fe 5d4d3e2 83362fe 5d4d3e2 83362fe 7a0e378 f8f0b32 83362fe b36fdc1 53d6350 83362fe 53d6350 83362fe 53d6350 83362fe b36fdc1 53d6350 2f49d9a f8f0b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
# 0. Install custom transformers and imports
import os
os.system("pip install git+https://github.com/shumingma/transformers.git")
os.system("pip install sentence-transformers")
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
from sentence_transformers import SentenceTransformer
import gradio as gr
import spaces
import pdfplumber
from pathlib import Path
from PyPDF2 import PdfReader
# 1. System prompt
SYSTEM_PROMPT = """
You are a friendly café assistant for Café Eleven. Your job is to:
1. Greet the customer warmly.
2. Help them order food and drinks from our menu.
3. Ask the customer for their desired pickup time.
4. Confirm the pickup time before ending the conversation.
5. Answer questions about ingredients, preparation, etc.
6. Handle special requests (allergies, modifications) politely.
7. Provide calorie information if asked.
Always be polite, helpful, and ensure the customer feels welcomed and cared for!
"""
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# 2. Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto"
)
print(f"Model loaded on device: {model.device}")
# 3. Load PDF files
def load_pdfs(folder_path="."):
docs = []
current_section = None
for pdf_file in Path(folder_path).glob("*.pdf"):
with pdfplumber.open(str(pdf_file)) as pdf:
for page in pdf.pages:
text = page.extract_text()
if text:
lines = text.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
if line.isupper() and len(line.split()) <= 6:
if current_section:
docs.append(current_section)
current_section = line
else:
if current_section:
current_section += f" | {line}"
else:
current_section = line
if current_section:
docs.append(current_section)
current_section = None
return docs
document_chunks = load_pdfs(".")
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")
# 4. Create embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)
# 5. Retrieval function with float32 fix
def retrieve_context(question, top_k=3):
question_embedding = embedder.encode(question, normalize_embeddings=True)
question_embedding = torch.tensor(question_embedding, dtype=torch.float32)
doc_embeds = torch.tensor(doc_embeddings, dtype=torch.float32)
scores = doc_embeds @ question_embedding
top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
return "\n\n".join([document_chunks[idx] for idx in top_indices])
# 6. Chat respond function
@spaces.GPU
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
context = retrieve_context(message)
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
)
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
response = ""
for new_text in streamer:
response += new_text
yield response
# 7. Gradio ChatInterface
demo = gr.ChatInterface(
fn=respond,
title="Café Eleven Assistant",
description="Friendly café assistant with real menu knowledge!",
examples=[
[
"What kinds of burgers do you have?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
[
"Do you have any gluten-free pastries?",
SYSTEM_PROMPT.strip(),
512,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value=SYSTEM_PROMPT.strip(),
label="System message"
),
gr.Slider(
minimum=1,
maximum=2048,
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# 8. Launch
if __name__ == "__main__":
demo.launch(share=True)
|