Check / app.py
Rajesh3338's picture
Update app.py
16da4c9 verified
raw
history blame
3.86 kB
import os
import torch
import gradio as gr
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
# Detect device (GPU or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if device == "cuda" else torch.float32
# βœ… Step 1: Check if dataset exists
DATASET_FILE = "dataset.txt"
if not os.path.exists(DATASET_FILE):
raise FileNotFoundError(f"❌ Error: '{DATASET_FILE}' not found. Please create and add some text.")
# βœ… Step 2: Load and split dataset
print("πŸ“„ Loading dataset...")
doc_loader = TextLoader(DATASET_FILE)
docs = doc_loader.load()
print("πŸ”€ Splitting documents...")
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# βœ… Step 3: Initialize FAISS Vector Store
print("🧠 Creating embeddings...")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
print("πŸ“Œ Creating FAISS Vector DB...")
vectordb = FAISS.from_documents(split_docs, embeddings)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
# βœ… Step 4: Load Hugging Face Model
MODEL_NAME = "mistralai/Mistral-7B-Instruct" # More memory-efficient than Yi-Coder-9B
print(f"πŸš€ Loading Model: {MODEL_NAME} on {device}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch_dtype,
)
# βœ… Step 5: Create QA Pipeline
print("πŸ”§ Setting up pipeline...")
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=750,
pad_token_id=tokenizer.eos_token_id,
)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
# βœ… Step 6: Define Helper Functions
def preprocess_query(query):
"""Prepares query text based on type of request"""
if "script" in query.lower() or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
def clean_response(response):
"""Cleans and extracts the response from model output"""
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
def chatbot_response(user_input):
"""Processes user input and returns AI response"""
processed_query = preprocess_query(user_input)
raw_response = qa_chain.invoke({"query": processed_query})
return clean_response(raw_response)
# βœ… Step 7: Build Gradio UI
with gr.Blocks() as chat_interface:
gr.Markdown("# πŸ€– CPSL Chatbot")
chat_history = gr.Chatbot(label="Chat History", height=300)
user_input = gr.Textbox(label="Your Message:")
send_button = gr.Button("Send")
clear_button = gr.Button("Clear")
def interact(user_message, history):
"""Handles user interaction and updates chat history"""
bot_reply = chatbot_response(user_message)
history.append(("πŸ§‘β€πŸ’» You:", user_message))
history.append(("πŸ€– Bot:", bot_reply))
return history, ""
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, user_input])
clear_button.click(lambda: [], outputs=[chat_history])
# βœ… Step 8: Run Gradio App
if __name__ == "__main__":
print("πŸš€ Launching CPSL Chatbot...")
chat_interface.launch(share=True) # share=True allows public link