Check / app.py
Rajesh3338's picture
Update app.py
800c7fe verified
raw
history blame
2.47 kB
import gradio as gr
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
# Load and process documents
doc_loader = TextLoader("dataset.txt")
docs = doc_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# Create vector database
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(split_docs, embeddings)
# Load model and create pipeline
model_name = "01-ai/Yi-Coder-9B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=500,
pad_token_id=tokenizer.eos_token_id
)
# Set up LangChain
llm = HuggingFacePipeline(pipeline=qa_pipeline)
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
def preprocess_query(query):
if "script" in query or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
def clean_response(response):
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
def chatbot_response(user_input):
processed_query = preprocess_query(user_input)
raw_response = qa_chain.invoke({"query": processed_query})
return clean_response(raw_response)
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# CPSL Chatbot")
chat_history = gr.Chatbot()
user_input = gr.Textbox(label="Your Message:")
send_button = gr.Button("Send")
def interact(user_message, history):
bot_reply = chatbot_response(user_message)
history.append((user_message, bot_reply))
return history, history
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
# Note: No launch() call here. Hugging Face will handle this.