Spaces:
Running
Running
from .base import BaseRetriever | |
from typing import Dict, List | |
import networkx as nx | |
# import spacy # 暂时注释掉 | |
from sklearn.cluster import AgglomerativeClustering | |
import numpy as np | |
import logging | |
class GraphRAG(BaseRetriever): | |
def __init__(self, config: Dict): | |
self.config = config | |
self.graph = nx.Graph() | |
# self.nlp = spacy.load("zh_core_web_sm") # 暂时注释掉 | |
self.init_retriever(config) | |
def init_retriever(self, config: Dict): | |
self.working_dir = config['retrieval_settings']['methods'][2]['model_settings']['working_dir'] | |
self.graph_file = f"{self.working_dir}/graph.graphml" | |
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]: | |
# 简单实现:基于关键词匹配的检索 | |
scored_docs = [] | |
for doc in context: | |
# 简单计算query中的词在文档中出现的次数作为分数 | |
score = sum(1 for word in query.split() if word in doc['passage']) | |
doc_copy = doc.copy() | |
doc_copy['graph_score'] = float(score) | |
scored_docs.append(doc_copy) | |
return sorted(scored_docs, key=lambda x: x['graph_score'], reverse=True) | |
def _build_graph(self, context: List[Dict]): | |
"""简化版本的图构建""" | |
# 仅使用简单的词频统计 | |
for doc in context: | |
text = doc['passage'] | |
words = text.split() | |
# 相邻词之间建立边 | |
for i in range(len(words)-1): | |
w1, w2 = words[i], words[i+1] | |
if not self.graph.has_edge(w1, w2): | |
self.graph.add_edge(w1, w2, weight=1) | |
else: | |
self.graph[w1][w2]['weight'] += 1 | |
def _calculate_graph_score(self, query_words: List[str], doc: Dict) -> float: | |
"""简化版本的图分数计算""" | |
score = 0.0 | |
doc_words = doc['passage'].split() | |
for q_word in query_words: | |
for d_word in doc_words: | |
if self.graph.has_edge(q_word, d_word): | |
score += self.graph[q_word][d_word]['weight'] | |
return score if score > 0 else 0.0 |