eli02 commited on
Commit
3d69062
·
1 Parent(s): 2abc9f5

Refactor embedding model integration and update API documentation for search response format

Browse files
[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet → [embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ced650f23166f55939fb6dfec6df2fd7d83995a9db362a1a7460d36e6f3ab510
3
- size 3118786
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca01a279b52f21c7e7d8441f8145f20201a255d8c015f3059920b1b957726a61
3
+ size 4232361
main.py CHANGED
@@ -51,7 +51,7 @@ class QueryInput(BaseModel):
51
  class SearchResult(BaseModel):
52
  text: str
53
  similarity: float
54
- model_type: str
55
 
56
  class TokenResponse(BaseModel):
57
  access_token: str
@@ -73,10 +73,13 @@ class RefreshRequest(BaseModel):
73
  refresh_token: str
74
 
75
  # Cache management
76
- @lru_cache(maxsize=1)
77
- def get_sentence_transformer():
78
- """Load and cache the SentenceTransformer model with lru_cache"""
79
- return SentenceTransformer(model_name_or_path="all-mpnet-base-v2", device="cpu")
 
 
 
80
 
81
  def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
82
  """Try to get embeddings from cache"""
@@ -91,7 +94,7 @@ def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]):
91
  @lru_cache(maxsize=1)
92
  def load_dataframe():
93
  """Load and cache the parquet dataframe"""
94
- database_file = Path(__file__).parent / "[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet"
95
  return pd.read_parquet(database_file)
96
 
97
  # Utility functions
@@ -102,61 +105,53 @@ def cosine_similarity(embedding_0, embedding_1):
102
  return dot_product / (norm_0 * norm_1)
103
 
104
  def generate_embedding(model, text: str, model_type: str) -> List[float]:
105
- # Try to get from cache first
106
  cached_embedding = get_cached_embeddings(text, model_type)
107
  if cached_embedding is not None:
108
  return cached_embedding
109
 
110
- # Generate new embedding if not in cache
111
- if model_type == "all-mpnet-base-v2":
112
- chunk_embedding = model.encode(
113
- text,
114
- convert_to_tensor=True
115
- )
116
- embedding = np.array(t.Tensor.cpu(chunk_embedding)).tolist()
117
- elif model_type == "text-embedding-3-small":
118
- response = model.embeddings.create(
119
- input=text,
120
- model="text-embedding-3-small"
121
- )
122
- embedding = response.data[0].embedding
123
-
124
- # Cache the new embedding
125
  set_cached_embeddings(text, model_type, embedding)
126
  return embedding
127
 
128
- def search_query(client, st_model, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]:
129
- # Generate embeddings for both models
130
- mpnet_embedding = generate_embedding(st_model, query, "all-mpnet-base-v2")
131
- openai_embedding = generate_embedding(client, query, "text-embedding-3-small")
132
 
133
  # Calculate similarities
134
- df['mpnet_similarities'] = df.all_mpnet_embedding.apply(
135
- lambda x: cosine_similarity(x, mpnet_embedding)
136
  )
137
- df['openai_similarities'] = df.openai_embedding.apply(
138
- lambda x: cosine_similarity(x, openai_embedding)
139
  )
140
 
141
  # Get top results for each model
142
- mpnet_results = df.nlargest(n, 'mpnet_similarities')
143
- openai_results = df.nlargest(n, 'openai_similarities')
144
 
145
  # Format results
146
  results = []
147
 
148
- for _, row in mpnet_results.iterrows():
149
  results.append({
150
  "text": row["ext"],
151
- "similarity": float(row["mpnet_similarities"]),
152
- "model_type": "all-mpnet-base-v2"
153
  })
154
 
155
- for _, row in openai_results.iterrows():
156
  results.append({
157
  "text": row["ext"],
158
- "similarity": float(row["openai_similarities"]),
159
- "model_type": "text-embedding-3-small"
160
  })
161
 
162
  return results
