Rajesh3338 commited on
Commit
8658913
Β·
verified Β·
1 Parent(s): 8eada5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -49
app.py CHANGED
@@ -1,6 +1,5 @@
1
- import os
2
- import torch
3
  import gradio as gr
 
4
  from langchain_huggingface import HuggingFaceEmbeddings
5
  from langchain_community.document_loaders import TextLoader
6
  from langchain_community.vectorstores import FAISS
@@ -9,55 +8,40 @@ from langchain.chains import RetrievalQA
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from langchain_huggingface import HuggingFacePipeline
11
 
12
- # Detect device (GPU or CPU)
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- torch_dtype = torch.float16 if device == "cuda" else torch.float32
15
 
16
- # βœ… Step 1: Check if dataset exists
17
- DATASET_FILE = "dataset.txt"
18
- if not os.path.exists(DATASET_FILE):
19
- raise FileNotFoundError(f"❌ Error: '{DATASET_FILE}' not found. Please create and add some text.")
20
-
21
- # βœ… Step 2: Load and split dataset
22
- print("πŸ“„ Loading dataset...")
23
- doc_loader = TextLoader(DATASET_FILE)
24
  docs = doc_loader.load()
25
-
26
- print("πŸ”€ Splitting documents...")
27
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
28
  split_docs = text_splitter.split_documents(docs)
29
 
30
- # βœ… Step 3: Initialize FAISS Vector Store
31
- print("🧠 Creating embeddings...")
32
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
33
-
34
- print("πŸ“Œ Creating FAISS Vector DB...")
35
  vectordb = FAISS.from_documents(split_docs, embeddings)
36
- retriever = vectordb.as_retriever(search_kwargs={"k": 5})
37
 
38
- # βœ… Step 4: Load Hugging Face Model
39
- MODEL_NAME = "01-ai/Yi-Coder-9B-Chat" # More memory-efficient than Yi-Coder-9B
40
- print(f"πŸš€ Loading Model: {MODEL_NAME} on {device}...")
41
-
42
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
43
  model = AutoModelForCausalLM.from_pretrained(
44
- MODEL_NAME,
45
- device_map="auto",
46
- torch_dtype=torch_dtype,
47
  )
48
 
49
- # βœ… Step 5: Create QA Pipeline
50
- print("πŸ”§ Setting up pipeline...")
51
  qa_pipeline = pipeline(
52
  "text-generation",
53
  model=model,
54
  tokenizer=tokenizer,
55
  max_new_tokens=750,
56
- pad_token_id=tokenizer.eos_token_id,
57
  )
58
 
59
  llm = HuggingFacePipeline(pipeline=qa_pipeline)
60
 
 
 
61
  qa_chain = RetrievalQA.from_chain_type(
62
  retriever=retriever,
63
  chain_type="stuff",
@@ -65,45 +49,37 @@ qa_chain = RetrievalQA.from_chain_type(
65
  return_source_documents=False
66
  )
67
 
68
- # βœ… Step 6: Define Helper Functions
69
  def preprocess_query(query):
70
- """Prepares query text based on type of request"""
71
- if "script" in query.lower() or "code" in query.lower():
72
  return f"Write a CPSL script: {query}"
73
  return query
74
 
75
  def clean_response(response):
76
- """Cleans and extracts the response from model output"""
77
  result = response.get("result", "")
78
  if "Answer:" in result:
79
  return result.split("Answer:")[1].strip()
80
  return result.strip()
81
 
82
  def chatbot_response(user_input):
83
- """Processes user input and returns AI response"""
84
  processed_query = preprocess_query(user_input)
85
  raw_response = qa_chain.invoke({"query": processed_query})
86
  return clean_response(raw_response)
87
 
88
- # βœ… Step 7: Build Gradio UI
89
  with gr.Blocks() as chat_interface:
90
- gr.Markdown("# πŸ€– CPSL Chatbot")
91
- chat_history = gr.Chatbot(label="Chat History", height=300)
92
  user_input = gr.Textbox(label="Your Message:")
93
  send_button = gr.Button("Send")
94
- clear_button = gr.Button("Clear")
95
 
96
  def interact(user_message, history):
97
- """Handles user interaction and updates chat history"""
98
  bot_reply = chatbot_response(user_message)
99
- history.append(("πŸ§‘β€πŸ’» You:", user_message))
100
- history.append(("πŸ€– Bot:", bot_reply))
101
- return history, ""
102
 
103
- send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, user_input])
104
- clear_button.click(lambda: [], outputs=[chat_history])
105
 
106
- # βœ… Step 8: Run Gradio App
107
  if __name__ == "__main__":
108
- print("πŸš€ Launching CPSL Chatbot...")
109
- chat_interface.launch(share=True) # share=True allows public link
 
 
 
1
  import gradio as gr
2
+ import torch
3
  from langchain_huggingface import HuggingFaceEmbeddings
4
  from langchain_community.document_loaders import TextLoader
5
  from langchain_community.vectorstores import FAISS
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from langchain_huggingface import HuggingFacePipeline
10
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
 
13
+ # Load and process the document
14
+ doc_loader = TextLoader("dataset.txt")
 
 
 
 
 
 
15
  docs = doc_loader.load()
 
 
16
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
17
  split_docs = text_splitter.split_documents(docs)
18
 
19
+ # Create embeddings and vector store
20
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
 
 
 
21
  vectordb = FAISS.from_documents(split_docs, embeddings)
 
22
 
23
+ # Load model and tokenizer
24
+ model_name = "01-ai/Yi-Coder-9B-Chat"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
26
  model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
30
  )
31
 
32
+ # Set up the QA pipeline
 
33
  qa_pipeline = pipeline(
34
  "text-generation",
35
  model=model,
36
  tokenizer=tokenizer,
37
  max_new_tokens=750,
38
+ pad_token_id=tokenizer.eos_token_id
39
  )
40
 
41
  llm = HuggingFacePipeline(pipeline=qa_pipeline)
42
 
43
+ # Set up the retriever and QA chain
44
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
45
  qa_chain = RetrievalQA.from_chain_type(
46
  retriever=retriever,
47
  chain_type="stuff",
 
49
  return_source_documents=False
50
  )
51
 
 
52
  def preprocess_query(query):
53
+ if "script" in query or "code" in query.lower():
 
54
  return f"Write a CPSL script: {query}"
55
  return query
56
 
57
  def clean_response(response):
 
58
  result = response.get("result", "")
59
  if "Answer:" in result:
60
  return result.split("Answer:")[1].strip()
61
  return result.strip()
62
 
63
  def chatbot_response(user_input):
 
64
  processed_query = preprocess_query(user_input)
65
  raw_response = qa_chain.invoke({"query": processed_query})
66
  return clean_response(raw_response)
67
 
68
+ # Gradio interface
69
  with gr.Blocks() as chat_interface:
70
+ gr.Markdown("# CPSL Chatbot")
71
+ chat_history = gr.Chatbot(type='messages')
72
  user_input = gr.Textbox(label="Your Message:")
73
  send_button = gr.Button("Send")
 
74
 
75
  def interact(user_message, history):
 
76
  bot_reply = chatbot_response(user_message)
77
+ history.append({"role": "user", "content": user_message})
78
+ history.append({"role": "assistant", "content": bot_reply})
79
+ return history, history
80
 
81
+ send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
 
82
 
83
+ # Launch the interface
84
  if __name__ == "__main__":
85
+ chat_interface.launch()