Spaces:
Running
Running
# Standard library imports | |
import re | |
import json | |
import asyncio | |
from typing import Annotated, Optional, Dict, Any, List | |
# Related third-party imports | |
import yaml | |
# Local imports | |
from src.text.model import LanguageModelManager | |
from src.audio.utils import Formatter | |
class LLMOrchestrator: | |
""" | |
A handler to perform specific LLM tasks such as classification or sentiment analysis. | |
This class uses a language model to perform different tasks by dynamically changing the prompt. | |
Parameters | |
---------- | |
config_path : str | |
Path to the configuration file for the language model manager. | |
prompt_config_path : str | |
Path to the configuration file containing prompts for different tasks. | |
model_id : str, optional | |
Identifier of the model to use. Defaults to "llama". | |
cache_size : int, optional | |
Cache size for the language model manager. Defaults to 2. | |
Attributes | |
---------- | |
manager : LanguageModelManager | |
An instance of LanguageModelManager for interacting with the model. | |
model_id : str | |
The identifier of the language model in use. | |
prompts : Dict[str, Dict[str, str]] | |
A dictionary containing prompts for different tasks. | |
""" | |
def __init__( | |
self, | |
config_path: Annotated[str, "Path to the configuration file"], | |
prompt_config_path: Annotated[str, "Path to the prompt configuration file"], | |
model_id: Annotated[str, "Language model identifier"] = "llama", | |
cache_size: Annotated[int, "Cache size for the language model manager"] = 2, | |
): | |
""" | |
Initializes the LLMOrchestrator with a language model manager and loads prompts. | |
Parameters | |
---------- | |
config_path : str | |
Path to the configuration file for the language model manager. | |
prompt_config_path : str | |
Path to the configuration file containing prompts for different tasks. | |
model_id : str, optional | |
Identifier of the model to use. Defaults to "llama". | |
cache_size : int, optional | |
Cache size for the language model manager. Defaults to 2. | |
""" | |
self.manager = LanguageModelManager(config_path=config_path, cache_size=cache_size) | |
self.model_id = model_id | |
self.prompts = self._load_prompts(prompt_config_path) | |
def _load_prompts(prompt_config_path: str) -> Dict[str, Dict[str, str]]: | |
""" | |
Loads prompts from the prompt configuration file. | |
Parameters | |
---------- | |
prompt_config_path : str | |
Path to the prompt configuration file. | |
Returns | |
------- | |
Dict[str, Dict[str, str]] | |
A dictionary containing prompts for different tasks. | |
""" | |
with open(prompt_config_path, encoding='utf-8') as f: | |
prompts = yaml.safe_load(f) | |
return prompts | |
def extract_json( | |
response: Annotated[str, "The response string to extract JSON from"] | |
) -> Annotated[Optional[Dict[str, Any]], "Extracted JSON as a dictionary or None if not found"]: | |
""" | |
Extracts the last valid JSON object from a given response string. | |
Parameters | |
---------- | |
response : str | |
The response string to extract JSON from. | |
Returns | |
------- | |
Optional[Dict[str, Any]] | |
The last valid JSON dictionary if successfully extracted and parsed, otherwise None. | |
""" | |
json_pattern = r'\{(?:[^{}]|(?:\{[^{}]*\}))*\}' | |
matches = re.findall(json_pattern, response) | |
for match in reversed(matches): | |
try: | |
return json.loads(match) | |
except json.JSONDecodeError: | |
continue | |
return None | |
async def generate( | |
self, | |
prompt_name: Annotated[str, "The name of the prompt to use (e.g., 'Classification', 'SentimentAnalysis')"], | |
user_input: Annotated[Any, "The user's context or input data"], | |
system_input: Annotated[Optional[Any], "The system's context or input data"] = None | |
) -> Annotated[Dict[str, Any], "Task results or error dictionary"]: | |
""" | |
Performs the specified LLM task using the selected prompt, supporting both user and optional system contexts. | |
""" | |
if prompt_name not in self.prompts: | |
return {"error": f"Prompt '{prompt_name}' is not defined in prompt.yaml."} | |
system_prompt_template = self.prompts[prompt_name].get('system', '') | |
user_prompt_template = self.prompts[prompt_name].get('user', '') | |
if not system_prompt_template or not user_prompt_template: | |
return {"error": f"Prompts for '{prompt_name}' are incomplete."} | |
formatted_user_input = Formatter.format_ssm_as_dialogue(user_input) | |
if system_input: | |
system_prompt = system_prompt_template.format(system_context=system_input) | |
else: | |
system_prompt = system_prompt_template | |
user_prompt = user_prompt_template.format(user_context=formatted_user_input) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
response = await self.manager.generate( | |
model_id=self.model_id, | |
messages=messages, | |
max_new_tokens=10000, | |
) | |
print(response) | |
dict_obj = self.extract_json(response) | |
if dict_obj: | |
return dict_obj | |
else: | |
return {"error": "No valid JSON object found in the response."} | |
class LLMResultHandler: | |
""" | |
A handler class to process and validate the output from a Language Learning Model (LLM) | |
and format structured data. | |
This class ensures that the input data conforms to expected formats and applies fallback | |
mechanisms to maintain data integrity. | |
Methods | |
------- | |
validate_and_fallback(llm_result, ssm) | |
Validates the LLM result against structured speaker metadata and applies fallback. | |
_fallback(ssm) | |
Applies fallback formatting to the speaker data. | |
log_result(ssm, llm_result) | |
Logs the final processed data and the original LLM result. | |
""" | |
def __init__(self): | |
""" | |
Initializes the LLMResultHandler class. | |
""" | |
pass | |
def validate_and_fallback( | |
self, | |
llm_result: Annotated[Dict[str, str], "LLM result with customer and CSR speaker identifiers"], | |
ssm: Annotated[List[Dict[str, Any]], "List of sentences with speaker metadata"] | |
) -> Annotated[List[Dict[str, Any]], "Processed speaker metadata"]: | |
""" | |
Validates the LLM result and applies corrections to the speaker metadata. | |
Parameters | |
---------- | |
llm_result : dict | |
A dictionary containing speaker identifiers for 'Customer' and 'CSR'. | |
ssm : list of dict | |
A list of dictionaries where each dictionary represents a sentence with | |
metadata, including the 'speaker'. | |
Returns | |
------- | |
list of dict | |
The processed speaker metadata with standardized speaker labels. | |
Examples | |
-------- | |
>>> result = {"Customer": "Speaker 1", "CSR": "Speaker 2"} | |
>>> ssm_ = [{"speaker": "Speaker 1", "text": "Hello!"}, {"speaker": "Speaker 2", "text": "Hi!"}] | |
>>> handler = LLMResultHandler() | |
>>> handler.validate_and_fallback(llm_result, ssm) | |
[{'speaker': 'Customer', 'text': 'Hello!'}, {'speaker': 'CSR', 'text': 'Hi!'}] | |
""" | |
if not isinstance(llm_result, dict): | |
return self._fallback(ssm) | |
if "Customer" not in llm_result or "CSR" not in llm_result: | |
return self._fallback(ssm) | |
customer_speaker = llm_result["Customer"] | |
csr_speaker = llm_result["CSR"] | |
speaker_pattern = r"^Speaker\s+\d+$" | |
if (not re.match(speaker_pattern, customer_speaker)) or (not re.match(speaker_pattern, csr_speaker)): | |
return self._fallback(ssm) | |
ssm_speakers = {sentence["speaker"] for sentence in ssm} | |
if customer_speaker not in ssm_speakers or csr_speaker not in ssm_speakers: | |
return self._fallback(ssm) | |
for sentence in ssm: | |
if sentence["speaker"] == csr_speaker: | |
sentence["speaker"] = "CSR" | |
elif sentence["speaker"] == customer_speaker: | |
sentence["speaker"] = "Customer" | |
else: | |
sentence["speaker"] = "Customer" | |
return ssm | |
def _fallback( | |
ssm: Annotated[List[Dict[str, Any]], "List of sentences with speaker metadata"] | |
) -> Annotated[List[Dict[str, Any]], "Fallback speaker metadata"]: | |
""" | |
Applies fallback formatting to speaker metadata when validation fails. | |
Parameters | |
---------- | |
ssm : list of dict | |
A list of dictionaries representing sentences with speaker metadata. | |
Returns | |
------- | |
list of dict | |
The speaker metadata with fallback formatting applied. | |
Examples | |
-------- | |
>>> ssm_ = [{"speaker": "Speaker 1", "text": "Hello!"}, {"speaker": "Speaker 2", "text": "Hi!"}] | |
>>> handler = LLMResultHandler() | |
>>> handler._fallback(ssm) | |
[{'speaker': 'CSR', 'text': 'Hello!'}, {'speaker': 'Customer', 'text': 'Hi!'}] | |
""" | |
if len(ssm) > 0: | |
first_speaker = ssm[0]["speaker"] | |
for sentence in ssm: | |
if sentence["speaker"] == first_speaker: | |
sentence["speaker"] = "CSR" | |
else: | |
sentence["speaker"] = "Customer" | |
return ssm | |
def log_result( | |
ssm: Annotated[List[Dict[str, Any]], "Final processed speaker metadata"], | |
llm_result: Annotated[Dict[str, str], "Original LLM result"] | |
) -> None: | |
""" | |
Logs the final processed speaker metadata and the original LLM result. | |
Parameters | |
---------- | |
ssm : list of dict | |
The processed speaker metadata. | |
llm_result : dict | |
The original LLM result. | |
Returns | |
------- | |
None | |
Examples | |
-------- | |
>>> ssm_ = [{"speaker": "CSR", "text": "Hello!"}, {"speaker": "Customer", "text": "Hi!"}] | |
>>> result = {"Customer": "Speaker 1", "CSR": "Speaker 2"} | |
>>> handler = LLMResultHandler() | |
>>> handler.log_result(ssm, llm_result) | |
Final SSM: [{'speaker': 'CSR', 'text': 'Hello!'}, {'speaker': 'Customer', 'text': 'Hi!'}] | |
LLM Result: {'Customer': 'Speaker 1', 'CSR': 'Speaker 2'} | |
""" | |
print("Final SSM:", ssm) | |
print("LLM Result:", llm_result) | |
if __name__ == "__main__": | |
# noinspection PyMissingOrEmptyDocstring | |
async def main(): | |
handler = LLMOrchestrator( | |
config_path="config/config.yaml", | |
prompt_config_path="config/prompt.yaml", | |
model_id="openai", | |
) | |
conversation = [ | |
{"speaker": "Speaker 1", "text": "Hello, I need help with my order."}, | |
{"speaker": "Speaker 0", "text": "Sure, I'd be happy to assist you."}, | |
{"speaker": "Speaker 1", "text": "I haven't received it yet."}, | |
{"speaker": "Speaker 0", "text": "Let me check the status for you."} | |
] | |
speaker_roles = await handler.generate("Classification", conversation) | |
print("Speaker Roles:", speaker_roles) | |
print("Type:", type(speaker_roles)) | |
sentiment_analyzer = LLMOrchestrator( | |
config_path="config/config.yaml", | |
prompt_config_path="config/prompt.yaml" | |
) | |
sentiment = await sentiment_analyzer.generate("SentimentAnalysis", conversation) | |
print("\nSentiment Analysis:", sentiment) | |
asyncio.run(main()) | |