File size: 2,043 Bytes
ed4d993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# flake8: noqa
from __future__ import annotations

import re
from typing import List

from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.exceptions import OutputParserException

_PROMPT_TEMPLATE = """If someone asks you to perform a task, your job is to come up with a series of bash commands that will perform the task. There is no need to put "#!/bin/bash" in your answer. Make sure to reason step by step, using this format:

Question: "copy the files in the directory named 'target' into a new directory at the same level as target called 'myNewDirectory'"

I need to take the following actions:
- List all files in the directory
- Create a new directory
- Copy the files from the first directory into the second directory
```bash
ls
mkdir myNewDirectory
cp -r target/* myNewDirectory
```

That is the format. Begin!

Question: {question}"""


class BashOutputParser(BaseOutputParser):
    """Parser for bash output."""

    def parse(self, text: str) -> List[str]:
        """Parse the output of a bash command."""

        if "```bash" in text:
            return self.get_code_blocks(text)
        else:
            raise OutputParserException(
                f"Failed to parse bash output. Got: {text}",
            )

    @staticmethod
    def get_code_blocks(t: str) -> List[str]:
        """Get multiple code blocks from the LLM result."""
        code_blocks: List[str] = []
        # Bash markdown code blocks
        pattern = re.compile(r"```bash(.*?)(?:\n\s*)```", re.DOTALL)
        for match in pattern.finditer(t):
            matched = match.group(1).strip()
            if matched:
                code_blocks.extend(
                    [line for line in matched.split("\n") if line.strip()]
                )

        return code_blocks

    @property
    def _type(self) -> str:
        return "bash"


PROMPT = PromptTemplate(
    input_variables=["question"],
    template=_PROMPT_TEMPLATE,
    output_parser=BashOutputParser(),
)