Refactor embedding model integration and update API documentation for search response format
Browse files
[all_embedded] The Alchemy of Happiness (Ghazzālī, Claud Field) (Z-Library).parquet → [embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ca01a279b52f21c7e7d8441f8145f20201a255d8c015f3059920b1b957726a61
|
3 |
+
size 4232361
|
main.py
CHANGED
@@ -51,7 +51,7 @@ class QueryInput(BaseModel):
|
|
51 |
class SearchResult(BaseModel):
|
52 |
text: str
|
53 |
similarity: float
|
54 |
-
model_type:
|
55 |
|
56 |
class TokenResponse(BaseModel):
|
57 |
access_token: str
|
@@ -73,10 +73,13 @@ class RefreshRequest(BaseModel):
|
|
73 |
refresh_token: str
|
74 |
|
75 |
# Cache management
|
76 |
-
@lru_cache(maxsize=
|
77 |
-
def
|
78 |
-
"""Load and cache
|
79 |
-
return
|
|
|
|
|
|
|
80 |
|
81 |
def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
|
82 |
"""Try to get embeddings from cache"""
|
@@ -91,7 +94,7 @@ def set_cached_embeddings(text: str, model_type: str, embeddings: List[float]):
|
|
91 |
@lru_cache(maxsize=1)
|
92 |
def load_dataframe():
|
93 |
"""Load and cache the parquet dataframe"""
|
94 |
-
database_file = Path(__file__).parent / "[
|
95 |
return pd.read_parquet(database_file)
|
96 |
|
97 |
# Utility functions
|
@@ -102,61 +105,53 @@ def cosine_similarity(embedding_0, embedding_1):
|
|
102 |
return dot_product / (norm_0 * norm_1)
|
103 |
|
104 |
def generate_embedding(model, text: str, model_type: str) -> List[float]:
|
105 |
-
# Try to get from cache first
|
106 |
cached_embedding = get_cached_embeddings(text, model_type)
|
107 |
if cached_embedding is not None:
|
108 |
return cached_embedding
|
109 |
|
110 |
-
# Generate new embedding
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
response = model.embeddings.create(
|
119 |
-
input=text,
|
120 |
-
model="text-embedding-3-small"
|
121 |
-
)
|
122 |
-
embedding = response.data[0].embedding
|
123 |
-
|
124 |
-
# Cache the new embedding
|
125 |
set_cached_embeddings(text, model_type, embedding)
|
126 |
return embedding
|
127 |
|
128 |
-
def search_query(
|
129 |
-
# Generate embeddings
|
130 |
-
|
131 |
-
|
132 |
|
133 |
# Calculate similarities
|
134 |
-
df['
|
135 |
-
lambda x: cosine_similarity(x,
|
136 |
)
|
137 |
-
df['
|
138 |
-
lambda x: cosine_similarity(x,
|
139 |
)
|
140 |
|
141 |
# Get top results for each model
|
142 |
-
|
143 |
-
|
144 |
|
145 |
# Format results
|
146 |
results = []
|
147 |
|
148 |
-
for _, row in
|
149 |
results.append({
|
150 |
"text": row["ext"],
|
151 |
-
"similarity": float(row["
|
152 |
-
"model_type": "
|
153 |
})
|
154 |
|
155 |
-
for _, row in
|
156 |
results.append({
|
157 |
"text": row["ext"],
|
158 |
-
"similarity": float(row["
|
159 |
-
"model_type": "
|
160 |
})
|
161 |
|
162 |
return results
|
@@ -309,17 +304,14 @@ def logout(
|
|
309 |
|
310 |
@app.post("/search", response_model=List[SearchResult])
|
311 |
async def search(
|
312 |
-
|
313 |
-
|
314 |
-
):
|
315 |
try:
|
316 |
-
|
317 |
-
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
|
318 |
-
st_model = get_sentence_transformer()
|
319 |
df = load_dataframe()
|
320 |
|
321 |
-
|
322 |
-
results = search_query(client, st_model, query_input.query, df, n=1)
|
323 |
return [SearchResult(**result) for result in results]
|
324 |
|
325 |
except Exception as e:
|
|
|
51 |
class SearchResult(BaseModel):
|
52 |
text: str
|
53 |
similarity: float
|
54 |
+
model_type: Literal["WhereIsAI_UAE_Large_V1", "BAAI_bge_large_en_v1.5"]
|
55 |
|
56 |
class TokenResponse(BaseModel):
|
57 |
access_token: str
|
|
|
73 |
refresh_token: str
|
74 |
|
75 |
# Cache management
|
76 |
+
@lru_cache(maxsize=2) # Cache both models
|
77 |
+
def get_embedding_models():
|
78 |
+
"""Load and cache both embedding models"""
|
79 |
+
return {
|
80 |
+
"uae-large": SentenceTransformer("WhereIsAI/UAE-Large-V1", device="cpu"),
|
81 |
+
"bge-large": SentenceTransformer("BAAI/bge-large-en-v1.5", device="cpu")
|
82 |
+
}
|
83 |
|
84 |
def get_cached_embeddings(text: str, model_type: str) -> Optional[List[float]]:
|
85 |
"""Try to get embeddings from cache"""
|
|
|
94 |
@lru_cache(maxsize=1)
|
95 |
def load_dataframe():
|
96 |
"""Load and cache the parquet dataframe"""
|
97 |
+
database_file = Path(__file__).parent / "[embed] The Alchemy of Happiness (Ghazzālī, Claud Field).parquet"
|
98 |
return pd.read_parquet(database_file)
|
99 |
|
100 |
# Utility functions
|
|
|
105 |
return dot_product / (norm_0 * norm_1)
|
106 |
|
107 |
def generate_embedding(model, text: str, model_type: str) -> List[float]:
|
|
|
108 |
cached_embedding = get_cached_embeddings(text, model_type)
|
109 |
if cached_embedding is not None:
|
110 |
return cached_embedding
|
111 |
|
112 |
+
# Generate new embedding
|
113 |
+
embedding = model.encode(
|
114 |
+
text,
|
115 |
+
convert_to_tensor=True,
|
116 |
+
normalize_embeddings=True # Important for UAE and BGE models
|
117 |
+
)
|
118 |
+
embedding = np.array(t.Tensor.cpu(embedding)).tolist()
|
119 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
set_cached_embeddings(text, model_type, embedding)
|
121 |
return embedding
|
122 |
|
123 |
+
def search_query(st_models, query: str, df: pd.DataFrame, n: int = 1) -> List[Dict]:
|
124 |
+
# Generate embeddings with both models
|
125 |
+
uae_embedding = generate_embedding(st_models["uae-large"], query, "uae-large")
|
126 |
+
bge_embedding = generate_embedding(st_models["bge-large"], query, "bge-large")
|
127 |
|
128 |
# Calculate similarities
|
129 |
+
df['uae_similarities'] = df["WhereIsAI_UAE_Large_V1"].apply(
|
130 |
+
lambda x: cosine_similarity(x, uae_embedding)
|
131 |
)
|
132 |
+
df['bge_similarities'] = df["BAAI_bge_large_en_v1.5"].apply(
|
133 |
+
lambda x: cosine_similarity(x, bge_embedding)
|
134 |
)
|
135 |
|
136 |
# Get top results for each model
|
137 |
+
uae_results = df.nlargest(n, 'uae_similarities')
|
138 |
+
bge_results = df.nlargest(n, 'bge_similarities')
|
139 |
|
140 |
# Format results
|
141 |
results = []
|
142 |
|
143 |
+
for _, row in uae_results.iterrows():
|
144 |
results.append({
|
145 |
"text": row["ext"],
|
146 |
+
"similarity": float(row["uae_similarities"]),
|
147 |
+
"model_type": "WhereIsAI_UAE_Large_V1"
|
148 |
})
|
149 |
|
150 |
+
for _, row in bge_results.iterrows():
|
151 |
results.append({
|
152 |
"text": row["ext"],
|
153 |
+
"similarity": float(row["bge_similarities"]),
|
154 |
+
"model_type": "BAAI_bge_large_en_v1.5"
|
155 |
})
|
156 |
|
157 |
return results
|
|
|
304 |
|
305 |
@app.post("/search", response_model=List[SearchResult])
|
306 |
async def search(
|
307 |
+
query_input: QueryInput,
|
308 |
+
username: str = Depends(verify_access_token),
|
309 |
+
):
|
310 |
try:
|
311 |
+
st_models = get_embedding_models()
|
|
|
|
|
312 |
df = load_dataframe()
|
313 |
|
314 |
+
results = search_query(st_models, query_input.query, df, n=1)
|
|
|
315 |
return [SearchResult(**result) for result in results]
|
316 |
|
317 |
except Exception as e:
|
static/index.html
CHANGED
@@ -138,18 +138,18 @@ class LoginResponse {
|
|
138 |
<p>This endpoint is used to send a search query and retrieve results. It requires a valid access token.</p>
|
139 |
|
140 |
<h4>Response:</h4>
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
</div>
|
154 |
|
155 |
<div class="endpoint" id="save">
|
|
|
138 |
<p>This endpoint is used to send a search query and retrieve results. It requires a valid access token.</p>
|
139 |
|
140 |
<h4>Response:</h4>
|
141 |
+
<pre><code class="language-json">[
|
142 |
+
{
|
143 |
+
"text": "Result 1 text",
|
144 |
+
"similarity": 0.95,
|
145 |
+
"model_type": "UAE-Large-V1"
|
146 |
+
},
|
147 |
+
{
|
148 |
+
"text": "Result 2 text",
|
149 |
+
"similarity": 0.92,
|
150 |
+
"model_type": "BGE-Large-V1.5"
|
151 |
+
}
|
152 |
+
]</code></pre>
|
153 |
</div>
|
154 |
|
155 |
<div class="endpoint" id="save">
|