Rajesh3338 commited on
Commit
6fe9e32
·
verified ·
1 Parent(s): 97fd059

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -0
app.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from langchain.document_loaders import TextLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.llms import HuggingFacePipeline
7
+ from langchain.chains import RetrievalQA
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
+
10
+ # Load and process documents
11
+ doc_loader = TextLoader("dataset.txt")
12
+ docs = doc_loader.load()
13
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
14
+ split_docs = text_splitter.split_documents(docs)
15
+
16
+ # Create vector database
17
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
18
+ vectordb = FAISS.from_documents(split_docs, embeddings)
19
+
20
+ # Load model and create pipeline
21
+ model_name = "01-ai/Yi-Coder-9B-Chat"
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto")
24
+ qa_pipeline = pipeline(
25
+ "text-generation",
26
+ model=model,
27
+ tokenizer=tokenizer,
28
+ max_new_tokens=500,
29
+ pad_token_id=tokenizer.eos_token_id
30
+ )
31
+
32
+ # Set up LangChain
33
+ llm = HuggingFacePipeline(pipeline=qa_pipeline)
34
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
35
+ qa_chain = RetrievalQA.from_chain_type(
36
+ retriever=retriever,
37
+ chain_type="stuff",
38
+ llm=llm,
39
+ return_source_documents=False
40
+ )
41
+
42
+ def preprocess_query(query):
43
+ if "script" in query or "code" in query.lower():
44
+ return f"Write a CPSL script: {query}"
45
+ return query
46
+
47
+ def clean_response(response):
48
+ result = response.get("result", "")
49
+ if "Answer:" in result:
50
+ return result.split("Answer:")[1].strip()
51
+ return result.strip()
52
+
53
+ def chatbot_response(user_input):
54
+ processed_query = preprocess_query(user_input)
55
+ raw_response = qa_chain.invoke({"query": processed_query})
56
+ return clean_response(raw_response)
57
+
58
+ # Gradio interface
59
+ with gr.Blocks() as demo:
60
+ gr.Markdown("# CPSL Chatbot")
61
+ chat_history = gr.Chatbot()
62
+ user_input = gr.Textbox(label="Your Message:")
63
+ send_button = gr.Button("Send")
64
+
65
+ def interact(user_message, history):
66
+ bot_reply = chatbot_response(user_message)
67
+ history.append((user_message, bot_reply))
68
+ return history, history
69
+
70
+ send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])