Spaces:
Runtime error
Runtime error
File size: 6,156 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import asyncio
import importlib
from typing import Any, List, Optional
from langchain_experimental.comprehend_moderation.base_moderation_exceptions import (
ModerationToxicityError,
)
class ComprehendToxicity:
"""Class to handle toxicity moderation."""
def __init__(
self,
client: Any,
callback: Optional[Any] = None,
unique_id: Optional[str] = None,
chain_id: Optional[str] = None,
) -> None:
self.client = client
self.moderation_beacon = {
"moderation_chain_id": chain_id,
"moderation_type": "Toxicity",
"moderation_status": "LABELS_NOT_FOUND",
}
self.callback = callback
self.unique_id = unique_id
def _toxicity_init_validate(self, max_size: int) -> Any:
"""
Validate and initialize toxicity processing configuration.
Args:
max_size (int): Maximum sentence size defined in the
configuration object.
Raises:
Exception: If the maximum sentence size exceeds the 5KB limit.
Note:
This function ensures that the NLTK punkt tokenizer is downloaded
if not already present.
Returns:
None
"""
if max_size > 1024 * 5:
raise Exception("The sentence length should not exceed 5KB.")
try:
nltk = importlib.import_module("nltk")
nltk.data.find("tokenizers/punkt")
return nltk
except ImportError:
raise ModuleNotFoundError(
"Could not import nltk python package. "
"Please install it with `pip install nltk`."
)
except LookupError:
nltk.download("punkt")
def _split_paragraph(
self, prompt_value: str, max_size: int = 1024 * 4
) -> List[List[str]]:
"""
Split a paragraph into chunks of sentences, respecting the maximum size limit.
Args:
paragraph (str): The input paragraph to be split into chunks.
max_size (int, optional): The maximum size limit in bytes for
each chunk. Defaults to 1024.
Returns:
List[List[str]]: A list of chunks, where each chunk is a list
of sentences.
Note:
This function validates the maximum sentence size based on service
limits using the 'toxicity_init_validate' function. It uses the NLTK
sentence tokenizer to split the paragraph into sentences.
Example:
paragraph = "This is a sample paragraph. It
contains multiple sentences. ..."
chunks = split_paragraph(paragraph, max_size=2048)
"""
# validate max. sentence size based on Service limits
nltk = self._toxicity_init_validate(max_size)
sentences = nltk.sent_tokenize(prompt_value)
chunks = list() # type: ignore
current_chunk = list() # type: ignore
current_size = 0
for sentence in sentences:
sentence_size = len(sentence.encode("utf-8"))
# If adding a new sentence exceeds max_size
# or current_chunk has 10 sentences, start a new chunk
if (current_size + sentence_size > max_size) or (len(current_chunk) >= 10):
if current_chunk: # Avoid appending empty chunks
chunks.append(current_chunk)
current_chunk = []
current_size = 0
current_chunk.append(sentence)
current_size += sentence_size
# Add any remaining sentences
if current_chunk:
chunks.append(current_chunk)
return chunks
def validate(self, prompt_value: str, config: Any = None) -> str:
"""
Check the toxicity of a given text prompt using AWS
Comprehend service and apply actions based on configuration.
Args:
prompt_value (str): The text content to be checked for toxicity.
config (Dict[str, Any]): Configuration for toxicity checks and actions.
Returns:
str: The original prompt_value if allowed or no toxicity found.
Raises:
ValueError: If the prompt contains toxic labels and cannot be
processed based on the configuration.
"""
chunks = self._split_paragraph(prompt_value=prompt_value)
for sentence_list in chunks:
segments = [{"Text": sentence} for sentence in sentence_list]
response = self.client.detect_toxic_content(
TextSegments=segments, LanguageCode="en"
)
if self.callback and self.callback.toxicity_callback:
self.moderation_beacon["moderation_input"] = segments # type: ignore
self.moderation_beacon["moderation_output"] = response
toxicity_found = False
threshold = config.get("threshold")
toxicity_labels = config.get("labels")
if not toxicity_labels:
for item in response["ResultList"]:
for label in item["Labels"]:
if label["Score"] >= threshold:
toxicity_found = True
break
else:
for item in response["ResultList"]:
for label in item["Labels"]:
if (
label["Name"] in toxicity_labels
and label["Score"] >= threshold
):
toxicity_found = True
break
if self.callback and self.callback.toxicity_callback:
if toxicity_found:
self.moderation_beacon["moderation_status"] = "LABELS_FOUND"
asyncio.create_task(
self.callback.on_after_toxicity(
self.moderation_beacon, self.unique_id
)
)
if toxicity_found:
raise ModerationToxicityError
return prompt_value
|