Spaces:
Runtime error
Runtime error
File size: 2,798 Bytes
f5e9a40 cba278b 16da4c9 cba278b 1e98784 cba278b 1e98784 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import gradio as gr
import torch
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains import RetrievalQA
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_huggingface import HuggingFacePipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load and process the document
doc_loader = TextLoader("dataset.txt")
docs = doc_loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
split_docs = text_splitter.split_documents(docs)
# Create embeddings and vector store
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vectordb = FAISS.from_documents(split_docs, embeddings)
# Load model and tokenizer
model_name = "01-ai/Yi-Coder-9B-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
)
# Set up the QA pipeline
qa_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=750,
pad_token_id=tokenizer.eos_token_id
)
llm = HuggingFacePipeline(pipeline=qa_pipeline)
# Set up the retriever and QA chain
retriever = vectordb.as_retriever(search_kwargs={"k": 5})
qa_chain = RetrievalQA.from_chain_type(
retriever=retriever,
chain_type="stuff",
llm=llm,
return_source_documents=False
)
def preprocess_query(query):
if "script" in query or "code" in query.lower():
return f"Write a CPSL script: {query}"
return query
def clean_response(response):
result = response.get("result", "")
if "Answer:" in result:
return result.split("Answer:")[1].strip()
return result.strip()
def chatbot_response(user_input):
processed_query = preprocess_query(user_input)
raw_response = qa_chain.invoke({"query": processed_query})
return clean_response(raw_response)
# Gradio interface
with gr.Blocks() as chat_interface:
gr.Markdown("# CPSL Chatbot")
chat_history = gr.Chatbot(type='messages')
user_input = gr.Textbox(label="Your Message:")
send_button = gr.Button("Send")
def interact(user_message, history):
bot_reply = chatbot_response(user_message)
history.append({"role": "user", "content": user_message})
history.append({"role": "assistant", "content": bot_reply})
return history, history
send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
# Launch the interface
if __name__ == "__main__":
chat_interface.launch() |