tosin2013 commited on
Commit
b0b776a
·
1 Parent(s): eb872b8

import spaces

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -10,6 +10,7 @@ from typing import List
10
  from langchain_core.output_parsers import StrOutputParser
11
  from langchain_core.prompts import ChatPromptTemplate
12
  import gradio as gr
 
13
  from huggingface_hub import InferenceClient
14
 
15
  # Configuration
@@ -48,8 +49,16 @@ else:
48
  dataset = load_dataset('tosin2013/autogen', streaming=True)
49
  dataset = Dataset.from_list(list(dataset['train']))
50
 
51
- # Initialize embeddings with GPU support if available
52
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
53
  embeddings = HuggingFaceEmbeddings(
54
  model_name="sentence-transformers/all-MiniLM-L6-v2",
55
  model_kwargs={"device": device}
@@ -65,6 +74,7 @@ text_embeddings = embeddings.embed_documents(texts)
65
  nn = NearestNeighbors(n_neighbors=5, metric='cosine')
66
  nn.fit(np.array(text_embeddings))
67
 
 
68
  def get_relevant_documents(query, k=5):
69
  """
70
  Retrieves the k most relevant documents to the query.
@@ -74,6 +84,7 @@ def get_relevant_documents(query, k=5):
74
  relevant_docs = [texts[i] for i in indices[0]]
75
  return relevant_docs
76
 
 
77
  def generate_response(question, history):
78
  try:
79
  print(f"\n[LOG] Received question: {question}")
@@ -199,9 +210,10 @@ with gr.Blocks() as demo:
199
 
200
  # Event handlers
201
  submit_btn.click(
202
- generate_response,
203
  inputs=[question, chatbot],
204
- outputs=[chatbot]
 
205
  )
206
 
207
  clear_btn.click(
 
10
  from langchain_core.output_parsers import StrOutputParser
11
  from langchain_core.prompts import ChatPromptTemplate
12
  import gradio as gr
13
+ import spaces
14
  from huggingface_hub import InferenceClient
15
 
16
  # Configuration
 
49
  dataset = load_dataset('tosin2013/autogen', streaming=True)
50
  dataset = Dataset.from_list(list(dataset['train']))
51
 
52
+ # Check GPU availability and initialize embeddings
53
+ if torch.cuda.is_available():
54
+ print("[INFO] GPU is available, using CUDA")
55
+ device = "cuda"
56
+ # Initialize CUDA device
57
+ torch.cuda.init()
58
+ else:
59
+ print("[INFO] GPU not available, using CPU")
60
+ device = "cpu"
61
+
62
  embeddings = HuggingFaceEmbeddings(
63
  model_name="sentence-transformers/all-MiniLM-L6-v2",
64
  model_kwargs={"device": device}
 
74
  nn = NearestNeighbors(n_neighbors=5, metric='cosine')
75
  nn.fit(np.array(text_embeddings))
76
 
77
+ @spaces.GPU(duration=30)
78
  def get_relevant_documents(query, k=5):
79
  """
80
  Retrieves the k most relevant documents to the query.
 
84
  relevant_docs = [texts[i] for i in indices[0]]
85
  return relevant_docs
86
 
87
+ @spaces.GPU(duration=120)
88
  def generate_response(question, history):
89
  try:
90
  print(f"\n[LOG] Received question: {question}")
 
210
 
211
  # Event handlers
212
  submit_btn.click(
213
+ fn=generate_response,
214
  inputs=[question, chatbot],
215
+ outputs=[chatbot],
216
+ queue=True
217
  )
218
 
219
  clear_btn.click(