File size: 2,336 Bytes
06696b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# vectordb_relank_law.py
import faiss
import numpy as np
import os
from chromadb import PersistentClient
from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from retriever.reranker import rerank_documents
from constants.embedding_models import embedding_models
# chroma vector config v2
 
# law_db config v2
CHROMA_PATH = os.path.abspath("data/index/exam_db")
COLLECTION_NAME = "exam_all"
EMBEDDING_MODEL_NAME = embedding_models[1]  # μ‚¬μš©ν•˜κ³ μž ν•˜λŠ” λͺ¨λΈ 선택

# 1. μž„λ² λ”© λͺ¨λΈ λ‘œλ“œ v2
# embedding_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)

# 2. μž„λ² λ”© ν•¨μˆ˜ μ„€μ •
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=EMBEDDING_MODEL_NAME)

# 3. Chroma ν΄λΌμ΄μ–ΈνŠΈ 및 μ»¬λ ‰μ…˜ λ‘œλ“œ
client = PersistentClient(path=CHROMA_PATH)
collection = client.get_collection(name=COLLECTION_NAME, embedding_function=embedding_fn)

# 4. 검색 ν•¨μˆ˜
def search_documents(query: str, top_k: int = 5):
    print(f"\nπŸ” 검색어: '{query}'")
    results = collection.query(
        query_texts=[query],
        n_results=top_k,
        include=["documents", "metadatas", "distances"]
    )

    # λ¬Έμ„œ 리슀트만 μΆ”μΆœ
    docs = results['documents'][0]
    metadatas = results['metadatas'][0]
    distances = results['distances'][0]

    # Rerank λ¬Έμ„œ
    reranked_docs = rerank_documents(query, docs, top_k=top_k)

    reranked_data = []
    for doc in reranked_docs:
        idx = docs.index(doc)
        reranked_data.append((doc, metadatas[idx], distances[idx]))

    # for i, (doc, meta, dist) in enumerate(reranked_data):
    #     print(f"\nπŸ“„ κ²°κ³Ό {i+1} (μœ μ‚¬λ„: {1 - dist:.2f})")
    #     print(f"λ¬Έμ„œ: {doc[:150]}...")
    #     print("메타데이터:")
    #     print(meta)

    return reranked_data  # ν•„μš”ν•˜λ©΄ 리턴

    # for i, (doc, meta, dist) in enumerate(zip(
    #     results['documents'][0],
    #     results['metadatas'][0],
    #     results['distances'][0]
    # )):
    #     print(f"\nπŸ“„ κ²°κ³Ό {i+1} (μœ μ‚¬λ„: {1 - dist:.2f})")
    #     print(f"λ¬Έμ„œ: {doc[:150]}...")
    #     print("메타데이터:")
    #     print(meta)