Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import torch | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
# Configure GPU settings | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
class CPSLChatbot: | |
def __init__(self): | |
self.initialize_components() | |
def initialize_components(self): | |
try: | |
# Load and process document | |
doc_loader = TextLoader("dataset.txt") | |
docs = doc_loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100 | |
) | |
split_docs = text_splitter.split_documents(docs) | |
# Initialize embeddings and vector store | |
self.embeddings = HuggingFaceEmbeddings( | |
model_name="all-MiniLM-L6-v2", | |
model_kwargs={'device': device} | |
) | |
self.vectordb = FAISS.from_documents(split_docs, self.embeddings) | |
# Load model and tokenizer | |
model_name = "01-ai/Yi-Coder-9B-Chat" | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
trust_remote_code=True | |
) | |
# Set up QA pipeline | |
self.qa_pipeline = pipeline( | |
"text-generation", | |
model=self.model, | |
tokenizer=self.tokenizer, | |
max_new_tokens=750, | |
pad_token_id=self.tokenizer.eos_token_id, | |
device=0 if device == "cuda" else -1 | |
) | |
# Initialize LangChain components | |
llm = HuggingFacePipeline(pipeline=self.qa_pipeline) | |
retriever = self.vectordb.as_retriever(search_kwargs={"k": 5}) | |
self.qa_chain = RetrievalQA.from_chain_type( | |
retriever=retriever, | |
chain_type="stuff", | |
llm=llm, | |
return_source_documents=False | |
) | |
print("Initialization completed successfully") | |
except Exception as e: | |
print(f"Initialization error: {str(e)}") | |
raise | |
def preprocess_query(self, query): | |
if "script" in query.lower() or "code" in query.lower(): | |
return f"Write a CPSL script: {query}" | |
return query | |
def clean_response(self, response): | |
result = response.get("result", "") | |
if "Answer:" in result: | |
return result.split("Answer:")[1].strip() | |
return result.strip() | |
def get_response(self, user_input): | |
try: | |
processed_query = self.preprocess_query(user_input) | |
raw_response = self.qa_chain.invoke({"query": processed_query}) | |
return self.clean_response(raw_response) | |
except Exception as e: | |
return f"Error processing query: {str(e)}" | |
def create_gradio_interface(): | |
chatbot = CPSLChatbot() | |
with gr.Blocks(title="CPSL Chatbot") as chat_interface: | |
gr.Markdown("# CPSL Chatbot with GPU Support") | |
gr.Markdown("Using Yi-Coder-9B-Chat model for CPSL script generation and queries") | |
chat_history = gr.Chatbot( | |
value=[], | |
elem_id="chatbot", | |
height=600 | |
) | |
with gr.Row(): | |
user_input = gr.Textbox( | |
label="Your Message:", | |
placeholder="Type your message here...", | |
show_label=True, | |
elem_id="user-input" | |
) | |
send_button = gr.Button("Send", variant="primary") | |
def chat_response(user_message, history): | |
if not user_message: | |
return history, history | |
bot_response = chatbot.get_response(user_message) | |
history.append((user_message, bot_response)) | |
return history, history | |
send_button.click( | |
chat_response, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chat_history], | |
api_name="chat" | |
) | |
# Clear the input textbox after sending | |
send_button.click(lambda: "", None, user_input) | |
# Also allow Enter key to send message | |
user_input.submit( | |
chat_response, | |
inputs=[user_input, chat_history], | |
outputs=[chat_history, chat_history], | |
) | |
user_input.submit(lambda: "", None, user_input) | |
return chat_interface | |
if __name__ == "__main__": | |
interface = create_gradio_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=True, | |
enable_queue=True | |
) |