Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,41 +2,31 @@ import spaces
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
-
|
6 |
from transformers import AutoTokenizer, TextStreamer, pipeline, AutoModelForCausalLM
|
7 |
-
from
|
8 |
-
from
|
9 |
from langchain.prompts import PromptTemplate
|
10 |
from langchain.chains import RetrievalQA
|
11 |
-
from
|
12 |
|
13 |
# System prompts
|
14 |
DEFAULT_SYSTEM_PROMPT = """
|
15 |
-
|
16 |
-
respond with "I don't
|
|
|
17 |
""".strip()
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
|
22 |
return f"""
|
23 |
[INST] <<SYS>>
|
24 |
{system_prompt}
|
25 |
<</SYS>>
|
26 |
-
{
|
27 |
-
""".strip()
|
28 |
-
|
29 |
-
template = generate_prompt(
|
30 |
-
"""
|
31 |
-
{context}
|
32 |
Question: {question}
|
33 |
-
|
34 |
-
|
35 |
-
)
|
36 |
-
|
37 |
-
prompt_template = PromptTemplate(template=template, input_variables=["context", "question"])
|
38 |
|
39 |
-
# Initialize embeddings and database
|
40 |
embeddings = HuggingFaceInstructEmbeddings(
|
41 |
model_name="hkunlp/instructor-base",
|
42 |
model_kwargs={"device": "cpu"}
|
@@ -55,21 +45,37 @@ def initialize_model():
|
|
55 |
model = AutoModelForCausalLM.from_pretrained(
|
56 |
model_id,
|
57 |
token=token,
|
58 |
-
device_map="cuda"
|
59 |
)
|
60 |
-
# if torch.cuda.is_available():
|
61 |
-
# model.device = "cuda"
|
62 |
-
# else:
|
63 |
-
# print("CUDA is not available")
|
64 |
|
65 |
return model, tokenizer
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
@spaces.GPU
|
68 |
def respond(message, history, system_message, max_tokens, temperature, top_p):
|
69 |
try:
|
70 |
model, tokenizer = initialize_model()
|
71 |
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
text_pipeline = pipeline(
|
74 |
"text-generation",
|
75 |
model=model,
|
@@ -81,18 +87,11 @@ def respond(message, history, system_message, max_tokens, temperature, top_p):
|
|
81 |
streamer=streamer,
|
82 |
)
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
qa_chain = RetrievalQA.from_chain_type(
|
87 |
-
llm=llm,
|
88 |
-
chain_type="stuff",
|
89 |
-
retriever=db.as_retriever(search_kwargs={"k": 2}),
|
90 |
-
return_source_documents=False,
|
91 |
-
chain_type_kwargs={"prompt": prompt_template}
|
92 |
-
)
|
93 |
|
94 |
-
|
95 |
-
yield
|
96 |
|
97 |
except Exception as e:
|
98 |
yield f"An error occurred: {str(e)}"
|
@@ -134,4 +133,4 @@ demo = gr.ChatInterface(
|
|
134 |
)
|
135 |
|
136 |
if __name__ == "__main__":
|
137 |
-
demo.launch()
|
|
|
2 |
import os
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
5 |
from transformers import AutoTokenizer, TextStreamer, pipeline, AutoModelForCausalLM
|
6 |
+
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
7 |
+
from langchain_community.vectorstores import Chroma
|
8 |
from langchain.prompts import PromptTemplate
|
9 |
from langchain.chains import RetrievalQA
|
10 |
+
from langchain_community.llms import HuggingFacePipeline
|
11 |
|
12 |
# System prompts
|
13 |
DEFAULT_SYSTEM_PROMPT = """
|
14 |
+
You are a ROS2 expert assistant. Based on the context provided, give direct and concise answers.
|
15 |
+
If the information is not in the context, respond with "I don't find that information in the available documentation."
|
16 |
+
Keep responses to 1-2 lines maximum.
|
17 |
""".strip()
|
18 |
|
19 |
+
def generate_prompt(context: str, question: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
|
|
|
|
|
20 |
return f"""
|
21 |
[INST] <<SYS>>
|
22 |
{system_prompt}
|
23 |
<</SYS>>
|
24 |
+
Context: {context}
|
|
|
|
|
|
|
|
|
|
|
25 |
Question: {question}
|
26 |
+
Answer: [/INST]
|
27 |
+
""".strip()
|
|
|
|
|
|
|
28 |
|
29 |
+
# Initialize embeddings and database
|
30 |
embeddings = HuggingFaceInstructEmbeddings(
|
31 |
model_name="hkunlp/instructor-base",
|
32 |
model_kwargs={"device": "cpu"}
|
|
|
45 |
model = AutoModelForCausalLM.from_pretrained(
|
46 |
model_id,
|
47 |
token=token,
|
48 |
+
device_map="cuda" if torch.cuda.is_available() else "cpu"
|
49 |
)
|
|
|
|
|
|
|
|
|
50 |
|
51 |
return model, tokenizer
|
52 |
|
53 |
+
class CustomTextStreamer(TextStreamer):
|
54 |
+
def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
|
55 |
+
super().__init__(tokenizer, skip_prompt=skip_prompt, skip_special_tokens=skip_special_tokens)
|
56 |
+
self.output_text = ""
|
57 |
+
|
58 |
+
def put(self, value):
|
59 |
+
self.output_text += value
|
60 |
+
super().put(value)
|
61 |
+
|
62 |
@spaces.GPU
|
63 |
def respond(message, history, system_message, max_tokens, temperature, top_p):
|
64 |
try:
|
65 |
model, tokenizer = initialize_model()
|
66 |
|
67 |
+
# Get relevant context from the database
|
68 |
+
retriever = db.as_retriever(search_kwargs={"k": 2})
|
69 |
+
docs = retriever.get_relevant_documents(message)
|
70 |
+
context = "\n".join([doc.page_content for doc in docs])
|
71 |
+
|
72 |
+
# Generate the complete prompt
|
73 |
+
prompt = generate_prompt(context=context, question=message, system_prompt=system_message)
|
74 |
+
|
75 |
+
# Set up the streamer
|
76 |
+
streamer = CustomTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
77 |
+
|
78 |
+
# Set up the pipeline
|
79 |
text_pipeline = pipeline(
|
80 |
"text-generation",
|
81 |
model=model,
|
|
|
87 |
streamer=streamer,
|
88 |
)
|
89 |
|
90 |
+
# Generate response
|
91 |
+
_ = text_pipeline(prompt, max_new_tokens=max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
+
# Return only the generated response
|
94 |
+
yield streamer.output_text.strip()
|
95 |
|
96 |
except Exception as e:
|
97 |
yield f"An error occurred: {str(e)}"
|
|
|
133 |
)
|
134 |
|
135 |
if __name__ == "__main__":
|
136 |
+
demo.launch(share=True)
|