File size: 2,490 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import base64
import json

from langchain_community.chat_models import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate
from langchain_core.pydantic_v1 import Field
from langserve import CustomUserType

from .prompts import (
    AI_REPONSE_DICT,
    FULL_PROMPT,
    USER_EXAMPLE_DICT,
    create_prompt,
)
from .utils import parse_llm_output

llm = ChatOpenAI(temperature=0, model="gpt-4")

prompt = ChatPromptTemplate.from_messages(
    [
        SystemMessagePromptTemplate.from_template(FULL_PROMPT),
        ("human", "{user_example}"),
        ("ai", "{ai_response}"),
        ("human", "{input}"),
    ],
)


# ATTENTION: Inherit from CustomUserType instead of BaseModel otherwise
# the server will decode it into a dict instead of a pydantic model.
class FileProcessingRequest(CustomUserType):
    """Request including a base64 encoded file."""

    # The extra field is used to specify a widget for the playground UI.
    file: str = Field(..., extra={"widget": {"type": "base64file"}})
    num_plates: int = None
    num_rows: int = 8
    num_cols: int = 12


def _load_file(request: FileProcessingRequest):
    return base64.b64decode(request.file.encode("utf-8")).decode("utf-8")


def _load_prompt(request: FileProcessingRequest):
    return create_prompt(
        num_plates=request.num_plates,
        num_rows=request.num_rows,
        num_cols=request.num_cols,
    )


def _get_col_range_str(request: FileProcessingRequest):
    if request.num_cols:
        return f"from 1 to {request.num_cols}"
    else:
        return ""


def _get_json_format(request: FileProcessingRequest):
    return json.dumps(
        [
            {
                "row_start": 12,
                "row_end": 12 + request.num_rows - 1,
                "col_start": 1,
                "col_end": 1 + request.num_cols - 1,
                "contents": "Entity ID",
            }
        ]
    )


chain = (
    {
        # Should add validation to ensure numeric indices
        "input": _load_file,
        "hint": _load_prompt,
        "col_range_str": _get_col_range_str,
        "json_format": _get_json_format,
        "user_example": lambda x: USER_EXAMPLE_DICT[x.num_rows * x.num_cols],
        "ai_response": lambda x: AI_REPONSE_DICT[x.num_rows * x.num_cols],
    }
    | prompt
    | llm
    | StrOutputParser()
    | parse_llm_output
).with_types(input_type=FileProcessingRequest)