katuni4ka commited on
Commit
90996d5
·
verified ·
1 Parent(s): 10ddc7b

code compatibility with python3.9

Browse files
Files changed (1) hide show
  1. processing_maira2.py +48 -48
processing_maira2.py CHANGED
@@ -3,7 +3,7 @@
3
 
4
 
5
  import re
6
- from typing import Any, TypeAlias, Union, List
7
 
8
  import numpy as np
9
  from PIL import Image
@@ -14,9 +14,9 @@ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
14
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
15
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
 
17
- SingleChatMessageType: TypeAlias = dict[str, str | int | None]
18
- ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
19
- BoxType: TypeAlias = tuple[float, float, float, float]
20
 
21
 
22
  class Maira2Processor(LlavaProcessor):
@@ -55,9 +55,9 @@ class Maira2Processor(LlavaProcessor):
55
  self,
56
  image_processor: BaseImageProcessor = None,
57
  tokenizer: PreTrainedTokenizer = None,
58
- patch_size: int | None = None,
59
- vision_feature_select_strategy: str | None = None,
60
- chat_template: str | None = None,
61
  image_token: str = "<image>",
62
  phrase_start_token: str = "<obj>",
63
  phrase_end_token: str = "</obj>",
@@ -106,9 +106,9 @@ class Maira2Processor(LlavaProcessor):
106
  def _normalize_and_stack_images(
107
  self,
108
  current_frontal: Image.Image,
109
- current_lateral: Image.Image | None,
110
- prior_frontal: Image.Image | None,
111
- ) -> list[Image.Image]:
112
  """
113
  This function normalizes the input images and stacks them together. The images are stacked in the order of
114
  current_frontal, current_lateral, and prior_frontal. The order of images is important, since it must match the
@@ -133,7 +133,7 @@ class Maira2Processor(LlavaProcessor):
133
  return images
134
 
135
  @staticmethod
136
- def _get_section_text_or_missing_text(section: str | None) -> str:
137
  """
138
  This function returns the input section text if it is not None and not empty, otherwise it returns a missing
139
  section text "N/A".
@@ -151,7 +151,7 @@ class Maira2Processor(LlavaProcessor):
151
  return section
152
 
153
  @staticmethod
154
- def _construct_image_chat_messages_for_reporting(has_prior: bool, has_lateral: bool) -> list[SingleChatMessageType]:
155
  """
156
  This function constructs user chat messages based on the presence of the prior and lateral images.
157
 
@@ -187,7 +187,7 @@ class Maira2Processor(LlavaProcessor):
187
  ]
188
  )
189
 
190
- image_prompt: list[SingleChatMessageType] = []
191
  image_index = 0
192
  if not has_prior and not has_lateral:
193
  _add_single_image_to_chat_messages("Given the current frontal image only", image_index)
@@ -208,13 +208,13 @@ class Maira2Processor(LlavaProcessor):
208
  self,
209
  has_prior: bool,
210
  has_lateral: bool,
211
- indication: str | None,
212
- technique: str | None,
213
- comparison: str | None,
214
- prior_report: str | None,
215
  get_grounding: bool = False,
216
- assistant_text: str | None = None,
217
- ) -> ChatMessageListType:
218
  """
219
  This function constructs the chat messages for reporting used in the grounded and non-grounded reporting tasks.
220
 
@@ -299,14 +299,14 @@ class Maira2Processor(LlavaProcessor):
299
  "type": "text",
300
  }
301
  )
302
- messages: ChatMessageListType = [{"content": prompt, "role": "user"}]
303
  if assistant_text is not None:
304
  messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
305
  return messages
306
 
307
  def _construct_chat_messages_phrase_grounding(
308
- self, phrase: str, assistant_text: str | None = None
309
- ) -> ChatMessageListType:
310
  """
311
  This function constructs the chat messages for phrase grounding used in the phrase grounding task.
312
 
@@ -319,7 +319,7 @@ class Maira2Processor(LlavaProcessor):
319
  Returns:
320
  ChatMessageListType: The chat messages for phrase grounding in the form of a list of dictionaries.
321
  """
