Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import gradio as gr | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import TextLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
from langchain_huggingface import HuggingFacePipeline | |
# Detect device (GPU or CPU) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
# β Step 1: Check if dataset exists | |
DATASET_FILE = "dataset.txt" | |
if not os.path.exists(DATASET_FILE): | |
raise FileNotFoundError(f"β Error: '{DATASET_FILE}' not found. Please create and add some text.") | |
# β Step 2: Load and split dataset | |
print("π Loading dataset...") | |
doc_loader = TextLoader(DATASET_FILE) | |
docs = doc_loader.load() | |
print("π Splitting documents...") | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
split_docs = text_splitter.split_documents(docs) | |
# β Step 3: Initialize FAISS Vector Store | |
print("π§ Creating embeddings...") | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
print("π Creating FAISS Vector DB...") | |
vectordb = FAISS.from_documents(split_docs, embeddings) | |
retriever = vectordb.as_retriever(search_kwargs={"k": 5}) | |
# β Step 4: Load Hugging Face Model | |
MODEL_NAME = "mistralai/Mistral-7B-Instruct" # More memory-efficient than Yi-Coder-9B | |
print(f"π Loading Model: {MODEL_NAME} on {device}...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
device_map="auto", | |
torch_dtype=torch_dtype, | |
) | |
# β Step 5: Create QA Pipeline | |
print("π§ Setting up pipeline...") | |
qa_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
max_new_tokens=750, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
llm = HuggingFacePipeline(pipeline=qa_pipeline) | |
qa_chain = RetrievalQA.from_chain_type( | |
retriever=retriever, | |
chain_type="stuff", | |
llm=llm, | |
return_source_documents=False | |
) | |
# β Step 6: Define Helper Functions | |
def preprocess_query(query): | |
"""Prepares query text based on type of request""" | |
if "script" in query.lower() or "code" in query.lower(): | |
return f"Write a CPSL script: {query}" | |
return query | |
def clean_response(response): | |
"""Cleans and extracts the response from model output""" | |
result = response.get("result", "") | |
if "Answer:" in result: | |
return result.split("Answer:")[1].strip() | |
return result.strip() | |
def chatbot_response(user_input): | |
"""Processes user input and returns AI response""" | |
processed_query = preprocess_query(user_input) | |
raw_response = qa_chain.invoke({"query": processed_query}) | |
return clean_response(raw_response) | |
# β Step 7: Build Gradio UI | |
with gr.Blocks() as chat_interface: | |
gr.Markdown("# π€ CPSL Chatbot") | |
chat_history = gr.Chatbot(label="Chat History", height=300) | |
user_input = gr.Textbox(label="Your Message:") | |
send_button = gr.Button("Send") | |
clear_button = gr.Button("Clear") | |
def interact(user_message, history): | |
"""Handles user interaction and updates chat history""" | |
bot_reply = chatbot_response(user_message) | |
history.append(("π§βπ» You:", user_message)) | |
history.append(("π€ Bot:", bot_reply)) | |
return history, "" | |
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, user_input]) | |
clear_button.click(lambda: [], outputs=[chat_history]) | |
# β Step 8: Run Gradio App | |
if __name__ == "__main__": | |
print("π Launching CPSL Chatbot...") | |
chat_interface.launch(share=True) # share=True allows public link | |