|
from langchain.chains import GraphCypherQAChain
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
from decouple import config
|
|
from typing import List, Optional, Dict, Type
|
|
|
|
from langchain.callbacks.manager import (
|
|
AsyncCallbackManagerForToolRun,
|
|
CallbackManagerForToolRun,
|
|
)
|
|
from langchain.pydantic_v1 import BaseModel, Field
|
|
from langchain.tools import BaseTool
|
|
from langchain_community.graphs import Neo4jGraph
|
|
from neo4j.exceptions import ServiceUnavailable
|
|
import re
|
|
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
|
|
NEO4J_URL = config("NEO4J_URL",)
|
|
NEO4J_USER = "neo4j"
|
|
NEO4J_PASSWORD = config("NEO4J_PASSWORD")
|
|
|
|
|
|
def remove_lucene_chars(text: str) -> str:
|
|
"""Remove Lucene special characters"""
|
|
special_chars = [
|
|
"+",
|
|
"-",
|
|
"&",
|
|
"|",
|
|
"!",
|
|
"(",
|
|
")",
|
|
"{",
|
|
"}",
|
|
"[",
|
|
"]",
|
|
"^",
|
|
'"',
|
|
"~",
|
|
"*",
|
|
"?",
|
|
":",
|
|
"\\",
|
|
]
|
|
for char in special_chars:
|
|
if char in text:
|
|
text = text.replace(char, " ")
|
|
return text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_full_text_query(input: str) -> str:
|
|
"""
|
|
Generate a full-text search query for a given input string.
|
|
|
|
This function constructs a query string suitable for a full-text search.
|
|
It processes the input string by splitting it into words and appending a
|
|
similarity threshold (~0.7) to each word, then combines them using the AND
|
|
operator. Useful for mapping movies and people from user questions
|
|
to database values, and allows for some misspelings.
|
|
"""
|
|
full_text_query = ""
|
|
words = [el for el in remove_lucene_chars(input).split() if el]
|
|
for word in words[:-1]:
|
|
full_text_query += f" {word}~0.7 AND"
|
|
full_text_query += f" {words[-1]}~0.7"
|
|
return full_text_query.strip()
|
|
|
|
|
|
candidate_query = """
|
|
CALL db.index.fulltext.queryNodes($index, $fulltextQuery, {limit: $limit})
|
|
YIELD node
|
|
RETURN node.name AS name,
|
|
node.summary AS summary,
|
|
labels(node) AS label
|
|
|
|
"""
|
|
|
|
person_description_query = """
|
|
MATCH (e:PERSON)-[r:IN_CATEGORY]-(m:CATEGORY)
|
|
WHERE e.name = $name
|
|
RETURN e.name AS name,
|
|
e.gender AS gender,
|
|
e.summary AS summary,
|
|
m.name AS policy_category
|
|
LIMIT 1
|
|
"""
|
|
|
|
organization_description_query = """
|
|
MATCH (o:ORGANIZATION)-[r:IN_CATEGORY]-(m:CATEGORY)
|
|
WHERE o.name = $name
|
|
RETURN o.name AS name,
|
|
o.summary AS summary,
|
|
o.description AS description,
|
|
o.twitterUri AS twitter_uri,
|
|
o.homepageUri AS homepage_uri,
|
|
m.name AS policy_category
|
|
LIMIT 1
|
|
"""
|
|
|
|
@retry(
|
|
retry=retry_if_exception_type(ServiceUnavailable),
|
|
stop=stop_after_attempt(4),
|
|
wait=wait_exponential(multiplier=1, min=4, max=8)
|
|
)
|
|
def execute_query(query, params):
|
|
graph = Neo4jGraph(NEO4J_URL, NEO4J_USER, NEO4J_PASSWORD)
|
|
return graph.query(query, params)
|
|
|
|
|
|
def get_candidates(input: str, index: str, limit: int = 3) -> List[Dict[str, str]]:
|
|
"""
|
|
Retrieve a list of candidate entities from database based on the input string.
|
|
|
|
This function queries the Neo4j database using a full-text search. It takes the
|
|
input string, generates a full-text query, and executes this query against the
|
|
specified index in the database. The function returns a list of candidates
|
|
matching the query, with each candidate being a dictionary containing their name,
|
|
summary, and label (either 'ORGANIZATION' or 'PERSON').
|
|
"""
|
|
ft_query = generate_full_text_query(input)
|
|
print(ft_query)
|
|
candidates = execute_query(
|
|
candidate_query, {"fulltextQuery": ft_query, "index": index, "limit": limit}
|
|
)
|
|
return candidates
|
|
|
|
|
|
def get_information(entity: str, index: str) -> dict:
|
|
candidates = get_candidates(entity, index)
|
|
if not candidates:
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
candidate = candidates[0]
|
|
description_query = (person_description_query if index == "dangerous_individuals" else organization_description_query)
|
|
|
|
data = execute_query(
|
|
description_query, params={"name": candidate["name"]}
|
|
)
|
|
|
|
if not data:
|
|
return None
|
|
candidate_data = data[0]
|
|
|
|
detailed_info = {key.replace('_', ' ').title(): value for key, value in candidate_data.items()}
|
|
return detailed_info
|
|
|
|
|
|
class InformationInput(BaseModel):
|
|
entity: str = Field(description="full-text search query of the name of a given entity which are mentioned in the question. Example: 'Alice Bob")
|
|
entity_type: str = Field(
|
|
description="indexed list to search for membership by the entity. Available options are 'dangerous_organizations' and 'dangerous_individuals'"
|
|
)
|
|
|
|
|
|
class DIOInformationTool(BaseTool):
|
|
name = "Dangerous_Individuals_And_Organizations_Information"
|
|
description = (
|
|
"useful for when you need to answer questions about various elected officials or persons running for office. "
|
|
"Never generate a final answer to the question if multiple candidates were matched by name; in those cases, "
|
|
"always present the candidate options as a bulleted list and ask for disambiguation in your <avi_answer> tag. Never embellish your descriptions of "
|
|
"the candidate in question or assume their motivations from their professional activities under any circumstances; "
|
|
"remain 100% strictly fact-focused at all times with the outputs of this tool."
|
|
)
|
|
args_schema: Type[BaseModel] = InformationInput
|
|
|
|
def _run(
|
|
self,
|
|
entity: str,
|
|
entity_type: str,
|
|
run_manager: Optional[CallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool."""
|
|
return get_information(entity, entity_type)
|
|
|
|
async def _arun(
|
|
self,
|
|
entity: str,
|
|
entity_type: str,
|
|
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
|
) -> str:
|
|
"""Use the tool asynchronously."""
|
|
return get_information(entity, entity_type)
|
|
|
|
|
|
|
|
|
|
|