Spaces:
Runtime error
Runtime error
File size: 6,537 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import asyncio
from typing import Any, Dict, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationPiiError,
)
class ComprehendPII:
"""Class to handle Personally Identifiable Information (PII) moderation."""
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "PII",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def validate(self, prompt_value: str, config: Any = None) -> str:
redact = config.get("redact")
return (
self._detect_pii(prompt_value=prompt_value, config=config)
if redact
else self._contains_pii(prompt_value=prompt_value, config=config)
)
def _contains_pii(self, prompt_value: str, config: Any = None) -> str:
"""
Checks for Personally Identifiable Information (PII) labels above a
specified threshold. Uses Amazon Comprehend Contains PII Entities API. See -
https://docs.aws.amazon.com/comprehend/latest/APIReference/API_ContainsPiiEntities.html
Args:
prompt_value (str): The input text to be checked for PII labels.
config (Dict[str, Any]): Configuration for PII check and actions.
Returns:
str: the original prompt
Note:
- The provided client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.contains_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
threshold = config.get("threshold")
pii_labels = config.get("labels")
pii_found = False
for entity in pii_identified["Labels"]:
if (entity["Score"] >= threshold and entity["Name"] in pii_labels) or (
entity["Score"] >= threshold and not pii_labels
):
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
return prompt_value
def _detect_pii(self, prompt_value: str, config: Optional[Dict[str, Any]]) -> str:
"""
Detects and handles Personally Identifiable Information (PII) entities in the
given prompt text using Amazon Comprehend's detect_pii_entities API. The
function provides options to redact or stop processing based on the identified
PII entities and a provided configuration. Uses Amazon Comprehend Detect PII
Entities API.
Args:
prompt_value (str): The input text to be checked for PII entities.
config (Dict[str, Any]): A configuration specifying how to handle
PII entities.
Returns:
str: The processed prompt text with redacted PII entities or raised
exceptions.
Raises:
ValueError: If the prompt contains configured PII entities for
stopping processing.
Note:
- If PII is not found in the prompt, the original prompt is returned.
- The client should be initialized with valid AWS credentials.
"""
pii_identified = self.client.detect_pii_entities(
Text=prompt_value, LanguageCode="en"
)
if self.callback and self.callback.pii_callback:
self.moderation_beacon["moderation_input"] = prompt_value
self.moderation_beacon["moderation_output"] = pii_identified
if (pii_identified["Entities"]) == []:
if self.callback and self.callback.pii_callback:
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value
pii_found = False
if not config and pii_identified["Entities"]:
for entity in pii_identified["Entities"]:
if entity["Score"] >= 0.5:
pii_found = True
break
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
if pii_found:
raise ModerationPiiError
else:
threshold = config.get("threshold") # type: ignore
pii_labels = config.get("labels") # type: ignore
mask_marker = config.get("mask_character") # type: ignore
pii_found = False
for entity in pii_identified["Entities"]:
if (
pii_labels
and entity["Type"] in pii_labels
and entity["Score"] >= threshold
) or (not pii_labels and entity["Score"] >= threshold):
pii_found = True
char_offset_begin = entity["BeginOffset"]
char_offset_end = entity["EndOffset"]
mask_length = char_offset_end - char_offset_begin + 1
masked_part = mask_marker * mask_length
prompt_value = (
prompt_value[:char_offset_begin]
+ masked_part
+ prompt_value[char_offset_end + 1 :]
)
if self.callback and self.callback.pii_callback:
if pii_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_pii(self.moderation_beacon, self.unique_id)
)
return prompt_value
|