Toursim-Test / src /retrieval /graph_rag.py
zhuhai111's picture
Upload 43 files
7cc8bc0 verified
from .base import BaseRetriever
from typing import Dict, List
import networkx as nx
# import spacy # 暂时注释掉
from sklearn.cluster import AgglomerativeClustering
import numpy as np
import logging
class GraphRAG(BaseRetriever):
def __init__(self, config: Dict):
self.config = config
self.graph = nx.Graph()
# self.nlp = spacy.load("zh_core_web_sm") # 暂时注释掉
self.init_retriever(config)
def init_retriever(self, config: Dict):
self.working_dir = config['retrieval_settings']['methods'][2]['model_settings']['working_dir']
self.graph_file = f"{self.working_dir}/graph.graphml"
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
# 简单实现:基于关键词匹配的检索
scored_docs = []
for doc in context:
# 简单计算query中的词在文档中出现的次数作为分数
score = sum(1 for word in query.split() if word in doc['passage'])
doc_copy = doc.copy()
doc_copy['graph_score'] = float(score)
scored_docs.append(doc_copy)
return sorted(scored_docs, key=lambda x: x['graph_score'], reverse=True)
def _build_graph(self, context: List[Dict]):
"""简化版本的图构建"""
# 仅使用简单的词频统计
for doc in context:
text = doc['passage']
words = text.split()
# 相邻词之间建立边
for i in range(len(words)-1):
w1, w2 = words[i], words[i+1]
if not self.graph.has_edge(w1, w2):
self.graph.add_edge(w1, w2, weight=1)
else:
self.graph[w1][w2]['weight'] += 1
def _calculate_graph_score(self, query_words: List[str], doc: Dict) -> float:
"""简化版本的图分数计算"""
score = 0.0
doc_words = doc['passage'].split()
for q_word in query_words:
for d_word in doc_words:
if self.graph.has_edge(q_word, d_word):
score += self.graph[q_word][d_word]['weight']
return score if score > 0 else 0.0