Copain22 commited on
Commit
83362fe
·
verified ·
1 Parent(s): 61d97ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -65
app.py CHANGED
@@ -1,18 +1,21 @@
1
- # ---------- 0. Imports & constants ----------
2
  import os
 
 
 
3
  import torch
4
- import gradio as gr
5
- from pathlib import Path
6
- from huggingface_hub import login
7
 
8
- from llama_index.core import (
9
- VectorStoreIndex, SimpleDirectoryReader, Settings, PromptTemplate
 
 
10
  )
11
- from llama_index.core.memory import ChatMemoryBuffer
12
- from llama_index.llms.huggingface import HuggingFaceLLM
13
- from llama_index.embeddings.langchain import LangchainEmbedding
14
- from langchain_huggingface import HuggingFaceEmbeddings
15
 
 
16
  SYSTEM_PROMPT = """
17
  You are a friendly café assistant for Café Eleven. Your job is to:
18
  1. Greet the customer warmly
@@ -24,70 +27,114 @@ You are a friendly café assistant for Café Eleven. Your job is to:
24
  Always be polite and helpful!
25
  """
26
 
27
- WRAPPER_PROMPT = PromptTemplate(
28
- "[INST]<<SYS>>\n" + SYSTEM_PROMPT + "\n<</SYS>>\n\n{query_str} [/INST]"
29
- )
30
 
31
- # ---------- 1. Login & Load Data ----------
32
- login(token=os.environ["HF_TOKEN"])
33
-
34
- docs = SimpleDirectoryReader(
35
- input_files=[str(p) for p in Path(".").glob("*.pdf")]
36
- ).load_data()
37
-
38
- embed_model = LangchainEmbedding(
39
- HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
40
  )
41
- Settings.embed_model = embed_model
42
- Settings.chunk_size = 512
43
-
44
- index = VectorStoreIndex.from_documents(docs)
45
 
46
- # ---------- 2. Initialize Chat Engine ----------
47
- _state = {"chat_engine": None}
48
 
49
- def get_chat_engine():
50
- if _state["chat_engine"] is None:
51
- llm = HuggingFaceLLM(
52
- tokenizer_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
53
- model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
54
- context_window=2048,
55
- max_new_tokens=256,
56
- generate_kwargs={"temperature": 0.2, "do_sample": True},
57
- device_map="auto",
58
- model_kwargs={
59
- "use_auth_token": os.environ["HF_TOKEN"]
60
- },
61
- system_prompt=SYSTEM_PROMPT,
62
- query_wrapper_prompt=WRAPPER_PROMPT,)
63
- Settings.llm = llm
 
 
 
 
 
64
 
65
- memory = ChatMemoryBuffer.from_defaults(token_limit=2000)
66
- _state["chat_engine"] = index.as_chat_engine(
67
- chat_mode="condense_plus_context",
68
- memory=memory,
69
- system_prompt=SYSTEM_PROMPT,
70
- )
71
- return _state["chat_engine"]
72
 
73
- # ---------- 3. Simple Chat Function ----------
74
- def chat_with_cafe_eleven(message: str) -> str:
75
- if message.lower().strip() in {"quit", "exit", "done"}:
76
- return "Thank you for your order! We'll see you soon."
 
 
 
 
 
 
 
 
 
77
 
78
- engine = get_chat_engine()
79
- response = engine.chat(message).response
80
- return response
 
81
 
82
- # ---------- 4. Gradio UI ----------
83
- iface = gr.Interface(
84
- fn=chat_with_cafe_eleven,
85
- inputs=gr.Textbox(lines=2, placeholder="Ask about menu items, orders, etc..."),
86
- outputs="text",
87
  title="Café Eleven Assistant",
88
- description="A friendly café assistant to help you with orders and questions!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
 
91
- # ---------- 5. Launch App ----------
92
  if __name__ == "__main__":
93
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
+ # 0. Install custom transformers and imports
2
  import os
3
+ os.system("pip install git+https://github.com/shumingma/transformers.git")
4
+
5
+ import threading
6
  import torch
7
+ import torch._dynamo
8
+ torch._dynamo.config.suppress_errors = True
 
9
 
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoTokenizer,
13
+ TextIteratorStreamer,
14
  )
