|
import os |
|
from typing import List, Dict, Optional,Tuple |
|
from datetime import datetime |
|
import psycopg2 |
|
from psycopg2 import sql |
|
|
|
class DatabaseService: |
|
def __init__(self): |
|
|
|
self.DB_HOST = os.getenv("SUPABASE_HOST", "aws-0-eu-west-3.pooler.supabase.com") |
|
self.DB_PORT = os.getenv("DB_PORT", "6543") |
|
self.DB_NAME = os.getenv("DB_NAME", "postgres") |
|
self.DB_USER = os.getenv("DB_USER") |
|
self.DB_PASSWORD = os.getenv("DB_PASSWORD") |
|
|
|
async def semantic_search( |
|
self, |
|
query_embedding: List[float], |
|
start_date: Optional[datetime] = None, |
|
end_date: Optional[datetime] = None, |
|
topic: Optional[str] = None, |
|
entities: Optional[List[Tuple[str,str]]] = None, |
|
limit: int = 10 |
|
) -> List[Dict[str, any]]: |
|
|
|
print(f"Extracted entities2: {entities}") |
|
try: |
|
with psycopg2.connect( |
|
user=self.DB_USER, |
|
password=self.DB_PASSWORD, |
|
host=self.DB_HOST, |
|
port=self.DB_PORT, |
|
dbname=self.DB_NAME |
|
) as conn: |
|
with conn.cursor() as cursor: |
|
|
|
base_query = sql.SQL(''' |
|
WITH filtered_articles AS ( |
|
SELECT article_id |
|
FROM articles.articles |
|
WHERE 1=1 |
|
''') |
|
|
|
|
|
if start_date and end_date: |
|
base_query = sql.SQL('{}{}').format( |
|
base_query, |
|
sql.SQL(' AND date BETWEEN {} AND {}').format( |
|
sql.Literal(start_date), |
|
sql.Literal(end_date) |
|
) |
|
) |
|
|
|
|
|
if topic: |
|
base_query = sql.SQL('{}{}').format( |
|
base_query, |
|
sql.SQL(' AND topic = {}').format(sql.Literal(topic)) |
|
) |
|
|
|
base_query = sql.SQL('{} {}').format( |
|
base_query, |
|
sql.SQL(')') |
|
) |
|
|
|
|
|
if entities: |
|
entity_conditions = sql.SQL(" OR ").join( |
|
sql.SQL(""" |
|
(LOWER(UNACCENT(word)) = LOWER(UNACCENT({})) |
|
AND entity_group = {}) |
|
""").format( |
|
sql.Literal(e[0]), |
|
sql.Literal(e[1]) |
|
) for e in entities |
|
) |
|
|
|
final_query = sql.SQL(''' |
|
{base_query}, |
|
target_articles AS ( |
|
SELECT DISTINCT article_id |
|
FROM articles.ner |
|
WHERE ({entity_conditions}) |
|
AND article_id IN (SELECT article_id FROM filtered_articles) |
|
) |
|
SELECT |
|
a.content, |
|
a.embedding <=> {embedding}::vector AS distance, |
|
a.date, |
|
a.topic, |
|
a.url |
|
FROM articles.articles a |
|
JOIN target_articles t ON a.article_id = t.article_id |
|
ORDER BY distance |
|
LIMIT {limit} |
|
''').format( |
|
base_query=base_query, |
|
entity_conditions=entity_conditions, |
|
embedding=sql.Literal(query_embedding), |
|
limit=sql.Literal(limit) |
|
) |
|
else: |
|
final_query = sql.SQL(''' |
|
{base_query} |
|
SELECT |
|
a.content, |
|
a.embedding <=> {embedding}::vector AS distance, |
|
a.date, |
|
a.topic, |
|
a.url |
|
FROM articles.articles a |
|
JOIN filtered_articles f ON a.article_id = f.article_id |
|
ORDER BY distance |
|
LIMIT {limit} |
|
''').format( |
|
base_query=base_query, |
|
embedding=sql.Literal(query_embedding), |
|
limit=sql.Literal(limit) |
|
) |
|
|
|
cursor.execute(final_query) |
|
articles = cursor.fetchall() |
|
|
|
|
|
if not articles: |
|
print("No articles found with entities...") |
|
fallback_query = sql.SQL(''' |
|
SELECT |
|
content, |
|
embedding <=> {}::vector AS distance, |
|
date, |
|
topic, |
|
url |
|
FROM articles.articles |
|
ORDER BY distance |
|
LIMIT {limit} |
|
''').format( |
|
sql.Literal(query_embedding), |
|
limit=sql.Literal(limit) |
|
) |
|
cursor.execute(fallback_query) |
|
articles = cursor.fetchall() |
|
|
|
|
|
formatted_results = [ |
|
{ |
|
"content": content, |
|
"distance": distance, |
|
"date": art_date, |
|
"topic": art_topic, |
|
"url": url, |
|
} |
|
for content, distance, art_date, art_topic,url in articles |
|
] |
|
|
|
return formatted_results |
|
|
|
except Exception as e: |
|
print(f"Database query error: {e}") |
|
return [] |
|
|
|
async def close(self): |
|
|
|
pass |