rag-ros2 / app.py
mannadamay12's picture
Update app.py
17a9f49 verified
raw
history blame
3.84 kB
import os
import torch
import gradio as gr
import spaces
from huggingface_hub import InferenceClient
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import PromptTemplate
# Verify PyTorch version compatibility
TORCH_VERSION = torch.__version__
SUPPORTED_TORCH_VERSIONS = ['2.0.1', '2.1.2', '2.2.2', '2.4.0']
if TORCH_VERSION.rsplit('+')[0] not in SUPPORTED_TORCH_VERSIONS:
print(f"Warning: Current PyTorch version {TORCH_VERSION} may not be compatible with ZeroGPU. "
f"Supported versions are: {', '.join(SUPPORTED_TORCH_VERSIONS)}")
# Initialize components outside of GPU scope
client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct")
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"} # Keep embeddings on CPU
)
# Load database
db = Chroma(
persist_directory="db",
embedding_function=embeddings
)
# Prompt templates
DEFAULT_SYSTEM_PROMPT = """
Based on the information in this document provided in context, answer the question as accurately as possible in 1 or 2 lines. If the information is not in the context,
respond with "I don't know" or a similar acknowledgment that the answer is not available.
""".strip()
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"""
[INST] <<SYS>>
{system_prompt}
<</SYS>>
{prompt} [/INST]
""".strip()
template = generate_prompt(
"""
{context}
Question: {question}
""",
system_prompt="Use the following pieces of context to answer the question at the end. Do not provide commentary or elaboration more than 1 or 2 lines.?"
)
prompt_template = PromptTemplate(template=template, input_variables=["context", "question"])
@spaces.GPU(duration=30) # Reduced duration for faster queue priority
def respond(
message,
history,
system_message,
max_tokens,
temperature,
top_p,
):
"""GPU-accelerated response generation"""
try:
# Retrieve context (CPU operation)
docs = db.similarity_search(message, k=2)
context = "\n".join([doc.page_content for doc in docs])
print(f"Retrieved context: {context[:200]}...")
# Format prompt
formatted_prompt = prompt_template.format(
context=context,
question=message
)
print(f"Full prompt: {formatted_prompt}")
# Stream response (GPU operation)
response = ""
for message in client.text_generation(
prompt=formatted_prompt,
max_new_tokens=max_tokens,
stream=True,
temperature=temperature,
top_p=top_p,
):
response += message
yield response
except Exception as e:
yield f"An error occurred: {str(e)}"
# Create Gradio interface
demo = gr.ChatInterface(
respond,
additional_inputs=[
gr.Textbox(
value=DEFAULT_SYSTEM_PROMPT,
label="System Message",
lines=3,
visible=False
),
gr.Slider(
minimum=1,
maximum=2048,
value=500,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.1,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
title="ROS2 Expert Assistant",
description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.",
)
if __name__ == "__main__":
demo.launch()