# 0. Install custom transformers and imports import os os.system("pip install git+https://github.com/shumingma/transformers.git") os.system("pip install sentence-transformers") import threading import torch import torch._dynamo torch._dynamo.config.suppress_errors = True from transformers import ( AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, ) from sentence_transformers import SentenceTransformer import gradio as gr import spaces import pdfplumber from pathlib import Path from PyPDF2 import PdfReader # 1. System prompt SYSTEM_PROMPT = """ You are a friendly café assistant for Café Eleven. Your job is to: 1. Greet the customer warmly. 2. Help them order food and drinks from our menu. 3. Ask the customer for their desired pickup time. 4. Confirm the pickup time before ending the conversation. 5. Answer questions about ingredients, preparation, etc. 6. Handle special requests (allergies, modifications) politely. 7. Provide calorie information if asked. Always be polite, helpful, and ensure the customer feels welcomed and cared for! """ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" # 2. Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto" ) print(f"Model loaded on device: {model.device}") # 3. Load PDF files def load_pdfs(folder_path="."): docs = [] current_section = None for pdf_file in Path(folder_path).glob("*.pdf"): with pdfplumber.open(str(pdf_file)) as pdf: for page in pdf.pages: text = page.extract_text() if text: lines = text.split("\n") for line in lines: line = line.strip() if not line: continue if line.isupper() and len(line.split()) <= 6: if current_section: docs.append(current_section) current_section = line else: if current_section: current_section += f" | {line}" else: current_section = line if current_section: docs.append(current_section) current_section = None return docs document_chunks = load_pdfs(".") print(f"Loaded {len(document_chunks)} text chunks from PDFs.") # 4. Create embeddings embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True) # 5. Retrieval function with float32 fix def retrieve_context(question, top_k=3): question_embedding = embedder.encode(question, normalize_embeddings=True) question_embedding = torch.tensor(question_embedding, dtype=torch.float32) doc_embeds = torch.tensor(doc_embeddings, dtype=torch.float32) scores = doc_embeds @ question_embedding top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist() return "\n\n".join([document_chunks[idx] for idx in top_indices]) # 6. Chat respond function @spaces.GPU def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): context = retrieve_context(message) messages = [{"role": "system", "content": system_message}] for user_msg, bot_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if bot_msg: messages.append({"role": "assistant", "content": bot_msg}) messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"}) prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True, ) thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) thread.start() response = "" for new_text in streamer: response += new_text yield response # 7. Gradio ChatInterface demo = gr.ChatInterface( fn=respond, title="Café Eleven Assistant", description="Friendly café assistant with real menu knowledge!", examples=[ [ "What kinds of burgers do you have?", SYSTEM_PROMPT.strip(), 512, 0.7, 0.95, ], [ "Do you have any gluten-free pastries?", SYSTEM_PROMPT.strip(), 512, 0.7, 0.95, ], ], additional_inputs=[ gr.Textbox( value=SYSTEM_PROMPT.strip(), label="System message" ), gr.Slider( minimum=1, maximum=2048, value=512, step=1, label="Max new tokens" ), gr.Slider( minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature" ), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)" ), ], ) # 8. Launch if __name__ == "__main__": demo.launch(share=True)