Check / app.py
Rajesh3338's picture
Update app.py
1e98784 verified
raw
history blame
2.8 kB
import gradio as gr
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load and process the document
doc_loader = TextLoader("dataset.txt")
docs = doc_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(split_docs, embeddings)
# Load model and tokenizer
model_name = "01-ai/Yi-Coder-9B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
# Set up the QA pipeline
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=750,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
# Set up the retriever and QA chain
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
def preprocess_query(query):
if "script" in query or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
def clean_response(response):
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
def chatbot_response(user_input):
processed_query = preprocess_query(user_input)
raw_response = qa_chain.invoke({"query": processed_query})
return clean_response(raw_response)
# Gradio interface
with gr.Blocks() as chat_interface:
gr.Markdown("# CPSL Chatbot")
chat_history = gr.Chatbot(type='messages')
user_input = gr.Textbox(label="Your Message:")
send_button = gr.Button("Send")
def interact(user_message, history):
bot_reply = chatbot_response(user_message)
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": bot_reply})
return history, history
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
# Launch the interface
if __name__ == "__main__":
chat_interface.launch()