Rajesh3338 commited on
Commit
1e98784
·
verified ·
1 Parent(s): 9b6a55e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -137
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import gradio as gr
3
  import torch
4
  from langchain_huggingface import HuggingFaceEmbeddings
@@ -9,142 +8,78 @@ from langchain.chains import RetrievalQA
9
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
  from langchain_huggingface import HuggingFacePipeline
11
 
12
- # Configure GPU settings
13
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
- print(f"Using device: {device}")
16
-
17
- class CPSLChatbot:
18
- def __init__(self):
19
- self.initialize_components()
20
-
21
- def initialize_components(self):
22
- try:
23
- # Load and process document
24
- doc_loader = TextLoader("dataset.txt")
25
- docs = doc_loader.load()
26
- text_splitter = RecursiveCharacterTextSplitter(
27
- chunk_size=1000,
28
- chunk_overlap=100
29
- )
30
- split_docs = text_splitter.split_documents(docs)
31
-
32
- # Initialize embeddings and vector store
33
- self.embeddings = HuggingFaceEmbeddings(
34
- model_name="all-MiniLM-L6-v2",
35
- model_kwargs={'device': device}
36
- )
37
- self.vectordb = FAISS.from_documents(split_docs, self.embeddings)
38
-
39
- # Load model and tokenizer
40
- model_name = "01-ai/Yi-Coder-9B-Chat"
41
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
42
- self.model = AutoModelForCausalLM.from_pretrained(
43
- model_name,
44
- device_map="auto",
45
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
- trust_remote_code=True
47
- )
48
-
49
- # Set up QA pipeline
50
- self.qa_pipeline = pipeline(
51
- "text-generation",
52
- model=self.model,
53
- tokenizer=self.tokenizer,
54
- max_new_tokens=750,
55
- pad_token_id=self.tokenizer.eos_token_id,
56
- device=0 if device == "cuda" else -1
57
- )
58
-
59
- # Initialize LangChain components
60
- llm = HuggingFacePipeline(pipeline=self.qa_pipeline)
61
- retriever = self.vectordb.as_retriever(search_kwargs={"k": 5})
62
- self.qa_chain = RetrievalQA.from_chain_type(
63
- retriever=retriever,
64
- chain_type="stuff",
65
- llm=llm,
66
- return_source_documents=False
67
- )
68
- print("Initialization completed successfully")
69
-
70
- except Exception as e:
71
- print(f"Initialization error: {str(e)}")
72
- raise
73
-
74
- def preprocess_query(self, query):
75
- if "script" in query.lower() or "code" in query.lower():
76
- return f"Write a CPSL script: {query}"
77
- return query
78
-
79
- def clean_response(self, response):
80
- result = response.get("result", "")
81
- if "Answer:" in result:
82
- return result.split("Answer:")[1].strip()
83
- return result.strip()
84
-
85
- def get_response(self, user_input):
86
- try:
87
- processed_query = self.preprocess_query(user_input)
88
- raw_response = self.qa_chain.invoke({"query": processed_query})
89
- return self.clean_response(raw_response)
90
- except Exception as e:
91
- return f"Error processing query: {str(e)}"
92
-
93
- def create_gradio_interface():
94
- chatbot = CPSLChatbot()
95
-
96
- with gr.Blocks(title="CPSL Chatbot") as chat_interface:
97
- gr.Markdown("# CPSL Chatbot with GPU Support")
98
- gr.Markdown("Using Yi-Coder-9B-Chat model for CPSL script generation and queries")
99
-
100
- chat_history = gr.Chatbot(
101
- value=[],
102
- elem_id="chatbot",
103
- height=600
104
- )
105
-
106
- with gr.Row():
107
- user_input = gr.Textbox(
108
- label="Your Message:",
109
- placeholder="Type your message here...",
110
- show_label=True,
111
- elem_id="user-input"
112
- )
113
- send_button = gr.Button("Send", variant="primary")
114
-
115
- def chat_response(user_message, history):
116
- if not user_message:
117
- return history, history
118
-
119
- bot_response = chatbot.get_response(user_message)
120
- history.append((user_message, bot_response))
121
- return history, history
122
-
123
- send_button.click(
124
- chat_response,
125
- inputs=[user_input, chat_history],
126
- outputs=[chat_history, chat_history],
127
- api_name="chat"
128
- )
129
-
130
- # Clear the input textbox after sending
131
- send_button.click(lambda: "", None, user_input)
132
-
133
- # Also allow Enter key to send message
134
- user_input.submit(
135
- chat_response,
136
- inputs=[user_input, chat_history],
137
- outputs=[chat_history, chat_history],
138
- )
139
- user_input.submit(lambda: "", None, user_input)
140
-
141
- return chat_interface
142
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  if __name__ == "__main__":
144
- interface = create_gradio_interface()
145
- interface.launch(
146
- server_name="0.0.0.0",
147
- server_port=7860,
148
- share=True,
149
- enable_queue=True
150
- )
 
 
1
  import gradio as gr
2
  import torch
3
  from langchain_huggingface import HuggingFaceEmbeddings
 
8
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
9
  from langchain_huggingface import HuggingFacePipeline
10
 
 
 
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Load and process the document
14
+ doc_loader = TextLoader("dataset.txt")
15
+ docs = doc_loader.load()
16
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
17
+ split_docs = text_splitter.split_documents(docs)
18
+
19
+ # Create embeddings and vector store
20
+ embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
21
+ vectordb = FAISS.from_documents(split_docs, embeddings)
22
+
23
+ # Load model and tokenizer
24
+ model_name = "01-ai/Yi-Coder-9B-Chat"
25
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_name,
28
+ device_map="auto",
29
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
30
+ )
31
+
32
+ # Set up the QA pipeline
33
+ qa_pipeline = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ max_new_tokens=750,
38
+ pad_token_id=tokenizer.eos_token_id
39
+ )
40
+
41
+ llm = HuggingFacePipeline(pipeline=qa_pipeline)
42
+
43
+ # Set up the retriever and QA chain
44
+ retriever = vectordb.as_retriever(search_kwargs={"k": 5})
45
+ qa_chain = RetrievalQA.from_chain_type(
46
+ retriever=retriever,
47
+ chain_type="stuff",
48
+ llm=llm,
49
+ return_source_documents=False
50
+ )
51
+
52
+ def preprocess_query(query):
53
+ if "script" in query or "code" in query.lower():
54
+ return f"Write a CPSL script: {query}"
55
+ return query
56
+
57
+ def clean_response(response):
58
+ result = response.get("result", "")
59
+ if "Answer:" in result:
60
+ return result.split("Answer:")[1].strip()
61
+ return result.strip()
62
+
63
+ def chatbot_response(user_input):
64
+ processed_query = preprocess_query(user_input)
65
+ raw_response = qa_chain.invoke({"query": processed_query})
66
+ return clean_response(raw_response)
67
+
68
+ # Gradio interface
69
+ with gr.Blocks() as chat_interface:
70
+ gr.Markdown("# CPSL Chatbot")
71
+ chat_history = gr.Chatbot(type='messages')
72
+ user_input = gr.Textbox(label="Your Message:")
73
+ send_button = gr.Button("Send")
74
+
75
+ def interact(user_message, history):
76
+ bot_reply = chatbot_response(user_message)
77
+ history.append({"role": "user", "content": user_message})
78
+ history.append({"role": "assistant", "content": bot_reply})
79
+ return history, history
80
+
81
+ send_button.click(interact, inputs=[user_input, chat_history], outputs=[chat_history, chat_history])
82
+
83
+ # Launch the interface
84
  if __name__ == "__main__":
85
+ chat_interface.launch()