mitre-attack / workflows /chat_workflow.py
nyasukun's picture
.
fa3b0ec
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import MessagesState
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from services.llm_service import LLMService
from services.file_service import save_response, create_file_element
from models.chat_state import AttackState, get_initial_state
import chainlit as cl
import json
import logging
import os
from urllib.parse import quote
# Initialize services
llm_service = LLMService()
logger = logging.getLogger(__name__)
CHAINLIT_URL = os.environ.get("SPACE_HOST")
if not CHAINLIT_URL:
CHAINLIT_URL = "http://localhost:8080"
if not CHAINLIT_URL.startswith("https://"):
CHAINLIT_URL = "https://" + CHAINLIT_URL
@cl.step(name="コンテキスト評価", type="evaluation")
async def evaluate_context_node(state: AttackState) -> AttackState:
"""Node for evaluating if the user input is valid for ATT&CK context"""
msg = cl.Message(content="")
# Get the last user message
user_messages = [msg for msg in state['messages'] if isinstance(msg, HumanMessage)]
user_message = user_messages[-1].content if user_messages else ""
try:
evaluation_result = await llm_service.evaluate_context(user_message)
state['is_valid_context'] = evaluation_result.is_valid
state['extracted_user_scenario'] = evaluation_result.extracted_scenario
state['extracted_user_layer_operation'] = evaluation_result.extracted_layer_operation
if state['is_valid_context']:
response_text = "入力はATT&CKフレームワークのコンテキストに合致します。シナリオの評価を続けます。"
else:
response_text = "申し訳ありませんが、この入力はサイバー攻撃の分析やATT&CKフレームワークのレイヤーに関する指示として認識できませんでした。適切な指示を入力してください。"
await msg.stream_token(response_text)
await msg.send()
state['messages'].append(AIMessage(content=response_text))
except Exception as e:
error_msg = f"コンテキスト評価中にエラーが発生しました: {str(e)}"
await msg.stream_token(error_msg)
await msg.send()
state['messages'].append(AIMessage(content=error_msg))
state['is_valid_context'] = False
return state
@cl.step(name="シナリオ更新", type="update")
async def update_scenario_node(state: AttackState) -> AttackState:
"""Node for updating the scenario based on user input"""
msg = cl.Message(content="")
# Get the last user message
user_message = state.get('extracted_user_scenario', None)
current_scenario = state.get('scenario', None)
if not user_message and not current_scenario:
raise ValueError("シナリオの更新に必要な情報がありません。")
try:
updated_scenario = await llm_service.generate_scenario(user_message, current_scenario)
state['scenario'] = updated_scenario
message = "新しいシナリオを作成しました。" if not current_scenario else "シナリオを更新しました。"
await msg.stream_token(message)
await msg.send()
state['messages'].append(AIMessage(content=message))
except Exception as e:
error_msg = f"シナリオの{'作成' if not current_scenario else '更新'}中にエラーが発生しました: {str(e)}"
await msg.stream_token(error_msg)
await msg.send()
state['messages'].append(AIMessage(content=error_msg))
return state
@cl.step(name="JSON生成/更新", type="generation", language="json")
async def generate_json_node(state: AttackState) -> AttackState:
"""Node for generating or updating ATT&CK Navigator JSON"""
user_message = state.get('extracted_user_layer_operation')
current_scenario = state.get('scenario')
existing_json = state.get('attack_json')
try:
json_content = await llm_service.generate_attack_json(user_message, current_scenario, existing_json)
# Save JSON to file
filename, filepath = save_response(json_content)
file_element = create_file_element(filename, filepath)
json_url = CHAINLIT_URL + "/" + filepath
json_url = quote(json_url)
# Prepare and send the response message
response = "MITRE ATT&CK Navigatorレイヤーを更新しました。" if existing_json else "MITRE ATT&CK Navigatorレイヤーを生成しました。"
response += " ファイルをダウンロードしてインポートできます。"
response += f"ATT&CK Navigator : https://mitre-attack.github.io/attack-navigator//#layerURL={json_url}"
msg = cl.Message(content=response, elements=[file_element])
await msg.send()
# Update state
state['messages'].append(AIMessage(content=response))
state['attack_json'] = json.loads(json_content)
except Exception as e:
error_msg = f"ATT&CK Navigatorレイヤーの生成中にエラーが発生しました: {str(e)}"
msg = cl.Message(content=error_msg)
await msg.send()
state['messages'].append(AIMessage(content=error_msg))
return state
async def display_state_node(state: AttackState) -> AttackState:
"""Node for displaying the current state before ending"""
async with cl.Step(name="状態表示", type="display") as step:
# コンテキストの評価結果
if state.get('is_valid_context') is not None:
status = "有効" if state['is_valid_context'] else "無効"
step.input = "コンテキスト評価"
step.output = f"コンテキストの評価結果: {status}"
# 現在のシナリオ
if state.get('scenario'):
async with cl.Step(name="シナリオ情報", type="display") as scenario_step:
scenario_step.input = "現在のシナリオ"
scenario_step.output = state['scenario']
# JSONの状態
if state.get('attack_json'):
async with cl.Step(name="ATT&CK情報", type="display") as attack_step:
techniques = state['attack_json'].get('techniques', [])
technique_count = len(techniques)
attack_step.input = "登録済みテクニック"
if technique_count > 0:
technique_list = "\n".join([
f"- {t.get('techniqueID', 'Unknown')}: {t.get('comment', '説明なし')}"
for t in techniques[:5]
])
if technique_count > 5:
technique_list += f"\n... 他 {technique_count - 5} 件"
attack_step.output = f"登録済みテクニック数: {technique_count}\n\n{technique_list}"
else:
attack_step.output = "登録済みテクニックはありません"
msg = cl.Message(content="処理が完了しました。")
await msg.send()
return state
# Create the graph
workflow = StateGraph(AttackState)
# Add nodes
workflow.add_node("evaluate_context", evaluate_context_node)
workflow.add_node("update_scenario", update_scenario_node)
workflow.add_node("generate_json", generate_json_node)
workflow.add_node("display_state", display_state_node)
# Add edges
workflow.add_edge(START, "evaluate_context")
workflow.add_conditional_edges(
"evaluate_context",
lambda state: state.get('is_valid_context', False),
{
True: "update_scenario",
False: "display_state"
}
)
workflow.add_edge("update_scenario", "generate_json")
workflow.add_edge("generate_json", "display_state")
workflow.add_edge("display_state", END)
# Compile the graph
chainlit_app = workflow.compile()
async def test_evaluate_context_node(state: AttackState) -> AttackState:
"""テスト用のコンテキスト評価ノード"""
user_messages = [msg for msg in state['messages'] if isinstance(msg, HumanMessage)]
user_message = user_messages[-1].content if user_messages else ""
try:
evaluation_result = await llm_service.evaluate_context(user_message)
state['is_valid_context'] = evaluation_result.is_valid
state['extracted_user_scenario'] = evaluation_result.extracted_scenario
state['extracted_user_layer_operation'] = evaluation_result.extracted_layer_operation
response_text = "入力はATT&CKフレームワークのコンテキストに合致します。シナリオの評価を続けます。" if state['is_valid_context'] else "申し訳ありませんが、この入力はサイバー攻撃の分析やATT&CKフレームワークのレイヤーに関する指示として認識できませんでした。適切な指示を入力してください。"
state['messages'].append(AIMessage(content=response_text))
except Exception as e:
error_msg = f"コンテキスト評価中にエラーが発生しました: {str(e)}"
state['messages'].append(AIMessage(content=error_msg))
state['is_valid_context'] = False
return state
async def test_update_scenario_node(state: AttackState) -> AttackState:
"""テスト用のシナリオ更新ノード"""
# Get the last user message
user_message = state.get('extracted_user_scenario')
current_scenario = state.get('scenario')
try:
updated_scenario = await llm_service.generate_scenario(user_message, current_scenario)
state['scenario'] = updated_scenario
message = "新しいシナリオを作成しました。" if not current_scenario else "シナリオを更新しました。"
state['messages'].append(AIMessage(content=message))
except Exception as e:
error_msg = f"シナリオの{'作成' if not current_scenario else '更新'}中にエラーが発生しました: {str(e)}"
state['messages'].append(AIMessage(content=error_msg))
return state
async def test_generate_json_node(state: AttackState) -> AttackState:
"""テスト用のJSON生成ノード"""
user_message = state.get('extracted_user_layer_operation')
current_scenario = state.get('scenario')
existing_json = state.get('attack_json')
try:
json_content = await llm_service.generate_attack_json(user_message, current_scenario, existing_json)
response = "MITRE ATT&CK Navigatorレイヤーを更新しました。" if existing_json else "MITRE ATT&CK Navigatorレイヤーを生成しました。"
response += " ファイルをダウンロードしてインポートできます。"
state['messages'].append(AIMessage(content=response))
state['attack_json'] = json.loads(json_content)
except Exception as e:
error_msg = f"ATT&CK Navigatorレイヤーの生成中にエラーが発生しました: {str(e)}"
state['messages'].append(AIMessage(content=error_msg))
return state
async def test_display_state_node(state: AttackState) -> AttackState:
"""テスト用の状態表示ノード"""
summary = []
if state.get('is_valid_context') is not None:
status = "有効" if state['is_valid_context'] else "無効"
summary.append(f"コンテキストの評価結果: {status}")
if state.get('scenario'):
summary.append(f"現在のシナリオ:\n{state['scenario']}")
if state.get('attack_json'):
techniques = state['attack_json'].get('techniques', [])
technique_count = len(techniques)
summary.append(f"登録済みテクニック数: {technique_count}")
if technique_count > 0:
technique_list = "\n".join([
f"- {t.get('techniqueID', 'Unknown')}: {t.get('comment', '説明なし')}"
for t in techniques[:5]
])
if technique_count > 5:
technique_list += f"\n... 他 {technique_count - 5} 件"
summary.append(f"\n登録済みテクニック:\n{technique_list}")
if summary:
state_summary = "\n\n".join(summary)
state['messages'].append(AIMessage(content=f"現在の状態:\n{state_summary}"))
return state
async def main():
"""テスト用のメイン関数"""
try:
# 初期状態の作成
initial_state = get_initial_state()
# テスト用の既存シナリオ
existing_scenario = """
標的システムへの不正アクセスシナリオ
概要:
攻撃者は、標的のシステムに不正アクセスを試み、機密情報を窃取します。
攻撃フェーズ:
1. 初期アクセス
- パスワードスプレー攻撃による認証情報の取得
- 有効なアカウントの特定
2. 実行
- 取得した認証情報を使用してシステムにログイン
- 不正なコマンドの実行
3. 権限昇格
- 管理者権限の取得
- システム設定の変更
4. 防御回避
- ログの削除
- 攻撃痕跡の隠蔽
"""
# テスト用の既存JSON
existing_json = {
"name": "Test Layer",
"versions": {
"attack": "16.0",
"navigator": "4.9.0",
"layer": "4.5"
},
"domain": "enterprise-attack",
"description": "Test layer for development",
"filters": {
"platforms": ["Windows", "Linux", "macOS"]
},
"gradient": {
"colors": ["#ffffff", "#ff6666"],
"minValue": 0,
"maxValue": 100
},
"techniques": [
{
"techniqueID": "T1110",
"score": 50,
"color": "#ff6666",
"comment": "パスワードスプレー攻撃による認証情報の取得",
"enabled": True
},
{
"techniqueID": "T1078",
"score": 50,
"color": "#ff6666",
"comment": "有効なアカウントを使用した不正アクセス",
"enabled": True
}
]
}
# 初期状態に既存データを設定
initial_state['scenario'] = existing_scenario
initial_state['attack_json'] = existing_json
# テスト用のユーザーメッセージ
test_message = """
以下の攻撃シナリオを分析してください:
攻撃者は、標的のシステムに不正アクセスを試みます。
まず、パスワードスプレー攻撃を実行して、有効なアカウントの認証情報を取得します。
取得した認証情報を使用して、システムにログインし、機密情報を窃取します。
最後に、攻撃の痕跡を隠蔽するために、ログを削除します。
"""
test_message = "ATTACKのバージョンを16にして、テクニックは青にして。"
# ユーザーメッセージを状態に追加
initial_state['messages'].append(HumanMessage(content=test_message))
# テスト用ワークフローの実行
state = await test_evaluate_context_node(initial_state)
if state.get('is_valid_context', False):
# シナリオ更新
state = await test_update_scenario_node(state)
# JSON生成
state = await test_generate_json_node(state)
# 状態表示
state = await test_display_state_node(state)
# 結果の表示
for msg in state['messages']:
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
print(f"\n{role}:")
print(msg.content)
except Exception as e:
print(f"エラーが発生しました: {str(e)}")
raise
if __name__ == "__main__":
import asyncio
asyncio.run(main())