|
import os |
|
from typing import List, Dict, Optional |
|
import vecs |
|
from datetime import datetime |
|
|
|
class DatabaseService: |
|
def __init__(self): |
|
|
|
self.DB_HOST = os.getenv("SUPABASE_HOST", "db.daxquaudqidyeirypexa.supabase.co") |
|
self.DB_PORT = os.getenv("DB_PORT", "5432") |
|
self.DB_NAME = os.getenv("DB_NAME", "postgres") |
|
self.DB_USER = os.getenv("DB_USER", "postgres") |
|
self.DB_PASSWORD = os.getenv("DB_PASSWORD", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImRheHF1YXVkcWlkeWVpcnlwZXhhIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDQzOTIzNzcsImV4cCI6MjA1OTk2ODM3N30.3qB-GfiCoqXEpbNfqV3iHiqOLr8Ex9nPVr6p9De5Hdc") |
|
|
|
|
|
self.vx = vecs.create_client( |
|
f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}" |
|
) |
|
|
|
|
|
self.articles = self.vx.get_or_create_collection( |
|
name="articles", |
|
dimension=384 |
|
) |
|
|
|
async def semantic_search( |
|
self, |
|
query_embedding: List[float], |
|
start_date: Optional[datetime] = None, |
|
end_date: Optional[datetime] = None, |
|
topic: Optional[str] = None, |
|
entities: Optional[List[str]] = None, |
|
limit: int = 10 |
|
) -> List[Dict[str, any]]: |
|
try: |
|
|
|
filters = self._build_filters(start_date, end_date, topic) |
|
|
|
|
|
if entities: |
|
filters["entities"] = {"$in": entities} |
|
|
|
results = self.articles.query( |
|
data=query_embedding, |
|
limit=limit, |
|
filters=filters, |
|
measure="cosine_distance" |
|
) |
|
|
|
|
|
formatted_results = [] |
|
for article_id, distance in results: |
|
metadata = self.articles.fetch(ids=[article_id])[0]["metadata"] |
|
formatted_results.append({ |
|
"id": article_id, |
|
"url": metadata.get("url"), |
|
"content": metadata.get("content"), |
|
"date": metadata.get("date"), |
|
"topic": metadata.get("topic"), |
|
"distance": float(distance), |
|
"similarity": 1 - float(distance) |
|
}) |
|
|
|
return formatted_results |
|
|
|
except Exception as e: |
|
print(f"Vector search error: {e}") |
|
return [] |
|
|
|
def _build_filters( |
|
self, |
|
start_date: Optional[datetime], |
|
end_date: Optional[datetime], |
|
topic: Optional[str] |
|
) -> Dict[str, any]: |
|
filters = {} |
|
|
|
if start_date and end_date: |
|
filters["date"] = { |
|
"$gte": start_date.isoformat(), |
|
"$lte": end_date.isoformat() |
|
} |
|
|
|
if topic: |
|
filters["topic"] = {"$eq": topic} |
|
|
|
return filters |
|
|
|
async def close(self): |
|
self.vx.disconnect() |