Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
-
# 0. Install custom transformers and
|
2 |
import os
|
3 |
os.system("pip install git+https://github.com/shumingma/transformers.git")
|
|
|
4 |
|
5 |
import threading
|
6 |
import torch
|
@@ -12,11 +13,15 @@ from transformers import (
|
|
12 |
AutoTokenizer,
|
13 |
TextIteratorStreamer,
|
14 |
)
|
|
|
15 |
import gradio as gr
|
16 |
import spaces
|
17 |
|
18 |
-
|
|
|
|
|
19 |
SYSTEM_PROMPT = """
|
|
|
20 |
1. Greet the customer warmly.
|
21 |
2. Help them order food and drinks from our menu.
|
22 |
3. Ask the customer for their desired pickup time.
|
@@ -39,7 +44,37 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
39 |
|
40 |
print(f"Model loaded on device: {model.device}")
|
41 |
|
42 |
-
# 3.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
@spaces.GPU
|
44 |
def respond(
|
45 |
message: str,
|
@@ -49,13 +84,15 @@ def respond(
|
|
49 |
temperature: float,
|
50 |
top_p: float,
|
51 |
):
|
|
|
|
|
52 |
messages = [{"role": "system", "content": system_message}]
|
53 |
for user_msg, bot_msg in history:
|
54 |
if user_msg:
|
55 |
messages.append({"role": "user", "content": user_msg})
|
56 |
if bot_msg:
|
57 |
messages.append({"role": "assistant", "content": bot_msg})
|
58 |
-
messages.append({"role": "user", "content": message})
|
59 |
|
60 |
prompt = tokenizer.apply_chat_template(
|
61 |
messages, tokenize=False, add_generation_prompt=True
|
@@ -81,21 +118,21 @@ def respond(
|
|
81 |
response += new_text
|
82 |
yield response
|
83 |
|
84 |
-
#
|
85 |
demo = gr.ChatInterface(
|
86 |
fn=respond,
|
87 |
title="Café Eleven Assistant",
|
88 |
-
description="
|
89 |
examples=[
|
90 |
[
|
91 |
-
"
|
92 |
SYSTEM_PROMPT.strip(),
|
93 |
512,
|
94 |
0.7,
|
95 |
0.95,
|
96 |
],
|
97 |
[
|
98 |
-
"Do you have
|
99 |
SYSTEM_PROMPT.strip(),
|
100 |
512,
|
101 |
0.7,
|
@@ -131,6 +168,6 @@ demo = gr.ChatInterface(
|
|
131 |
],
|
132 |
)
|
133 |
|
134 |
-
#
|
135 |
if __name__ == "__main__":
|
136 |
demo.launch()
|
|
|
1 |
+
# 0. Install custom transformers and imports
|
2 |
import os
|
3 |
os.system("pip install git+https://github.com/shumingma/transformers.git")
|
4 |
+
os.system("pip install sentence-transformers")
|
5 |
|
6 |
import threading
|
7 |
import torch
|
|
|
13 |
AutoTokenizer,
|
14 |
TextIteratorStreamer,
|
15 |
)
|
16 |
+
from sentence_transformers import SentenceTransformer
|
17 |
import gradio as gr
|
18 |
import spaces
|
19 |
|
20 |
+
from pathlib import Path
|
21 |
+
|
22 |
+
# 1. System prompt
|
23 |
SYSTEM_PROMPT = """
|
24 |
+
You are a friendly café assistant for Café Eleven. Your job is to:
|
25 |
1. Greet the customer warmly.
|
26 |
2. Help them order food and drinks from our menu.
|
27 |
3. Ask the customer for their desired pickup time.
|
|
|
44 |
|
45 |
print(f"Model loaded on device: {model.device}")
|
46 |
|
47 |
+
# 3. Load PDF files and create simple document store
|
48 |
+
from PyPDF2 import PdfReader
|
49 |
+
|
50 |
+
# Read all PDFs into a list of small chunks
|
51 |
+
def load_pdfs(folder_path="."):
|
52 |
+
docs = []
|
53 |
+
for pdf_file in Path(folder_path).glob("*.pdf"):
|
54 |
+
reader = PdfReader(str(pdf_file))
|
55 |
+
for page in reader.pages:
|
56 |
+
text = page.extract_text()
|
57 |
+
if text:
|
58 |
+
for para in text.split("\n\n"):
|
59 |
+
if len(para.strip()) > 20: # keep meaningful text
|
60 |
+
docs.append(para.strip())
|
61 |
+
return docs
|
62 |
+
|
63 |
+
document_chunks = load_pdfs(".")
|
64 |
+
print(f"Loaded {len(document_chunks)} text chunks from PDFs.")
|
65 |
+
|
66 |
+
# 4. Create embeddings
|
67 |
+
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Fast small model
|
68 |
+
doc_embeddings = embedder.encode(document_chunks, normalize_embeddings=True)
|
69 |
+
|
70 |
+
# 5. Retrieval function
|
71 |
+
def retrieve_context(question, top_k=3):
|
72 |
+
question_embedding = embedder.encode(question, normalize_embeddings=True)
|
73 |
+
scores = torch.tensor(doc_embeddings) @ torch.tensor(question_embedding)
|
74 |
+
top_indices = torch.topk(scores, k=min(top_k, len(scores))).indices.tolist()
|
75 |
+
return "\n\n".join([document_chunks[idx] for idx in top_indices])
|
76 |
+
|
77 |
+
# 6. Chat respond function
|
78 |
@spaces.GPU
|
79 |
def respond(
|
80 |
message: str,
|
|
|
84 |
temperature: float,
|
85 |
top_p: float,
|
86 |
):
|
87 |
+
context = retrieve_context(message)
|
88 |
+
|
89 |
messages = [{"role": "system", "content": system_message}]
|
90 |
for user_msg, bot_msg in history:
|
91 |
if user_msg:
|
92 |
messages.append({"role": "user", "content": user_msg})
|
93 |
if bot_msg:
|
94 |
messages.append({"role": "assistant", "content": bot_msg})
|
95 |
+
messages.append({"role": "user", "content": f"{message}\n\nRelevant menu info:\n{context}"})
|
96 |
|
97 |
prompt = tokenizer.apply_chat_template(
|
98 |
messages, tokenize=False, add_generation_prompt=True
|
|
|
118 |
response += new_text
|
119 |
yield response
|
120 |
|
121 |
+
# 7. Gradio UI
|
122 |
demo = gr.ChatInterface(
|
123 |
fn=respond,
|
124 |
title="Café Eleven Assistant",
|
125 |
+
description="Friendly café assistant with real menu knowledge!",
|
126 |
examples=[
|
127 |
[
|
128 |
+
"What kinds of burgers do you have?",
|
129 |
SYSTEM_PROMPT.strip(),
|
130 |
512,
|
131 |
0.7,
|
132 |
0.95,
|
133 |
],
|
134 |
[
|
135 |
+
"Do you have any gluten-free pastries?",
|
136 |
SYSTEM_PROMPT.strip(),
|
137 |
512,
|
138 |
0.7,
|
|
|
168 |
],
|
169 |
)
|
170 |
|
171 |
+
# 8. Launch
|
172 |
if __name__ == "__main__":
|
173 |
demo.launch()
|