322
- prompt: list[SingleChatMessageType] = [
323
  {"index": None, "text": "Given the current frontal image", "type": "text"},
324
  {"index": 0, "text": None, "type": "image"},
325
  {
@@ -329,7 +329,7 @@ class Maira2Processor(LlavaProcessor):
329
  "type": "text",
330
  },
331
  ]
332
- messages: ChatMessageListType = [{"content": prompt, "role": "user"}]
333
  if assistant_text is not None:
334
  messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
335
  return messages
@@ -337,15 +337,15 @@ class Maira2Processor(LlavaProcessor):
337
  def format_reporting_input(
338
  self,
339
  current_frontal: Image.Image,
340
- current_lateral: Image.Image | None,
341
- prior_frontal: Image.Image | None,
342
- indication: str | None,
343
- technique: str | None,
344
- comparison: str | None,
345
- prior_report: str | None,
346
  get_grounding: bool = False,
347
- assistant_text: str | None = None,
348
- ) -> tuple[str, list[Image.Image]]:
349
  """
350
  This function formats the reporting prompt for the grounded and non-grounded reporting tasks from the given
351
  input images and text sections. The images are normalized and stacked together in the right order.
@@ -395,8 +395,8 @@ class Maira2Processor(LlavaProcessor):
395
  self,
396
  frontal_image: Image.Image,
397
  phrase: str,
398
- assistant_text: str | None = None,
399
- ) -> tuple[str, list[Image.Image]]:
400
  """
401
  This function formats the phrase grounding prompt for the phrase grounding task from the given input
402
  image and phrase.
@@ -425,14 +425,14 @@ class Maira2Processor(LlavaProcessor):
425
  def format_and_preprocess_reporting_input(
426
  self,
427
  current_frontal: Image.Image,
428
- current_lateral: Image.Image | None,
429
- prior_frontal: Image.Image | None,
430
- indication: str | None,
431
- technique: str | None,
432
- comparison: str | None,
433
- prior_report: str | None,
434
  get_grounding: bool = False,
435
- assistant_text: str | None = None,
436
  **kwargs: Any,
437
  ) -> BatchFeature:
438
  """
@@ -481,7 +481,7 @@ class Maira2Processor(LlavaProcessor):
481
  self,
482
  frontal_image: Image.Image,
483
  phrase: str,
484
- assistant_text: str | None = None,
485
  **kwargs: Any,
486
  ) -> BatchFeature:
487
  """
@@ -507,7 +507,7 @@ class Maira2Processor(LlavaProcessor):
507
  )
508
  return self(text=text, images=images, **kwargs)
509
 
510
- def _get_text_between_delimiters(self, text: str, begin_token: str, end_token: str) -> list[str]:
511
  """
512
  This function splits the input text into a list of substrings beased on the given begin and end tokens.
513
 
@@ -544,7 +544,7 @@ class Maira2Processor(LlavaProcessor):
544
 
545
  def convert_output_to_plaintext_or_grounded_sequence(
546
  self, text: str
547
- ) -> str | list[tuple[str, list[BoxType] | None]]:
548
  """
549
  This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
550
  boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
@@ -584,7 +584,7 @@ class Maira2Processor(LlavaProcessor):
584
 
585
  # One or more grounded phrases
586
  grounded_phrase_texts = self._get_text_between_delimiters(text, self.phrase_start_token, self.phrase_end_token)
587
- grounded_phrases: list[tuple[str, list[BoxType] | None]] = []
588
  for grounded_phrase_text in grounded_phrase_texts:
589
  if self.box_start_token in grounded_phrase_text or self.box_end_token in grounded_phrase_text:
590
  first_box_start_index = grounded_phrase_text.find(self.box_start_token)
@@ -593,14 +593,14 @@ class Maira2Processor(LlavaProcessor):
593
  boxes_text_list = self._get_text_between_delimiters(
594
  boxes_text, self.box_start_token, self.box_end_token
595
  )
596
- boxes: list[BoxType] = []
597
  for box_text in boxes_text_list:
