from .base import BaseRetriever from typing import Dict, List import networkx as nx # import spacy # 暂时注释掉 from sklearn.cluster import AgglomerativeClustering import numpy as np import logging class GraphRAG(BaseRetriever): def __init__(self, config: Dict): self.config = config self.graph = nx.Graph() # self.nlp = spacy.load("zh_core_web_sm") # 暂时注释掉 self.init_retriever(config) def init_retriever(self, config: Dict): self.working_dir = config['retrieval_settings']['methods'][2]['model_settings']['working_dir'] self.graph_file = f"{self.working_dir}/graph.graphml" def retrieve(self, query: str, context: List[Dict]) -> List[Dict]: # 简单实现:基于关键词匹配的检索 scored_docs = [] for doc in context: # 简单计算query中的词在文档中出现的次数作为分数 score = sum(1 for word in query.split() if word in doc['passage']) doc_copy = doc.copy() doc_copy['graph_score'] = float(score) scored_docs.append(doc_copy) return sorted(scored_docs, key=lambda x: x['graph_score'], reverse=True) def _build_graph(self, context: List[Dict]): """简化版本的图构建""" # 仅使用简单的词频统计 for doc in context: text = doc['passage'] words = text.split() # 相邻词之间建立边 for i in range(len(words)-1): w1, w2 = words[i], words[i+1] if not self.graph.has_edge(w1, w2): self.graph.add_edge(w1, w2, weight=1) else: self.graph[w1][w2]['weight'] += 1 def _calculate_graph_score(self, query_words: List[str], doc: Dict) -> float: """简化版本的图分数计算""" score = 0.0 doc_words = doc['passage'].split() for q_word in query_words: for d_word in doc_words: if self.graph.has_edge(q_word, d_word): score += self.graph[q_word][d_word]['weight'] return score if score > 0 else 0.0