File size: 4,558 Bytes
256a159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
from typing import Optional
from mmpretrain.structures import DataSample
class OpenFlamingoMMBenchPromptConstructor:
"""MMBench prompt constructor for OpenFlamingo."""
def __init__(self) -> None:
pass
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
prompts = []
question = sample.get('question')
option = sample.get('options')
prompt = '<image>' + question + ' ' + option + ' ' + 'Answer:'
if sample.get('context') is not None:
prompt = sample.get('context') + ' ' + prompt
prompts.append(prompt)
return prompts
class OpenFlamingoCaptionPromptConstructor:
"""Caption prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Output:A child holding a flowered umbrella and petting a yak.<|endofchunk|>' # noqa
'Output:The child is holding a brush close to his mouth.<|endofchunk|>' # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
prompts = []
prompt = '<image>Output:'
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoVQAPromptConstructor:
"""VQA prompt constructor for OpenFlamingo."""
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
'Question:Is the sky dark? Short Answer:yes<|endofchunk|>' # noqa: E501
'Question:What is on the white wall? Short Answer:pipe<|endofchunk|>' # noqa: E501
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
prompts = []
for sample in data_samples:
question = sample.get('question')
prompt = '<image>Question:{} Short Answer:'.format(question)
prompts.append(self.shot_prompt + prompt)
return prompts
class OpenFlamingoScienceQAPromptConstructor:
"""ScienceQA prompt constructor for OpenFlamingo."""
choice_mapping = {0: 'A', 1: 'B', 2: 'C', 3: 'D', 4: 'E', 5: 'F'}
def __init__(self, shot_prompt: Optional[str] = None) -> None:
if shot_prompt:
self.shot_prompt = shot_prompt
else:
self.shot_prompt = (
"Context:Question:Which of these states is farthest north? Choices:['(A) West Virginia' '(B) Louisiana' '(C) Arizona' '(D) Oklahoma'] Answer with a single character: A<|endofchunk|>" # noqa
'Context:The diagrams below show two pure samples of gas in identical closed, rigid containers. Each colored ball represents one gas particle. Both samples have the same number of particles.' # noqa
"Question:Compare the average kinetic energies of the particles in each sample. Which sample has the higher temperature? Choices:'[(A) neither' '(B) sample A' '(C) sample B'] Answer with a single character: C<|endofchunk|>" # noqa
) # noqa
def __call__(self, data_samples: DataSample) -> tuple:
"""Construct prompt.
Args:
data_samples (DataSample): Input data_samples.
Returns:
Raw text input (str).
"""
assert len(data_samples) == 1
sample = data_samples[0]
question = sample.get('question')
choices = sample.get('choices')
choices = [
f'({self.choice_mapping[i]}) ' + item
for i, item in enumerate(choices)
]
hint = sample.get('hint')
prompts = []
prompt = '<image>Context:{} Question:{} Choices:{}'.format(
hint, question, choices)
prompt += ' Answer with a single character:'
prompts.append(self.shot_prompt + prompt)
return prompts
|