598
  # extract from <x_><y_><x_><y_>
599
  regex = r"<x(\d+?)><y(\d+?)><x(\d+?)><y(\d+?)>"
600
  match = re.search(regex, box_text)
601
  if match:
602
  x_min, y_min, x_max, y_max = match.groups()
603
- box: BoxType = tuple( # type: ignore[assignment]
604
  (int(coord) + 0.5) / self.num_box_coord_bins for coord in (x_min, y_min, x_max, y_max)
605
  )
606
  assert all(0 <= coord <= 1 for coord in box), f"Invalid box coordinates: {box}"
@@ -613,7 +613,7 @@ class Maira2Processor(LlavaProcessor):
613
  return grounded_phrases
614
 
615
  @staticmethod
616
- def adjust_box_for_original_image_size(box: BoxType, width: int, height: int) -> BoxType:
617
  """
618
  This function adjusts the bounding boxes from the MAIRA-2 model output to account for the image processor
619
  cropping the image to be square prior to the model forward pass. The box coordinates are adjusted to be
 
3
 
4
 
5
  import re
6
+ from typing import Any, Union, List
7
 
8
  import numpy as np
9
  from PIL import Image
 
14
  from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
15
  from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
16
 
17
+ # SingleChatMessageType: TypeAlias = dict[str, str | int | None]
18
+ # ChatMessageListType: TypeAlias = list[dict[str, str | list[SingleChatMessageType]]]
19
+ # BoxType: TypeAlias = tuple[float, float, float, float]
20
 
21
 
22
  class Maira2Processor(LlavaProcessor):
 
55
  self,
56
  image_processor: BaseImageProcessor = None,
57
  tokenizer: PreTrainedTokenizer = None,
58
+ patch_size = None,
59
+ vision_feature_select_strategy = None,
60
+ chat_template = None,
61
  image_token: str = "<image>",
62
  phrase_start_token: str = "<obj>",
63
  phrase_end_token: str = "</obj>",
 
106
  def _normalize_and_stack_images(
107
  self,
108
  current_frontal: Image.Image,
109
+ current_lateral: Image.Image,
110
+ prior_frontal: Image.Image,
111
+ ):
112
  """
113
  This function normalizes the input images and stacks them together. The images are stacked in the order of
114
  current_frontal, current_lateral, and prior_frontal. The order of images is important, since it must match the
 
133
  return images
134
 
135
  @staticmethod
136
+ def _get_section_text_or_missing_text(section: str) -> str:
137
  """
138
  This function returns the input section text if it is not None and not empty, otherwise it returns a missing
139
  section text "N/A".
 
151
  return section
152
 
153
  @staticmethod
154
+ def _construct_image_chat_messages_for_reporting(has_prior: bool, has_lateral: bool):
155
  """
156
  This function constructs user chat messages based on the presence of the prior and lateral images.
157
 
 
187
  ]
188
  )
189
 
190
+ image_prompt = []
191
  image_index = 0
192
  if not has_prior and not has_lateral:
193
  _add_single_image_to_chat_messages("Given the current frontal image only", image_index)
 
208
  self,
209
  has_prior: bool,
210
  has_lateral: bool,
211
+ indication: str,
212
+ technique: str,
213
+ comparison: str,
214
+ prior_report: str,
215
  get_grounding: bool = False,
216
+ assistant_text: str = None,
217
+ ):
218
  """
219
  This function constructs the chat messages for reporting used in the grounded and non-grounded reporting tasks.
220
 
 
299
  "type": "text",
300
  }
301
  )
302
+ messages = [{"content": prompt, "role": "user"}]
303
  if assistant_text is not None:
304
  messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
305
  return messages
306
 
307
  def _construct_chat_messages_phrase_grounding(
308
+ self, phrase: str, assistant_text: str = None
309
+ ):
310
  """
311
  This function constructs the chat messages for phrase grounding used in the phrase grounding task.
312
 
 
319
  Returns:
320
  ChatMessageListType: The chat messages for phrase grounding in the form of a list of dictionaries.