@@ -309,17 +304,14 @@ def logout(
309
 
310
  @app.post("/search", response_model=List[SearchResult])
311
  async def search(
312
- query_input: QueryInput,
313
- username: str = Depends(verify_access_token),
314
- ):
315
  try:
316
- # Initialize clients using cached functions
317
- client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
318
- st_model = get_sentence_transformer()
319
  df = load_dataframe()
320
 
321
- # Perform search with both models
322
- results = search_query(client, st_model, query_input.query, df, n=1)
323
  return [SearchResult(**result) for result in results]
324
 
325
  except Exception as e:
 
51
  class SearchResult(BaseModel):
52
  text: str
53
  similarity: float
54
+ model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"]
55
 
56
  class TokenResponse(BaseModel):
57
  access_token: str
 
73
  refresh_token: str
74
 
75
  # Cache management
76
+ @lru_cache(maxsize=2) # Cache both models
77
+ def get_embedding_models():
78
+ """Load and cache both embedding models"""
79
+ return {
80
+ "uae-large": SentenceTransformer("WhereIsAI/UAE-Large-V1", device="cpu"),
81
+ "bge-large": SentenceTransformer("BAAI/bge-large-en-v1.5", device="cpu")
82
+ }
83
 
84
  def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
85
  """Try to get embeddings from cache"""
 
94
  @lru_cache(maxsize=1)
95
  def load_dataframe():
96
  """Load and cache the parquet dataframe"""
97
+ database_file = Path(__file__).parent / "[embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet"
98
  return pd.read_parquet(database_file)
99
 
100
  # Utility functions
 
105
  return dot_product / (norm_0 * norm_1)
106
 
107
  def generate_embedding(model, text: str, model_type: str) -> List[float]:
 
108
  cached_embedding = get_cached_embeddings(text, model_type)
109
  if cached_embedding is not None:
110
  return cached_embedding
111
 
112
+ # Generate new embedding
113
+ embedding = model.encode(
114
+ text,
115
+ convert_to_tensor=True,
116
+ normalize_embeddings=True # Important for UAE and BGE models
117
+ )
118
+ embedding = np.array(t.Tensor.cpu(embedding)).tolist()
119
+
 
 
 
 
 
 
 
120
  set_cached_embeddings(text, model_type, embedding)
121
  return embedding
122
 
123
+ def search_query(st_models, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]:
124
+ # Generate embeddings with both models
125
+ uae_embedding = generate_embedding(st_models["uae-large"], query, "uae-large")
126
+ bge_embedding = generate_embedding(st_models["bge-large"], query, "bge-large")
127
 
128
  # Calculate similarities
129
+ df['uae_similarities'] = df["WhereIsAI_UAE_Large_V1"].apply(
130
+ lambda x: cosine_similarity(x, uae_embedding)
131
  )
132
+ df['bge_similarities'] = df["BAAI_bge_large_en_v1.5"].apply(
133
+ lambda x: cosine_similarity(x, bge_embedding)
134
  )
135
 
136
  # Get top results for each model
137
+ uae_results = df.nlargest(n, 'uae_similarities')
138
+ bge_results = df.nlargest(n, 'bge_similarities')
139
 
140
  # Format results
141
  results = []
142
 
143
+ for _, row in uae_results.iterrows():
144
  results.append({
145
  "text": row["ext"],
146
+ "similarity": float(row["uae_similarities"]),
147
+ "model_type": "WhereIsAI_UAE_Large_V1"
148
  })
149
 
150
+ for _, row in bge_results.iterrows():
151
  results.append({
152
  "text": row["ext"],
153
+ "similarity": float(row["bge_similarities"]),
154
+ "model_type": "BAAI_bge_large_en_v1.5"
155
  })
156
 
157
  return results
 
304
 
305
  @app.post("/search", response_model=List[SearchResult])
306
  async def search(
307
+ query_input: QueryInput,
308
+ username: str = Depends(verify_access_token),
309
+ ):
310
  try:
311
+ st_models = get_embedding_models()
 
 
312
  df = load_dataframe()
313
 
314
+ results = search_query(st_models, query_input.query, df, n=1)
 
315
  return [SearchResult(**result) for result in results]
316
 
317
  except Exception as e:
static/index.html CHANGED
@@ -138,18 +138,18 @@ class LoginResponse {
138
  <p>This endpoint is used to send a search query and retrieve results. It requires a valid access token.</p>
139
 
140
  <h4>Response:</h4>
141
- <pre><code class="language-json">[
142
- {
143
- "text": "Result 1 text",
144
- "similarity": 0.95,
145
- "model_type": "all-mpnet-base-v2"
146
- },
147
- {
148
- "text": "Result 2 text",
149
- "similarity": 0.92,
150
- "model_type": "text-embedding-3-small"
151
- }
152
- ]</code></pre>
153
  </div>
154
 
155
  <div class="endpoint" id="save">
 
138
  <p>This endpoint is used to send a search query and retrieve results. It requires a valid access token.</p>
139
 
140
  <h4>Response:</h4>
141
+ <pre><code class="language-json">[
142
+ {
143
+ "text": "Result 1 text",
144
+ "similarity": 0.95,
145
+ "model_type": "UAE-Large-V1"
146
+ },
147
+ {
148
+ "text": "Result 2 text",
149
+ "similarity": 0.92,
150
+ "model_type": "BGE-Large-V1.5"
151
+ }
152
+ ]</code></pre>
153
  </div>
154
 
155
  <div class="endpoint" id="save">