Spaces:
Runtime error
Runtime error
File size: 2,490 Bytes
ed4d993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import base64
import json
from langchain_community.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import Field
from langserve import CustomUserType
from .prompts import (
AI_REPONSE_DICT,
FULL_PROMPT,
USER_EXAMPLE_DICT,
create_prompt,
)
from .utils import parse_llm_output
llm = ChatOpenAI(temperature=0, model="gpt-4")
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(FULL_PROMPT),
("human", "{user_example}"),
("ai", "{ai_response}"),
("human", "{input}"),
],
)
# ATTENTION: Inherit from CustomUserType instead of BaseModel otherwise
# the server will decode it into a dict instead of a pydantic model.
class FileProcessingRequest(CustomUserType):
"""Request including a base64 encoded file."""
# The extra field is used to specify a widget for the playground UI.
file: str = Field(..., extra={"widget": {"type": "base64file"}})
num_plates: int = None
num_rows: int = 8
num_cols: int = 12
def _load_file(request: FileProcessingRequest):
return base64.b64decode(request.file.encode("utf-8")).decode("utf-8")
def _load_prompt(request: FileProcessingRequest):
return create_prompt(
num_plates=request.num_plates,
num_rows=request.num_rows,
num_cols=request.num_cols,
)
def _get_col_range_str(request: FileProcessingRequest):
if request.num_cols:
return f"from 1 to {request.num_cols}"
else:
return ""
def _get_json_format(request: FileProcessingRequest):
return json.dumps(
[
{
"row_start": 12,
"row_end": 12 + request.num_rows - 1,
"col_start": 1,
"col_end": 1 + request.num_cols - 1,
"contents": "Entity ID",
}
]
)
chain = (
{
# Should add validation to ensure numeric indices
"input": _load_file,
"hint": _load_prompt,
"col_range_str": _get_col_range_str,
"json_format": _get_json_format,
"user_example": lambda x: USER_EXAMPLE_DICT[x.num_rows * x.num_cols],
"ai_response": lambda x: AI_REPONSE_DICT[x.num_rows * x.num_cols],
}
| prompt
| llm
| StrOutputParser()
| parse_llm_output
).with_types(input_type=FileProcessingRequest)
|