15
+ import gradio as gr
16
+ import spaces
 
 
17
 
18
+ # 1. System prompt (your original one)
19
  SYSTEM_PROMPT = """
20
  You are a friendly café assistant for Café Eleven. Your job is to:
21
  1. Greet the customer warmly
 
27
  Always be polite and helpful!
28
  """
29
 
30
+ # 2. Model info
31
+ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
 
32
 
33
+ # 3. Load model and tokenizer
34
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
35
+ model = AutoModelForCausalLM.from_pretrained(
36
+ MODEL_ID,
37
+ torch_dtype=torch.bfloat16,
38
+ device_map="auto"
 
 
 
39
  )
 
 
 
 
40
 
41
+ print(f"Model loaded on device: {model.device}")
 
42
 
43
+ # 4. Respond function
44
+ @spaces.GPU
45
+ def respond(
46
+ message: str,
47
+ history: list[tuple[str, str]],
48
+ system_message: str,
49
+ max_tokens: int,
50
+ temperature: float,
51
+ top_p: float,
52
+ ):
53
+ """
54
+ Generate a chat response using streaming with TextIteratorStreamer.
55
+ """
56
+ messages = [{"role": "system", "content": system_message}]
57
+ for user_msg, bot_msg in history:
58
+ if user_msg:
59
+ messages.append({"role": "user", "content": user_msg})
60
+ if bot_msg:
61
+ messages.append({"role": "assistant", "content": bot_msg})
62
+ messages.append({"role": "user", "content": message})
63
 
64
+ prompt = tokenizer.apply_chat_template(
65
+ messages, tokenize=False, add_generation_prompt=True
66
+ )
67
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
 
 
 
68
 
69
+ streamer = TextIteratorStreamer(
70
+ tokenizer, skip_prompt=True, skip_special_tokens=True
71
+ )
72
+ generate_kwargs = dict(
73
+ **inputs,
74
+ streamer=streamer,
75
+ max_new_tokens=max_tokens,
76
+ temperature=temperature,
77
+ top_p=top_p,
78
+ do_sample=True,
79
+ )
80
+ thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
81
+ thread.start()
82
 
83
+ response = ""
84
+ for new_text in streamer:
85
+ response += new_text
86
+ yield response
87
 
88
+ # 5. Gradio UI
89
+ demo = gr.ChatInterface(
90
+ fn=respond,
 
 
91
  title="Café Eleven Assistant",
92
+ description="A friendly café chatbot to help you with orders and menu questions!",
93
+ examples=[
94
+ [
95
+ "Can I get a recommendation for breakfast?",
96
+ SYSTEM_PROMPT.strip(),
97
+ 512,
98
+ 0.7,
99
+ 0.95,
100
+ ],
101
+ [
102
+ "Do you have vegan menu options?",
103
+ SYSTEM_PROMPT.strip(),
104
+ 512,
105
+ 0.7,
106
+ 0.95,
107
+ ],
108
+ ],
109
+ additional_inputs=[
110
+ gr.Textbox(
111
+ value=SYSTEM_PROMPT.strip(),
112
+ label="System message"
113
+ ),
114
+ gr.Slider(
115
+ minimum=1,
116
+ maximum=2048,
117
+ value=512,
118
+ step=1,
119
+ label="Max new tokens"
120
+ ),
121
+ gr.Slider(
122
+ minimum=0.1,
123
+ maximum=4.0,
124
+ value=0.7,
125
+ step=0.1,
126
+ label="Temperature"
127
+ ),
128
+ gr.Slider(
129
+ minimum=0.1,
130
+ maximum=1.0,
131
+ value=0.95,
132
+ step=0.05,
133
+ label="Top-p (nucleus sampling)"
134
+ ),
135
+ ],
136
  )
137
 
138
+ # 6. Launch
139
  if __name__ == "__main__":
140
+ demo.launch(server_name="0.0.0.0", server_port=7860)