File size: 6,637 Bytes
256a159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from typing import List

from mmpretrain.structures import DataSample


class MiniGPT4MMBenchPromptConstructor:
    """Prompt constructor for MiniGPT-4 on MMBench.

    Args:
        image_prompt (str): Image prompt. Defaults to `''`.
        reply_prompt (str): Reply prompt. Defaults to `''`.
    """

    def __init__(self, image_prompt: str = '', reply_prompt: str = '') -> None:
        self.image_prompt = image_prompt
        self.reply_prompt = reply_prompt

    def __call__(self, inputs: dict) -> dict:
        """Construct prompt.

        Args:
            inputs (dict): Input data containing image and data_samples.

        Returns:
            dict: A dict containing prompt, images and data_samples.
        """
        data_samples = inputs['data_samples']
        prompt = self._process(data_samples)
        inputs.update({'prompt': prompt})

        return inputs

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        options = [data_sample.get('options') for data_sample in data_samples]
        contexts = [data_sample.get('context') for data_sample in data_samples]
        question = questions[0]
        option = options[0]
        context = contexts[0]
        if context is not None:
            prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + option + ' ' + self.reply_prompt  # noqa
        else:
            prompt = self.image_prompt + ' ' + question + ' ' + option + ' ' + self.reply_prompt  # noqa
        return prompt


class MiniGPT4COCOCaotionPromptConstructor(MiniGPT4MMBenchPromptConstructor):
    """Prompt constructor for MiniGPT-4 on COCO Caption."""

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        prompt = self.image_prompt + ' ' + 'a photo of' + self.reply_prompt
        return prompt


class MiniGPT4ScienceQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
    """Prompt constructor for MiniGPT-4 on ScienceQA."""

    choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            'Question: ' + data_sample.get('question') + '\n'
            for data_sample in data_samples
        ]  # noqa
        choices = [data_sample.get('choices') for data_sample in data_samples]
        choices = [[
            f'({self.choice_mapping[i]}) ' + item
            for i, item in enumerate(choice)
        ] for choice in choices]
        choices = [
            'Choices: ' + ' '.join(choice) + '\n' for choice in choices
        ]  # noqa
        contexts = [
            'Context: ' + data_sample.get('hint') + '\n'
            for data_sample in data_samples
        ]  # noqa
        question = questions[0]
        choice = choices[0]
        context = contexts[0]
        prompt = self.image_prompt + ' ' + context + ' ' + question + ' ' + choice + self.reply_prompt + ' ' + 'The answer is'  # noqa
        return prompt


class MiniGPT4VQAPromptConstructor(MiniGPT4MMBenchPromptConstructor):
    """Prompt constructor for MiniGPT-4 on VQA."""

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = self.image_prompt + ' ' + question + ' ' + 'Answer this question in a single word.' + ' ' + self.reply_prompt  # noqa
        return prompt


class MiniGPT4VSRPromptConstructor(MiniGPT4MMBenchPromptConstructor):
    """Prompt constructor for MiniGPT-4 on VSR."""

    def _process(self, data_samples: List[DataSample]) -> str:
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = self.image_prompt + ' ' + question + ' ' + 'Is the above description correct? Answer yes or no.' + ' ' + self.reply_prompt  # noqa
        return prompt


class MiniGPT4SEEDBenchPromptConstructor(MiniGPT4MMBenchPromptConstructor):

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        questions = [
            data_sample.get('question') for data_sample in data_samples
        ]
        question = questions[0]
        prompt = self.image_prompt + ' ' + question + ' ' + self.reply_prompt
        return prompt


class MiniGPT4MMEPromptConstructor:
    """Prompt constructor for MiniGPT-4 on MME.

    Args:
        image_prompt (str): Image prompt. Defaults to `''`.
        reply_prompt (str): Reply prompt. Defaults to `''`.
    """

    def __init__(self) -> None:
        self.system_prompt = (
            'Give the following image: <Img>ImageContent</Img>.'
            'You will be able to see the image once I provide it to you.'
            'Please answer my questions.')
        self.sep = '###'

    def __call__(self, inputs: dict) -> dict:
        """Construct prompt.

        Args:
            inputs (dict): Input data containing image and data_samples.

        Returns:
            dict: A dict containing prompt, images and data_samples.
        """
        data_samples = inputs['data_samples']
        prompt = self._process(data_samples)
        inputs.update({'prompt': prompt})

        return inputs

    def _process(self, data_samples: List[DataSample]) -> str:
        """Process data sample to prompt.

        Args:
            data_samples (List[DataSample]): A list of data_samples.

        Returns:
            str: Prompt.
        """
        assert len(data_samples) == 1, 'Only support batch size 1.'
        question = data_samples[0].get('question')
        prompt = self.system_prompt + self.sep
        prompt += 'Human: ' + question + ' ' + '<Img><ImageHere></Img>' + ' ' + self.sep  # noqa
        prompt += 'Assistant: '
        return prompt