Spaces:
Running
Running
import os | |
import torch | |
import gradio as gr | |
import spaces | |
from huggingface_hub import InferenceClient | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import Chroma | |
from langchain.prompts import PromptTemplate | |
# Verify PyTorch version compatibility | |
TORCH_VERSION = torch.__version__ | |
SUPPORTED_TORCH_VERSIONS = ['2.0.1', '2.1.2', '2.2.2', '2.4.0'] | |
if TORCH_VERSION.rsplit('+')[0] not in SUPPORTED_TORCH_VERSIONS: | |
print(f"Warning: Current PyTorch version {TORCH_VERSION} may not be compatible with ZeroGPU. " | |
f"Supported versions are: {', '.join(SUPPORTED_TORCH_VERSIONS)}") | |
# Initialize components outside of GPU scope | |
client = InferenceClient("meta-llama/Llama-3.2-3B-Instruct") | |
embeddings = HuggingFaceEmbeddings( | |
model_name="sentence-transformers/all-MiniLM-L6-v2", | |
model_kwargs={"device": "cpu"} # Keep embeddings on CPU | |
) | |
# Load database | |
db = Chroma( | |
persist_directory="db", | |
embedding_function=embeddings | |
) | |
# Prompt templates | |
DEFAULT_SYSTEM_PROMPT = """ | |
Based on the information in this document provided in context, answer the question as accurately as possible in 1 or 2 lines. If the information is not in the context, | |
respond with "I don't know" or a similar acknowledgment that the answer is not available. | |
""".strip() | |
def generate_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str: | |
return f""" | |
[INST] <<SYS>> | |
{system_prompt} | |
<</SYS>> | |
{prompt} [/INST] | |
""".strip() | |
template = generate_prompt( | |
""" | |
{context} | |
Question: {question} | |
""", | |
system_prompt="Use the following pieces of context to answer the question at the end. Do not provide commentary or elaboration more than 1 or 2 lines.?" | |
) | |
prompt_template = PromptTemplate(template=template, input_variables=["context", "question"]) | |
# Reduced duration for faster queue priority | |
def respond( | |
message, | |
history, | |
system_message, | |
max_tokens, | |
temperature, | |
top_p, | |
): | |
"""GPU-accelerated response generation""" | |
try: | |
# Retrieve context (CPU operation) | |
docs = db.similarity_search(message, k=2) | |
context = "\n".join([doc.page_content for doc in docs]) | |
print(f"Retrieved context: {context[:200]}...") | |
# Format prompt | |
formatted_prompt = prompt_template.format( | |
context=context, | |
question=message | |
) | |
print(f"Full prompt: {formatted_prompt}") | |
# Stream response (GPU operation) | |
response = "" | |
for message in client.text_generation( | |
prompt=formatted_prompt, | |
max_new_tokens=max_tokens, | |
stream=True, | |
temperature=temperature, | |
top_p=top_p, | |
): | |
response += message | |
yield response | |
except Exception as e: | |
yield f"An error occurred: {str(e)}" | |
# Create Gradio interface | |
demo = gr.ChatInterface( | |
respond, | |
additional_inputs=[ | |
gr.Textbox( | |
value=DEFAULT_SYSTEM_PROMPT, | |
label="System Message", | |
lines=3, | |
visible=False | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=2048, | |
value=500, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.1, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
title="ROS2 Expert Assistant", | |
description="Ask questions about ROS2, navigation, and robotics. I'll provide concise answers based on the available documentation.", | |
) | |
if __name__ == "__main__": | |
demo.launch() |