|
import json |
|
import os |
|
import re |
|
|
|
from langchain.chains import ConversationChain, LLMChain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.chains.base import Chain |
|
|
|
from app_modules.llm_inference import LLMInference, get_system_prompt_and_user_message |
|
from app_modules.utils import CustomizedConversationSummaryBufferMemory |
|
from langchain.chains import LLMChain |
|
from langchain.globals import get_debug |
|
|
|
chat_history_enabled = os.getenv("CHAT_HISTORY_ENABLED", "false").lower() == "true" |
|
B_INST, E_INST = "[INST]", "[/INST]" |
|
|
|
|
|
def create_llama_2_prompt_template(): |
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n" |
|
|
|
system_prompt, user_message = get_system_prompt_and_user_message() |
|
|
|
SYSTEM_PROMPT = B_SYS + system_prompt + E_SYS |
|
prompt_template = B_INST + SYSTEM_PROMPT + user_message + E_INST |
|
return prompt_template |
|
|
|
|
|
def create_llama_3_prompt_template(): |
|
system_prompt, user_message = get_system_prompt_and_user_message() |
|
prompt_template = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> |
|
{ system_prompt }<|eot_id|><|start_header_id|>user<|end_header_id|> |
|
{ user_message }<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
""" |
|
|
|
return prompt_template |
|
|
|
|
|
def create_phi_3_prompt_template(): |
|
system_prompt, user_message = get_system_prompt_and_user_message() |
|
prompt_template = f"""<|system|> |
|
{ system_prompt }<|end|> |
|
<|user|> |
|
{ user_message }<|end|> |
|
<|assistant|> |
|
""" |
|
|
|
return prompt_template |
|
|
|
|
|
def create_orca_2_prompt_template(): |
|
system_prompt, user_message = get_system_prompt_and_user_message(orca=False) |
|
|
|
prompt_template = f"<|im_start|>system\n{system_prompt}<|im_end|>\n<|im_start|>user\n{user_message}<|im_end|>\n<|im_start|>assistant" |
|
return prompt_template |
|
|
|
|
|
def create_mistral_prompt_template(): |
|
system_prompt, user_message = get_system_prompt_and_user_message() |
|
|
|
prompt_template = B_INST + system_prompt + "\n\n" + user_message + E_INST |
|
return prompt_template |
|
|
|
|
|
def create_gemma_prompt_template(): |
|
return "<start_of_turn>user\n{input}<end_of_turn>\n<start_of_turn>model\n" |
|
|
|
|
|
def create_prompt_template(model_name): |
|
print(f"creating prompt template for model: {model_name}") |
|
if re.search(r"llama-?2", model_name, re.IGNORECASE): |
|
return create_llama_2_prompt_template() |
|
elif re.search(r"llama-?3", model_name, re.IGNORECASE): |
|
return create_llama_3_prompt_template() |
|
elif re.search(r"phi-?3", model_name, re.IGNORECASE): |
|
return create_phi_3_prompt_template() |
|
elif model_name.lower().startswith("orca"): |
|
return create_orca_2_prompt_template() |
|
elif model_name.lower().startswith("mistral"): |
|
return create_mistral_prompt_template() |
|
elif model_name.lower().startswith("gemma"): |
|
return create_gemma_prompt_template() |
|
|
|
return ( |
|
"""You are a chatbot having a conversation with a human. |
|
{history} |
|
Human: {input} |
|
Chatbot:""" |
|
if chat_history_enabled |
|
else """You are a chatbot having a conversation with a human. |
|
Human: {input} |
|
Chatbot:""" |
|
) |
|
|
|
|
|
class ChatChain(LLMInference): |
|
def __init__(self, llm_loader): |
|
super().__init__(llm_loader) |
|
|
|
def create_chain(self) -> Chain: |
|
template = create_prompt_template(self.llm_loader.model_name) |
|
print(f"template: {template}") |
|
|
|
if chat_history_enabled: |
|
prompt = PromptTemplate( |
|
input_variables=["history", "input"], template=template |
|
) |
|
memory = CustomizedConversationSummaryBufferMemory( |
|
llm=self.llm_loader.llm, max_token_limit=1024, return_messages=False |
|
) |
|
|
|
llm_chain = ConversationChain( |
|
llm=self.llm_loader.llm, |
|
prompt=prompt, |
|
verbose=False, |
|
memory=memory, |
|
) |
|
else: |
|
prompt = PromptTemplate(input_variables=["input"], template=template) |
|
llm_chain = LLMChain(llm=self.llm_loader.llm, prompt=prompt) |
|
|
|
return llm_chain |
|
|
|
def _process_inputs(self, inputs): |
|
if not isinstance(inputs, list): |
|
inputs = {"input": inputs["question"]} |
|
elif self.llm_loader.llm_model_type == "huggingface": |
|
inputs = [self.apply_chat_template(input["question"]) for input in inputs] |
|
else: |
|
inputs = [{"input": i["question"]} for i in inputs] |
|
|
|
if get_debug(): |
|
print("_process_inputs:", json.dumps(inputs, indent=4)) |
|
|
|
return inputs |
|
|