321
  """
322
+ prompt = [
323
  {"index": None, "text": "Given the current frontal image", "type": "text"},
324
  {"index": 0, "text": None, "type": "image"},
325
  {
 
329
  "type": "text",
330
  },
331
  ]
332
+ messages = [{"content": prompt, "role": "user"}]
333
  if assistant_text is not None:
334
  messages.append({"content": [{"index": None, "text": assistant_text, "type": "text"}], "role": "assistant"})
335
  return messages
 
337
  def format_reporting_input(
338
  self,
339
  current_frontal: Image.Image,
340
+ current_lateral: Image.Image,
341
+ prior_frontal: Image.Image,
342
+ indication: str,
343
+ technique: str,
344
+ comparison: str,
345
+ prior_report: str,
346
  get_grounding: bool = False,
347
+ assistant_text: str,
348
+ ):
349
  """
350
  This function formats the reporting prompt for the grounded and non-grounded reporting tasks from the given
351
  input images and text sections. The images are normalized and stacked together in the right order.
 
395
  self,
396
  frontal_image: Image.Image,
397
  phrase: str,
398
+ assistant_text: str = None,
399
+ ):
400
  """
401
  This function formats the phrase grounding prompt for the phrase grounding task from the given input
402
  image and phrase.
 
425
  def format_and_preprocess_reporting_input(
426
  self,
427
  current_frontal: Image.Image,
428
+ current_lateral: Image.Image,
429
+ prior_frontal: Image.Image,
430
+ indication: str,
431
+ technique: str,
432
+ comparison: str,
433
+ prior_report: str,
434
  get_grounding: bool = False,
435
+ assistant_text: str = None,
436
  **kwargs: Any,
437
  ) -> BatchFeature:
438
  """
 
481
  self,
482
  frontal_image: Image.Image,
483
  phrase: str,
484
+ assistant_text: str = None,
485
  **kwargs: Any,
486
  ) -> BatchFeature:
487
  """
 
507
  )
508
  return self(text=text, images=images, **kwargs)
509
 
510
+ def _get_text_between_delimiters(self, text: str, begin_token: str, end_token: str):
511
  """
512
  This function splits the input text into a list of substrings beased on the given begin and end tokens.
513
 
 
544
 
545
  def convert_output_to_plaintext_or_grounded_sequence(
546
  self, text: str
547
+ ):
548
  """
549
  This function converts the input text to a grounded sequence by extracting the grounded phrases and bounding
550
  boxes from the text. If the text is plaintext without any grounded phrases, it returns the text as is.
 
584
 
585
  # One or more grounded phrases
586
  grounded_phrase_texts = self._get_text_between_delimiters(text, self.phrase_start_token, self.phrase_end_token)
587
+ grounded_phrases = []
588
  for grounded_phrase_text in grounded_phrase_texts:
589
  if self.box_start_token in grounded_phrase_text or self.box_end_token in grounded_phrase_text:
590
  first_box_start_index = grounded_phrase_text.find(self.box_start_token)
 
593
  boxes_text_list = self._get_text_between_delimiters(
594
  boxes_text, self.box_start_token, self.box_end_token
595
  )
596
+ boxes = []
597
  for box_text in boxes_text_list:
598
  # extract from <x_><y_><x_><y_>
599
  regex = r"<x(\d+?)><y(\d+?)><x(\d+?)><y(\d+?)>"
600
  match = re.search(regex, box_text)
601
  if match:
602
  x_min, y_min, x_max, y_max = match.groups()
603
+ box = tuple( # type: ignore[assignment]
604
  (int(coord) + 0.5) / self.num_box_coord_bins for coord in (x_min, y_min, x_max, y_max)
605
  )
606
  assert all(0 <= coord <= 1 for coord in box), f"Invalid box coordinates: {box}"
 
613
  return grounded_phrases
614
 
615
  @staticmethod
616
+ def adjust_box_for_original_image_size(box, width: int, height: int):
617
  """
618
  This function adjusts the bounding boxes from the MAIRA-2 model output to account for the image processor
619
  cropping the image to be square prior to the model forward pass. The box coordinates are adjusted to be