pwcGraphRAG / search_handlers.py
cordwainersmith
Add project files and Docker setup
c917d47
raw
history blame
9.86 kB
import asyncio
from pathlib import Path
import pandas as pd
from typing import Tuple, Optional
from graphrag.config import GraphRagConfig, load_config, resolve_paths
from graphrag.index.create_pipeline_config import create_pipeline_config
from graphrag.logging import PrintProgressReporter
from graphrag.utils.storage import _create_storage, _load_table_from_storage
import graphrag.api as api
class StreamlitProgressReporter(PrintProgressReporter):
def __init__(self, placeholder):
super().__init__("")
self.placeholder = placeholder
def success(self, message: str):
self.placeholder.success(message)
def _resolve_parquet_files(
root_dir: str,
config: GraphRagConfig,
parquet_list: list[str],
optional_list: list[str],
) -> dict[str, pd.DataFrame]:
"""Read parquet files to a dataframe dict."""
dataframe_dict = {}
pipeline_config = create_pipeline_config(config)
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
for parquet_file in parquet_list:
df_key = parquet_file.split(".")[0]
df_value = asyncio.run(
_load_table_from_storage(name=parquet_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
for optional_file in optional_list:
file_exists = asyncio.run(storage_obj.has(optional_file))
df_key = optional_file.split(".")[0]
if file_exists:
df_value = asyncio.run(
_load_table_from_storage(name=optional_file, storage=storage_obj)
)
dataframe_dict[df_key] = df_value
else:
dataframe_dict[df_key] = None
return dataframe_dict
def run_global_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a global search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
],
optional_list=[],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
try:
async for stream_chunk in api.global_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
progress_placeholder.markdown(full_response)
except Exception as e:
progress_placeholder.error(f"Error during streaming search: {e}")
return None, None
return full_response, context_data
result = asyncio.run(run_streaming_search())
if result is None:
return "", {} # Graceful fallback
return result
# Non-streaming logic
try:
response, context_data = asyncio.run(
api.global_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
community_level=community_level,
response_type=response_type,
query=query,
)
)
reporter.success(f"Global Search Response:\n{response}")
return response, context_data
except Exception as e:
progress_placeholder.error(f"Error during global search: {e}")
return "", {} # Graceful fallback
def run_local_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a local search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
"create_final_entities.parquet",
],
optional_list=["create_final_covariates.parquet"],
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_covariates: Optional[pd.DataFrame] = dataframe_dict["create_final_covariates"]
if streaming:
async def run_streaming_search():
full_response = ""
context_data = None
get_context_data = True
async for stream_chunk in api.local_search_streaming(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=community_level,
response_type=response_type,
query=query,
):
if get_context_data:
context_data = stream_chunk
get_context_data = False
else:
full_response += stream_chunk
progress_placeholder.markdown(full_response)
return full_response, context_data
return asyncio.run(run_streaming_search())
response, context_data = asyncio.run(
api.local_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
covariates=final_covariates,
community_level=community_level,
response_type=response_type,
query=query,
)
)
reporter.success(f"Local Search Response:\n{response}")
return response, context_data
def run_drift_search(
config_filepath: Optional[str],
data_dir: Optional[str],
root_dir: str,
community_level: int,
response_type: str,
streaming: bool,
query: str,
progress_placeholder,
) -> Tuple[str, dict]:
"""Perform a DRIFT search with a given query."""
root = Path(root_dir).resolve()
config = load_config(root, config_filepath)
reporter = StreamlitProgressReporter(progress_placeholder)
config.storage.base_dir = data_dir or config.storage.base_dir
resolve_paths(config)
dataframe_dict = _resolve_parquet_files(
root_dir=root_dir,
config=config,
parquet_list=[
"create_final_nodes.parquet",
"create_final_entities.parquet",
"create_final_community_reports.parquet",
"create_final_text_units.parquet",
"create_final_relationships.parquet",
],
optional_list=[], # Remove covariates as it's not supported
)
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
final_community_reports: pd.DataFrame = dataframe_dict[
"create_final_community_reports"
]
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
# Note: DRIFT search doesn't support streaming
if streaming:
progress_placeholder.warning(
"Streaming is not supported for DRIFT search. Using standard search instead."
)
response, context_data = asyncio.run(
api.drift_search(
config=config,
nodes=final_nodes,
entities=final_entities,
community_reports=final_community_reports,
text_units=final_text_units,
relationships=final_relationships,
community_level=community_level,
query=query,
)
)
reporter.success(f"DRIFT Search Response:\n{response}")
return response, context_data