Check / app.py
Rajesh3338's picture
Update app.py
cba278b verified
raw
history blame
5.21 kB
import os
import gradio as gr
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
# Configure GPU settings
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
class CPSLChatbot:
def __init__(self):
self.initialize_components()
def initialize_components(self):
try:
# Load and process document
doc_loader = TextLoader("dataset.txt")
docs = doc_loader.load()
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100
)
split_docs = text_splitter.split_documents(docs)
# Initialize embeddings and vector store
self.embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2",
model_kwargs={'device': device}
)
self.vectordb = FAISS.from_documents(split_docs, self.embeddings)
# Load model and tokenizer
model_name = "01-ai/Yi-Coder-9B-Chat"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
trust_remote_code=True
)
# Set up QA pipeline
self.qa_pipeline = pipeline(
"text-generation",
model=self.model,
tokenizer=self.tokenizer,
max_new_tokens=750,
pad_token_id=self.tokenizer.eos_token_id,
device=0 if device == "cuda" else -1
)
# Initialize LangChain components
llm = HuggingFacePipeline(pipeline=self.qa_pipeline)
retriever = self.vectordb.as_retriever(search_kwargs={"k": 5})
self.qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
print("Initialization completed successfully")
except Exception as e:
print(f"Initialization error: {str(e)}")
raise
def preprocess_query(self, query):
if "script" in query.lower() or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
def clean_response(self, response):
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
def get_response(self, user_input):
try:
processed_query = self.preprocess_query(user_input)
raw_response = self.qa_chain.invoke({"query": processed_query})
return self.clean_response(raw_response)
except Exception as e:
return f"Error processing query: {str(e)}"
def create_gradio_interface():
chatbot = CPSLChatbot()
with gr.Blocks(title="CPSL Chatbot") as chat_interface:
gr.Markdown("# CPSL Chatbot with GPU Support")
gr.Markdown("Using Yi-Coder-9B-Chat model for CPSL script generation and queries")
chat_history = gr.Chatbot(
value=[],
elem_id="chatbot",
height=600
)
with gr.Row():
user_input = gr.Textbox(
label="Your Message:",
placeholder="Type your message here...",
show_label=True,
elem_id="user-input"
)
send_button = gr.Button("Send", variant="primary")
def chat_response(user_message, history):
if not user_message:
return history, history
bot_response = chatbot.get_response(user_message)
history.append((user_message, bot_response))
return history, history
send_button.click(
chat_response,
inputs=[user_input, chat_history],
outputs=[chat_history, chat_history],
api_name="chat"
)
# Clear the input textbox after sending
send_button.click(lambda: "", None, user_input)
# Also allow Enter key to send message
user_input.submit(
chat_response,
inputs=[user_input, chat_history],
outputs=[chat_history, chat_history],
)
user_input.submit(lambda: "", None, user_input)
return chat_interface
if __name__ == "__main__":
interface = create_gradio_interface()
interface.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
enable_queue=True
)