" for i in range(NUM_SENTINEL_TOKENS)])
_sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
tokenizer.sentinel_token_ids = _sentinel_token_ids
+
class AutoTokenizerForMOD(AutoTokenizer):
"""AutoTokenizer + Adaptation for MOD.
@@ -38,4 +43,4 @@ class AutoTokenizerForMOD(AutoTokenizer):
"""See `AutoTokenizer.from_pretrained` docstring."""
tokenizer = super().from_pretrained(*args, **kwargs)
adapt_tokenizer_for_denoising(tokenizer)
- return tokenizer
\ No newline at end of file
+ return tokenizer
diff --git a/model/llava/model/mpt/attention.py b/model/llava/model/mpt/attention.py
index 2ca1069cd14ca055d918fa623d7da5efb4c5fd89..b4dad928098484ef7c287b5d7da7d95d5ff5ffee 100644
--- a/model/llava/model/mpt/attention.py
+++ b/model/llava/model/mpt/attention.py
@@ -2,24 +2,45 @@
import math
import warnings
from typing import Optional
+
import torch
import torch.nn as nn
from einops import rearrange
from torch import nn
+
from .norm import LPLayerNorm
-def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
+
+def _reset_is_causal(
+ num_query_tokens: int, num_key_tokens: int, original_is_causal: bool
+):
if original_is_causal and num_query_tokens != num_key_tokens:
if num_query_tokens != 1:
- raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
+ raise NotImplementedError(
+ "MPT does not support query and key with different number of tokens, unless number of query tokens is 1."
+ )
else:
return False
return original_is_causal
-def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
- q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
+
+def scaled_multihead_dot_product_attention(
+ query,
+ key,
+ value,
+ n_heads,
+ softmax_scale=None,
+ attn_bias=None,
+ key_padding_mask=None,
+ is_causal=False,
+ dropout_p=0.0,
+ training=False,
+ needs_weights=False,
+ multiquery=False,
+):
+ q = rearrange(query, "b s (h d) -> b h s d", h=n_heads)
+ k = rearrange(key, "b s (h d) -> b h d s", h=1 if multiquery else n_heads)
+ v = rearrange(value, "b s (h d) -> b h s d", h=1 if multiquery else n_heads)
min_val = torch.finfo(q.dtype).min
(b, _, s_q, d) = q.shape
s_k = k.size(-1)
@@ -27,13 +48,27 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
softmax_scale = 1 / math.sqrt(d)
attn_weight = q.matmul(k) * softmax_scale
if attn_bias is not None:
- if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
- raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
+ if (
+ attn_bias.size(-1) != 1
+ and attn_bias.size(-1) != s_k
+ or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q)
+ ):
+ raise RuntimeError(
+ f"attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}."
+ )
attn_weight = attn_weight + attn_bias
if key_padding_mask is not None:
if attn_bias is not None:
- warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
- attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
+ warnings.warn(
+ "Propogating key_padding_mask to the attention module "
+ + "and applying it within the attention module can cause "
+ + "unneccessary computation/memory usage. Consider integrating "
+ + "into attn_bias once and passing that to each attention "
+ + "module instead."
+ )
+ attn_weight = attn_weight.masked_fill(
+ ~key_padding_mask.view((b, 1, 1, s_k)), min_val
+ )
if is_causal:
s = max(s_q, s_k)
causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
@@ -44,74 +79,146 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
attn_weight = torch.softmax(attn_weight, dim=-1)
if dropout_p:
- attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
+ attn_weight = torch.nn.functional.dropout(
+ attn_weight, p=dropout_p, training=training, inplace=True
+ )
out = attn_weight.matmul(v)
- out = rearrange(out, 'b h s d -> b s (h d)')
+ out = rearrange(out, "b h s d -> b s (h d)")
if needs_weights:
return (out, attn_weight)
return (out, None)
+
def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
for tensor in tensors:
if tensor.dtype not in valid_dtypes:
- raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
+ raise TypeError(
+ f"tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}."
+ )
if not tensor.is_cuda:
- raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
+ raise TypeError(
+ f"Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r})."
+ )
+
-def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
+def flash_attn_fn(
+ query,
+ key,
+ value,
+ n_heads,
+ softmax_scale=None,
+ attn_bias=None,
+ key_padding_mask=None,
+ is_causal=False,
+ dropout_p=0.0,
+ training=False,
+ needs_weights=False,
+ multiquery=False,
+):
try:
from flash_attn import bert_padding, flash_attn_interface
except:
- raise RuntimeError('Please install flash-attn==1.0.3.post0')
+ raise RuntimeError("Please install flash-attn==1.0.3.post0")
check_valid_inputs(query, key, value)
if attn_bias is not None:
- raise NotImplementedError(f'attn_bias not implemented for flash attn.')
+ raise NotImplementedError(f"attn_bias not implemented for flash attn.")
(batch_size, seqlen) = query.shape[:2]
if key_padding_mask is None:
key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
- query_padding_mask = key_padding_mask[:, -query.size(1):]
- (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
- query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
- (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
+ query_padding_mask = key_padding_mask[:, -query.size(1) :]
+ (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(
+ query, query_padding_mask
+ )
+ query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=n_heads)
+ (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(
+ key, key_padding_mask
+ )
+ key_unpad = rearrange(
+ key_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
+ )
(value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
+ value_unpad = rearrange(
+ value_unpad, "nnz (h d) -> nnz h d", h=1 if multiquery else n_heads
+ )
if multiquery:
key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
+ value_unpad = value_unpad.expand(
+ value_unpad.size(0), n_heads, value_unpad.size(-1)
+ )
dropout_p = dropout_p if training else 0.0
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
- output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
- output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
+ output_unpad = flash_attn_interface.flash_attn_unpadded_func(
+ query_unpad,
+ key_unpad,
+ value_unpad,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ max_seqlen_q,
+ max_seqlen_k,
+ dropout_p,
+ softmax_scale=softmax_scale,
+ causal=reset_is_causal,
+ return_attn_probs=needs_weights,
+ )
+ output = bert_padding.pad_input(
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices_q, batch_size, seqlen
+ )
return (output, None)
-def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
+
+def triton_flash_attn_fn(
+ query,
+ key,
+ value,
+ n_heads,
+ softmax_scale=None,
+ attn_bias=None,
+ key_padding_mask=None,
+ is_causal=False,
+ dropout_p=0.0,
+ training=False,
+ needs_weights=False,
+ multiquery=False,
+):
try:
from flash_attn import flash_attn_triton
except:
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
+ raise RuntimeError(
+ "Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202"
+ )
check_valid_inputs(query, key, value)
if dropout_p:
- raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
+ raise NotImplementedError(f"Dropout not implemented for attn_impl: triton.")
if needs_weights:
- raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
+ raise NotImplementedError(f"attn_impl: triton cannot return attn weights.")
if key_padding_mask is not None:
- warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
+ warnings.warn(
+ "Propagating key_padding_mask to the attention module "
+ + "and applying it within the attention module can cause "
+ + "unnecessary computation/memory usage. Consider integrating "
+ + "into attn_bias once and passing that to each attention "
+ + "module instead."
+ )
(b_size, s_k) = key_padding_mask.shape[:2]
if attn_bias is None:
attn_bias = query.new_zeros(b_size, 1, 1, s_k)
- attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
- query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
- key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
- value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
+ attn_bias = attn_bias.masked_fill(
+ ~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min
+ )
+ query = rearrange(query, "b s (h d) -> b s h d", h=n_heads)
+ key = rearrange(key, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
+ value = rearrange(value, "b s (h d) -> b s h d", h=1 if multiquery else n_heads)
if multiquery:
key = key.expand(*key.shape[:2], n_heads, key.size(-1))
value = value.expand(*value.shape[:2], n_heads, value.size(-1))
reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
+ attn_output = flash_attn_triton.flash_attn_func(
+ query, key, value, attn_bias, reset_is_causal, softmax_scale
+ )
output = attn_output.view(*attn_output.shape[:2], -1)
return (output, None)
+
class MultiheadAttention(nn.Module):
"""Multi-head self attention.
@@ -119,7 +226,18 @@ class MultiheadAttention(nn.Module):
additive bias.
"""
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
+ def __init__(
+ self,
+ d_model: int,
+ n_heads: int,
+ attn_impl: str = "triton",
+ clip_qkv: Optional[float] = None,
+ qk_ln: bool = False,
+ softmax_scale: Optional[float] = None,
+ attn_pdrop: float = 0.0,
+ low_precision_layernorm: bool = False,
+ device: Optional[str] = None,
+ ):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
@@ -137,21 +255,38 @@ class MultiheadAttention(nn.Module):
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(self.d_model, device=device)
self.k_ln = layernorm_class(self.d_model, device=device)
- if self.attn_impl == 'flash':
+ if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
- elif self.attn_impl == 'triton':
+ elif self.attn_impl == "triton":
self.attn_fn = triton_flash_attn_fn
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
- elif self.attn_impl == 'torch':
+ warnings.warn(
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
+ + "it uses more memory. When training larger models this can trigger "
+ + "alloc retries which hurts performance. If encountered, we recommend "
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
+ )
+ elif self.attn_impl == "torch":
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available():
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
+ warnings.warn(
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
+ + "we recommend using `attn_impl: triton`."
+ )
else:
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
+ def forward(
+ self,
+ x,
+ past_key_value=None,
+ attn_bias=None,
+ attention_mask=None,
+ is_causal=True,
+ needs_weights=False,
+ ):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
@@ -167,10 +302,23 @@ class MultiheadAttention(nn.Module):
value = torch.cat([past_key_value[1], value], dim=1)
past_key_value = (key, value)
if attn_bias is not None:
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
+ (context, attn_weights) = self.attn_fn(
+ query,
+ key,
+ value,
+ self.n_heads,
+ softmax_scale=self.softmax_scale,
+ attn_bias=attn_bias,
+ key_padding_mask=key_padding_mask,
+ is_causal=is_causal,
+ dropout_p=self.attn_dropout_p,
+ training=self.training,
+ needs_weights=needs_weights,
+ )
return (self.out_proj(context), attn_weights, past_key_value)
+
class MultiQueryAttention(nn.Module):
"""Multi-Query self attention.
@@ -178,7 +326,18 @@ class MultiQueryAttention(nn.Module):
additive bias.
"""
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
+ def __init__(
+ self,
+ d_model: int,
+ n_heads: int,
+ attn_impl: str = "triton",
+ clip_qkv: Optional[float] = None,
+ qk_ln: bool = False,
+ softmax_scale: Optional[float] = None,
+ attn_pdrop: float = 0.0,
+ low_precision_layernorm: bool = False,
+ device: Optional[str] = None,
+ ):
super().__init__()
self.attn_impl = attn_impl
self.clip_qkv = clip_qkv
@@ -197,25 +356,44 @@ class MultiQueryAttention(nn.Module):
layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
self.q_ln = layernorm_class(d_model, device=device)
self.k_ln = layernorm_class(self.head_dim, device=device)
- if self.attn_impl == 'flash':
+ if self.attn_impl == "flash":
self.attn_fn = flash_attn_fn
- elif self.attn_impl == 'triton':
+ elif self.attn_impl == "triton":
self.attn_fn = triton_flash_attn_fn
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
- elif self.attn_impl == 'torch':
+ warnings.warn(
+ "While `attn_impl: triton` can be faster than `attn_impl: flash` "
+ + "it uses more memory. When training larger models this can trigger "
+ + "alloc retries which hurts performance. If encountered, we recommend "
+ + "using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`."
+ )
+ elif self.attn_impl == "torch":
self.attn_fn = scaled_multihead_dot_product_attention
if torch.cuda.is_available():
- warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
+ warnings.warn(
+ "Using `attn_impl: torch`. If your model does not use `alibi` or "
+ + "`prefix_lm` we recommend using `attn_impl: flash` otherwise "
+ + "we recommend using `attn_impl: triton`."
+ )
else:
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
self.out_proj._is_residual = True
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
+ def forward(
+ self,
+ x,
+ past_key_value=None,
+ attn_bias=None,
+ attention_mask=None,
+ is_causal=True,
+ needs_weights=False,
+ ):
qkv = self.Wqkv(x)
if self.clip_qkv:
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
- (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
+ (query, key, value) = qkv.split(
+ [self.d_model, self.head_dim, self.head_dim], dim=2
+ )
key_padding_mask = attention_mask
if self.qk_ln:
dtype = query.dtype
@@ -227,14 +405,30 @@ class MultiQueryAttention(nn.Module):
value = torch.cat([past_key_value[1], value], dim=1)
past_key_value = (key, value)
if attn_bias is not None:
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
+ attn_bias = attn_bias[:, :, -query.size(1) :, -key.size(1) :]
+ (context, attn_weights) = self.attn_fn(
+ query,
+ key,
+ value,
+ self.n_heads,
+ softmax_scale=self.softmax_scale,
+ attn_bias=attn_bias,
+ key_padding_mask=key_padding_mask,
+ is_causal=is_causal,
+ dropout_p=self.attn_dropout_p,
+ training=self.training,
+ needs_weights=needs_weights,
+ multiquery=True,
+ )
return (self.out_proj(context), attn_weights, past_key_value)
-def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
- if attn_impl == 'flash':
+
+def attn_bias_shape(
+ attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id
+):
+ if attn_impl == "flash":
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl in ["torch", "triton"]:
if alibi:
if (prefix_lm or not causal) or use_sequence_id:
return (1, n_heads, seq_len, seq_len)
@@ -243,18 +437,31 @@ def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_s
return (1, 1, seq_len, seq_len)
return None
else:
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
-def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
- if attn_impl == 'flash':
+
+def build_attn_bias(
+ attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8
+):
+ if attn_impl == "flash":
return None
- elif attn_impl in ['torch', 'triton']:
+ elif attn_impl in ["torch", "triton"]:
if alibi:
(device, dtype) = (attn_bias.device, attn_bias.dtype)
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
+ attn_bias = attn_bias.add(
+ build_alibi_bias(
+ n_heads,
+ seq_len,
+ full=not causal,
+ alibi_bias_max=alibi_bias_max,
+ device=device,
+ dtype=dtype,
+ )
+ )
return attn_bias
else:
- raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
+ raise ValueError(f"attn_impl={attn_impl!r} is an invalid setting.")
+
def gen_slopes(n_heads, alibi_bias_max=8, device=None):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
@@ -265,12 +472,24 @@ def gen_slopes(n_heads, alibi_bias_max=8, device=None):
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
-def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
- alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
+
+def build_alibi_bias(
+ n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None
+):
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(
+ 1, 1, 1, seq_len
+ )
if full:
- alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
+ alibi_bias = alibi_bias - torch.arange(
+ 1 - seq_len, 1, dtype=torch.int32, device=device
+ ).view(1, 1, seq_len, 1)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
alibi_bias = alibi_bias * slopes
return alibi_bias.to(dtype=dtype)
-ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
\ No newline at end of file
+
+
+ATTN_CLASS_REGISTRY = {
+ "multihead_attention": MultiheadAttention,
+ "multiquery_attention": MultiQueryAttention,
+}
diff --git a/model/llava/model/mpt/blocks.py b/model/llava/model/mpt/blocks.py
index 04493aa4c03ef1b14ec539c9af8e9c38e8befc8b..1511a225455aaf0a5134cf6d275993e7de57b0e1 100644
--- a/model/llava/model/mpt/blocks.py
+++ b/model/llava/model/mpt/blocks.py
@@ -1,41 +1,90 @@
"""GPT Blocks used for the GPT Model."""
from typing import Dict, Optional, Tuple
+
import torch
import torch.nn as nn
+
from .attention import ATTN_CLASS_REGISTRY
from .norm import NORM_CLASS_REGISTRY
-class MPTMLP(nn.Module):
- def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
+class MPTMLP(nn.Module):
+ def __init__(
+ self, d_model: int, expansion_ratio: int, device: Optional[str] = None
+ ):
super().__init__()
self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
- self.act = nn.GELU(approximate='none')
+ self.act = nn.GELU(approximate="none")
self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
self.down_proj._is_residual = True
def forward(self, x):
return self.down_proj(self.act(self.up_proj(x)))
-class MPTBlock(nn.Module):
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
+class MPTBlock(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ n_heads: int,
+ expansion_ratio: int,
+ attn_config: Dict = {
+ "attn_type": "multihead_attention",
+ "attn_pdrop": 0.0,
+ "attn_impl": "triton",
+ "qk_ln": False,
+ "clip_qkv": None,
+ "softmax_scale": None,
+ "prefix_lm": False,
+ "attn_uses_sequence_id": False,
+ "alibi": False,
+ "alibi_bias_max": 8,
+ },
+ resid_pdrop: float = 0.0,
+ norm_type: str = "low_precision_layernorm",
+ device: Optional[str] = None,
+ **kwargs
+ ):
del kwargs
super().__init__()
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
- attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
+ attn_class = ATTN_CLASS_REGISTRY[attn_config["attn_type"]]
self.norm_1 = norm_class(d_model, device=device)
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
+ self.attn = attn_class(
+ attn_impl=attn_config["attn_impl"],
+ clip_qkv=attn_config["clip_qkv"],
+ qk_ln=attn_config["qk_ln"],
+ softmax_scale=attn_config["softmax_scale"],
+ attn_pdrop=attn_config["attn_pdrop"],
+ d_model=d_model,
+ n_heads=n_heads,
+ device=device,
+ )
self.norm_2 = norm_class(d_model, device=device)
- self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
+ self.ffn = MPTMLP(
+ d_model=d_model, expansion_ratio=expansion_ratio, device=device
+ )
self.resid_attn_dropout = nn.Dropout(resid_pdrop)
self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
+ def forward(
+ self,
+ x: torch.Tensor,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attn_bias: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.ByteTensor] = None,
+ is_causal: bool = True,
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
a = self.norm_1(x)
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
+ (b, _, past_key_value) = self.attn(
+ a,
+ past_key_value=past_key_value,
+ attn_bias=attn_bias,
+ attention_mask=attention_mask,
+ is_causal=is_causal,
+ )
x = x + self.resid_attn_dropout(b)
m = self.norm_2(x)
n = self.ffn(m)
x = x + self.resid_ffn_dropout(n)
- return (x, past_key_value)
\ No newline at end of file
+ return (x, past_key_value)
diff --git a/model/llava/model/mpt/configuration_mpt.py b/model/llava/model/mpt/configuration_mpt.py
index 35d1269cd4b599799d6df7953a8d0c30b33d1e65..f5b96e2a41b16a372b5050769a8c897816ada529 100644
--- a/model/llava/model/mpt/configuration_mpt.py
+++ b/model/llava/model/mpt/configuration_mpt.py
@@ -1,13 +1,52 @@
"""A HuggingFace-style model configuration."""
from typing import Dict, Optional, Union
+
from transformers import PretrainedConfig
-attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
-init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
+
+attn_config_defaults: Dict = {
+ "attn_type": "multihead_attention",
+ "attn_pdrop": 0.0,
+ "attn_impl": "triton",
+ "qk_ln": False,
+ "clip_qkv": None,
+ "softmax_scale": None,
+ "prefix_lm": False,
+ "attn_uses_sequence_id": False,
+ "alibi": False,
+ "alibi_bias_max": 8,
+}
+init_config_defaults: Dict = {
+ "name": "kaiming_normal_",
+ "fan_mode": "fan_in",
+ "init_nonlinearity": "relu",
+}
+
class MPTConfig(PretrainedConfig):
- model_type = 'mpt'
+ model_type = "mpt"
- def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs):
+ def __init__(
+ self,
+ d_model: int = 2048,
+ n_heads: int = 16,
+ n_layers: int = 24,
+ expansion_ratio: int = 4,
+ max_seq_len: int = 2048,
+ vocab_size: int = 50368,
+ resid_pdrop: float = 0.0,
+ emb_pdrop: float = 0.0,
+ learned_pos_emb: bool = True,
+ attn_config: Dict = attn_config_defaults,
+ init_device: str = "cpu",
+ logit_scale: Optional[Union[float, str]] = None,
+ no_bias: bool = False,
+ verbose: int = 0,
+ embedding_fraction: float = 1.0,
+ norm_type: str = "low_precision_layernorm",
+ use_cache: bool = False,
+ init_config: Dict = init_config_defaults,
+ **kwargs,
+ ):
"""The MPT configuration class.
Args:
@@ -80,39 +119,76 @@ class MPTConfig(PretrainedConfig):
self.norm_type = norm_type
self.use_cache = use_cache
self.init_config = init_config
- if 'name' in kwargs:
- del kwargs['name']
- if 'loss_fn' in kwargs:
- del kwargs['loss_fn']
+ if "name" in kwargs:
+ del kwargs["name"]
+ if "loss_fn" in kwargs:
+ del kwargs["loss_fn"]
super().__init__(**kwargs)
self._validate_config()
def _set_config_defaults(self, config, config_defaults):
- for (k, v) in config_defaults.items():
+ for k, v in config_defaults.items():
if k not in config:
config[k] = v
return config
def _validate_config(self):
- self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults)
- self.init_config = self._set_config_defaults(self.init_config, init_config_defaults)
+ self.attn_config = self._set_config_defaults(
+ self.attn_config, attn_config_defaults
+ )
+ self.init_config = self._set_config_defaults(
+ self.init_config, init_config_defaults
+ )
if self.d_model % self.n_heads != 0:
- raise ValueError('d_model must be divisible by n_heads')
- if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])):
- raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1")
- if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']:
+ raise ValueError("d_model must be divisible by n_heads")
+ if any(
+ (
+ prob < 0 or prob > 1
+ for prob in [
+ self.attn_config["attn_pdrop"],
+ self.resid_pdrop,
+ self.emb_pdrop,
+ ]
+ )
+ ):
+ raise ValueError(
+ "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1"
+ )
+ if self.attn_config["attn_impl"] not in ["torch", "flash", "triton"]:
raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}")
- if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
- raise NotImplementedError('prefix_lm only implemented with torch and triton attention.')
- if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
- raise NotImplementedError('alibi only implemented with torch and triton attention.')
- if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']:
- raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.')
+ if self.attn_config["prefix_lm"] and self.attn_config["attn_impl"] not in [
+ "torch",
+ "triton",
+ ]:
+ raise NotImplementedError(
+ "prefix_lm only implemented with torch and triton attention."
+ )
+ if self.attn_config["alibi"] and self.attn_config["attn_impl"] not in [
+ "torch",
+ "triton",
+ ]:
+ raise NotImplementedError(
+ "alibi only implemented with torch and triton attention."
+ )
+ if self.attn_config["attn_uses_sequence_id"] and self.attn_config[
+ "attn_impl"
+ ] not in ["torch", "triton"]:
+ raise NotImplementedError(
+ "attn_uses_sequence_id only implemented with torch and triton attention."
+ )
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
- raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!')
- if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model':
- raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
- if self.init_config.get('name', None) is None:
- raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.")
- if not self.learned_pos_emb and (not self.attn_config['alibi']):
- raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.')
\ No newline at end of file
+ raise ValueError(
+ "model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!"
+ )
+ if isinstance(self.logit_scale, str) and self.logit_scale != "inv_sqrt_d_model":
+ raise ValueError(
+ f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
+ )
+ if self.init_config.get("name", None) is None:
+ raise ValueError(
+ f"self.init_config={self.init_config!r} 'name' needs to be set."
+ )
+ if not self.learned_pos_emb and (not self.attn_config["alibi"]):
+ raise ValueError(
+ f"Positional information must be provided to the model using either learned_pos_emb or alibi."
+ )
diff --git a/model/llava/model/mpt/hf_prefixlm_converter.py b/model/llava/model/mpt/hf_prefixlm_converter.py
index 8c1a6487202a6400a7116a6bd68b493892ef0d14..427d3878185431f3e657d1a93c5db5a55f04300f 100644
--- a/model/llava/model/mpt/hf_prefixlm_converter.py
+++ b/model/llava/model/mpt/hf_prefixlm_converter.py
@@ -10,21 +10,37 @@ import math
import warnings
from types import MethodType
from typing import Any, Dict, List, Optional, Tuple, Union
+
import torch
-from transformers.models.bloom.modeling_bloom import BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel, CausalLMOutputWithCrossAttentions, CrossEntropyLoss
-from transformers.models.bloom.modeling_bloom import _expand_mask as _expand_mask_bloom
-from transformers.models.bloom.modeling_bloom import _make_causal_mask as _make_causal_mask_bloom
+from transformers.models.bloom.modeling_bloom import (
+ BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
+ CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
+from transformers.models.bloom.modeling_bloom import \
+ _expand_mask as _expand_mask_bloom
+from transformers.models.bloom.modeling_bloom import \
+ _make_causal_mask as _make_causal_mask_bloom
from transformers.models.bloom.modeling_bloom import logging
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM
-from transformers.models.opt.modeling_opt import _expand_mask as _expand_mask_opt
-from transformers.models.opt.modeling_opt import _make_causal_mask as _make_causal_mask_opt
+from transformers.models.opt.modeling_opt import \
+ _expand_mask as _expand_mask_opt
+from transformers.models.opt.modeling_opt import \
+ _make_causal_mask as _make_causal_mask_opt
+
logger = logging.get_logger(__name__)
-_SUPPORTED_GPT_MODELS = (GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM)
-CAUSAL_GPT_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM]
+_SUPPORTED_GPT_MODELS = (
+ GPT2LMHeadModel,
+ GPTJForCausalLM,
+ GPTNeoForCausalLM,
+ GPTNeoXForCausalLM,
+)
+CAUSAL_GPT_TYPES = Union[
+ GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM
+]
+
def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_TYPES:
"""Converts a GPT-style Causal LM to a Prefix LM.
@@ -37,10 +53,12 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
See `convert_hf_causal_lm_to_prefix_lm` for more details.
"""
- if hasattr(model, '_prefix_lm_converted'):
+ if hasattr(model, "_prefix_lm_converted"):
return model
assert isinstance(model, _SUPPORTED_GPT_MODELS)
- assert model.config.add_cross_attention == False, 'Only supports GPT-style decoder-only models'
+ assert (
+ model.config.add_cross_attention == False
+ ), "Only supports GPT-style decoder-only models"
def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]:
"""Helper that gets a list of the model's attention modules.
@@ -56,7 +74,7 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
blocks = model.transformer.h
for block in blocks:
if isinstance(model, GPTNeoForCausalLM):
- if block.attn.attention_type != 'global':
+ if block.attn.attention_type != "global":
continue
attn_module = block.attn.attention
elif isinstance(model, GPTNeoXForCausalLM):
@@ -65,17 +83,58 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
attn_module = block.attn
attn_modules.append(attn_module)
return attn_modules
- setattr(model, '_original_forward', getattr(model, 'forward'))
- setattr(model, '_original_generate', getattr(model, 'generate'))
- def forward(self: CAUSAL_GPT_TYPES, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]=None, attention_mask: Optional[torch.FloatTensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, token_type_ids: Optional[torch.LongTensor]=None, position_ids: Optional[torch.LongTensor]=None, head_mask: Optional[torch.FloatTensor]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
+ setattr(model, "_original_forward", getattr(model, "forward"))
+ setattr(model, "_original_generate", getattr(model, "generate"))
+
+ def forward(
+ self: CAUSAL_GPT_TYPES,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ bidirectional_mask: Optional[torch.Tensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
"""Wraps original forward to enable PrefixLM attention."""
def call_og_forward():
if isinstance(self, GPTNeoXForCausalLM):
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
+ return self._original_forward(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
else:
- return self._original_forward(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
+ return self._original_forward(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
if bidirectional_mask is None:
return call_og_forward()
assert isinstance(bidirectional_mask, torch.Tensor)
@@ -83,14 +142,23 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
(b, s) = bidirectional_mask.shape
max_length = attn_modules[0].bias.shape[-1]
if s > max_length:
- raise ValueError(f'bidirectional_mask sequence length (={s}) exceeds the ' + f'max length allowed by the model ({max_length}).')
+ raise ValueError(
+ f"bidirectional_mask sequence length (={s}) exceeds the "
+ + f"max length allowed by the model ({max_length})."
+ )
assert s <= max_length
if s < max_length:
- pad = torch.zeros((int(b), int(max_length - s)), dtype=bidirectional_mask.dtype, device=bidirectional_mask.device)
+ pad = torch.zeros(
+ (int(b), int(max_length - s)),
+ dtype=bidirectional_mask.dtype,
+ device=bidirectional_mask.device,
+ )
bidirectional_mask = torch.cat([bidirectional_mask, pad], dim=1)
bidirectional = bidirectional_mask.unsqueeze(1).unsqueeze(1)
for attn_module in attn_modules:
- attn_module.bias.data = torch.logical_or(attn_module.bias.data, bidirectional)
+ attn_module.bias.data = torch.logical_or(
+ attn_module.bias.data, bidirectional
+ )
output = call_og_forward()
for attn_module in attn_modules:
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
@@ -105,11 +173,13 @@ def _convert_gpt_causal_lm_to_prefix_lm(model: CAUSAL_GPT_TYPES) -> CAUSAL_GPT_T
for attn_module in attn_modules:
attn_module.bias.data = torch.tril(attn_module.bias.data[0, 0])[None, None]
return output
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'generate', MethodType(generate, model))
- setattr(model, '_prefix_lm_converted', True)
+
+ setattr(model, "forward", MethodType(forward, model))
+ setattr(model, "generate", MethodType(generate, model))
+ setattr(model, "_prefix_lm_converted", True)
return model
+
def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCausalLM:
"""Converts a BLOOM Causal LM to a Prefix LM.
@@ -118,62 +188,137 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
See `convert_hf_causal_lm_to_prefix_lm` for more details.
"""
- if hasattr(model, '_prefix_lm_converted'):
+ if hasattr(model, "_prefix_lm_converted"):
return model
assert isinstance(model, BloomForCausalLM)
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
-
- def _prepare_attn_mask(self: BloomModel, attention_mask: torch.Tensor, bidirectional_mask: Optional[torch.Tensor], input_shape: Tuple[int, int], past_key_values_length: int) -> torch.BoolTensor:
+ assert (
+ model.config.add_cross_attention == False
+ ), "Only supports BLOOM decoder-only models"
+
+ def _prepare_attn_mask(
+ self: BloomModel,
+ attention_mask: torch.Tensor,
+ bidirectional_mask: Optional[torch.Tensor],
+ input_shape: Tuple[int, int],
+ past_key_values_length: int,
+ ) -> torch.BoolTensor:
combined_attention_mask = None
device = attention_mask.device
(_, src_length) = input_shape
if src_length > 1:
- combined_attention_mask = _make_causal_mask_bloom(input_shape, device=device, past_key_values_length=past_key_values_length)
+ combined_attention_mask = _make_causal_mask_bloom(
+ input_shape,
+ device=device,
+ past_key_values_length=past_key_values_length,
+ )
if bidirectional_mask is not None:
assert attention_mask.shape == bidirectional_mask.shape
- expanded_bidirectional_mask = _expand_mask_bloom(bidirectional_mask, tgt_length=src_length)
- combined_attention_mask = torch.logical_and(combined_attention_mask, expanded_bidirectional_mask)
+ expanded_bidirectional_mask = _expand_mask_bloom(
+ bidirectional_mask, tgt_length=src_length
+ )
+ combined_attention_mask = torch.logical_and(
+ combined_attention_mask, expanded_bidirectional_mask
+ )
expanded_attn_mask = _expand_mask_bloom(attention_mask, tgt_length=src_length)
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask | combined_attention_mask
+ )
return combined_attention_mask
- def _build_alibi_tensor(self: BloomModel, batch_size: int, query_length: int, key_length: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
+ def _build_alibi_tensor(
+ self: BloomModel,
+ batch_size: int,
+ query_length: int,
+ key_length: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ ) -> torch.Tensor:
num_heads = self.config.n_head
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
- base = torch.tensor(2 ** (-2 ** (-(math.log2(closest_power_of_2) - 3))), device=device, dtype=torch.float32)
- powers = torch.arange(1, 1 + closest_power_of_2, device=device, dtype=torch.int32)
+ base = torch.tensor(
+ 2 ** (-(2 ** (-(math.log2(closest_power_of_2) - 3)))),
+ device=device,
+ dtype=torch.float32,
+ )
+ powers = torch.arange(
+ 1, 1 + closest_power_of_2, device=device, dtype=torch.int32
+ )
slopes = torch.pow(base, powers)
if closest_power_of_2 != num_heads:
- extra_base = torch.tensor(2 ** (-2 ** (-(math.log2(2 * closest_power_of_2) - 3))), device=device, dtype=torch.float32)
- num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
- extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32)
+ extra_base = torch.tensor(
+ 2 ** (-(2 ** (-(math.log2(2 * closest_power_of_2) - 3)))),
+ device=device,
+ dtype=torch.float32,
+ )
+ num_remaining_heads = min(
+ closest_power_of_2, num_heads - closest_power_of_2
+ )
+ extra_powers = torch.arange(
+ 1, 1 + 2 * num_remaining_heads, 2, device=device, dtype=torch.int32
+ )
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
qa = torch.arange(query_length, device=device, dtype=torch.int32).view(-1, 1)
ka = torch.arange(key_length, device=device, dtype=torch.int32).view(1, -1)
diffs = qa - ka + key_length - query_length
diffs = -diffs.abs()
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(1, 1, query_length, key_length)
- alibi = alibi.expand(batch_size, -1, -1, -1).reshape(-1, query_length, key_length)
+ alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
+ 1, 1, query_length, key_length
+ )
+ alibi = alibi.expand(batch_size, -1, -1, -1).reshape(
+ -1, query_length, key_length
+ )
return alibi.to(dtype)
+
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
- def forward(self: BloomModel, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.LongTensor]=None, inputs_embeds: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
- if deprecated_arguments.pop('position_ids', False) is not False:
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' + 'You can safely ignore passing `position_ids`.', FutureWarning)
+ def forward(
+ self: BloomModel,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ bidirectional_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.LongTensor] = None,
+ inputs_embeds: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. "
+ + "You can safely ignore passing `position_ids`.",
+ FutureWarning,
+ )
if len(deprecated_arguments) > 0:
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
if input_ids is not None and inputs_embeds is not None:
- raise ValueError('You cannot specify both input_ids and inputs_embeds at the same time')
+ raise ValueError(
+ "You cannot specify both input_ids and inputs_embeds at the same time"
+ )
elif input_ids is not None:
(batch_size, seq_length) = input_ids.shape
elif inputs_embeds is not None:
(batch_size, seq_length, _) = inputs_embeds.shape
else:
- raise ValueError('You have to specify either input_ids or inputs_embeds')
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
if past_key_values is None:
past_key_values = tuple([None] * len(self.h))
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
@@ -190,28 +335,62 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
past_key_values_length = tmp.shape[2]
seq_length_with_past = seq_length_with_past + past_key_values_length
if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
+ attention_mask = torch.ones(
+ (batch_size, seq_length_with_past), device=hidden_states.device
+ )
else:
attention_mask = attention_mask.to(hidden_states.device)
- alibi = self._build_alibi_tensor(batch_size=batch_size, query_length=seq_length, key_length=seq_length_with_past, dtype=hidden_states.dtype, device=hidden_states.device)
- causal_mask = self._prepare_attn_mask(attention_mask, bidirectional_mask, input_shape=(batch_size, seq_length), past_key_values_length=past_key_values_length)
- for (i, (block, layer_past)) in enumerate(zip(self.h, past_key_values)):
+ alibi = self._build_alibi_tensor(
+ batch_size=batch_size,
+ query_length=seq_length,
+ key_length=seq_length_with_past,
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ )
+ causal_mask = self._prepare_attn_mask(
+ attention_mask,
+ bidirectional_mask,
+ input_shape=(batch_size, seq_length),
+ past_key_values_length=past_key_values_length,
+ )
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
hst = (hidden_states,)
all_hidden_states = all_hidden_states + hst
if self.gradient_checkpointing and self.training:
if use_cache:
- logger.warning('`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...')
+ logger.warning(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
use_cache = False
def create_custom_forward(module):
-
def custom_forward(*inputs):
- return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
+ return module(
+ *inputs,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ )
+
return custom_forward
- outputs = torch.utils.checkpoint.checkpoint(create_custom_forward(block), hidden_states, alibi, causal_mask, head_mask[i])
+
+ outputs = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(block),
+ hidden_states,
+ alibi,
+ causal_mask,
+ head_mask[i],
+ )
else:
- outputs = block(hidden_states, layer_past=layer_past, attention_mask=causal_mask, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi)
+ outputs = block(
+ hidden_states,
+ layer_past=layer_past,
+ attention_mask=causal_mask,
+ head_mask=head_mask[i],
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ alibi=alibi,
+ )
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
@@ -223,21 +402,77 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
hst = (hidden_states,)
all_hidden_states = all_hidden_states + hst
if not return_dict:
- return tuple((v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None))
- return BaseModelOutputWithPastAndCrossAttentions(last_hidden_state=hidden_states, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attentions)
- setattr(model.transformer, '_prepare_attn_mask', MethodType(_prepare_attn_mask, model.transformer))
- setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer))
- setattr(model.transformer, 'forward', MethodType(forward, model.transformer))
+ return tuple(
+ (
+ v
+ for v in [
+ hidden_states,
+ presents,
+ all_hidden_states,
+ all_self_attentions,
+ ]
+ if v is not None
+ )
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+ setattr(
+ model.transformer,
+ "_prepare_attn_mask",
+ MethodType(_prepare_attn_mask, model.transformer),
+ )
+ setattr(
+ model.transformer,
+ "_build_alibi_tensor",
+ MethodType(_build_alibi_tensor, model.transformer),
+ )
+ setattr(model.transformer, "forward", MethodType(forward, model.transformer))
KeyValueT = Tuple[torch.Tensor, torch.Tensor]
- def forward(self: BloomForCausalLM, input_ids: Optional[torch.LongTensor]=None, past_key_values: Optional[Tuple[KeyValueT, ...]]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.Tensor]=None, head_mask: Optional[torch.Tensor]=None, inputs_embeds: Optional[torch.Tensor]=None, labels: Optional[torch.Tensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None, **deprecated_arguments) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
+ def forward(
+ self: BloomForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ bidirectional_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.Tensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ **deprecated_arguments,
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
"""Replacement forward method for BloomCausalLM."""
- if deprecated_arguments.pop('position_ids', False) is not False:
- warnings.warn('`position_ids` have no functionality in BLOOM and will be removed ' + 'in v5.0.0. You can safely ignore passing `position_ids`.', FutureWarning)
+ if deprecated_arguments.pop("position_ids", False) is not False:
+ warnings.warn(
+ "`position_ids` have no functionality in BLOOM and will be removed "
+ + "in v5.0.0. You can safely ignore passing `position_ids`.",
+ FutureWarning,
+ )
if len(deprecated_arguments) > 0:
- raise ValueError(f'Got unexpected arguments: {deprecated_arguments}')
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- transformer_outputs = self.transformer(input_ids, past_key_values=past_key_values, attention_mask=attention_mask, bidirectional_mask=bidirectional_mask, head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
+ raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ bidirectional_mask=bidirectional_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
@@ -246,13 +481,28 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
shift_labels = labels[..., 1:].contiguous()
(batch_size, seq_length, vocab_size) = shift_logits.shape
loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length))
+ loss = loss_fct(
+ shift_logits.view(batch_size * seq_length, vocab_size),
+ shift_labels.view(batch_size * seq_length),
+ )
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return (loss,) + output if loss is not None else output
- return CausalLMOutputWithCrossAttentions(loss=loss, logits=lm_logits, past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions)
-
- def prepare_inputs_for_generation(self: BloomForCausalLM, input_ids: torch.LongTensor, past: Optional[torch.Tensor]=None, attention_mask: Optional[torch.Tensor]=None, **kwargs) -> dict:
+ return CausalLMOutputWithCrossAttentions(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def prepare_inputs_for_generation(
+ self: BloomForCausalLM,
+ input_ids: torch.LongTensor,
+ past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> dict:
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
bidirectional_mask = None
@@ -260,12 +510,24 @@ def _convert_bloom_causal_lm_to_prefix_lm(model: BloomForCausalLM) -> BloomForCa
past = self._convert_to_bloom_cache(past)
else:
bidirectional_mask = torch.ones_like(input_ids)
- return {'input_ids': input_ids, 'past_key_values': past, 'use_cache': True, 'attention_mask': attention_mask, 'bidirectional_mask': bidirectional_mask}
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'prepare_inputs_for_generation', MethodType(prepare_inputs_for_generation, model))
- setattr(model, '_prefix_lm_converted', True)
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past,
+ "use_cache": True,
+ "attention_mask": attention_mask,
+ "bidirectional_mask": bidirectional_mask,
+ }
+
+ setattr(model, "forward", MethodType(forward, model))
+ setattr(
+ model,
+ "prepare_inputs_for_generation",
+ MethodType(prepare_inputs_for_generation, model),
+ )
+ setattr(model, "_prefix_lm_converted", True)
return model
+
def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM:
"""Converts an OPT Causal LM to a Prefix LM.
@@ -274,36 +536,89 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
See `convert_hf_causal_lm_to_prefix_lm` for more details.
"""
- if hasattr(model, '_prefix_lm_converted'):
+ if hasattr(model, "_prefix_lm_converted"):
return model
assert isinstance(model, OPTForCausalLM)
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
- setattr(model, '_original_forward', getattr(model, 'forward'))
- setattr(model, '_original_generate', getattr(model, 'generate'))
+ assert (
+ model.config.add_cross_attention == False
+ ), "Only supports OPT decoder-only models"
+ setattr(model, "_original_forward", getattr(model, "forward"))
+ setattr(model, "_original_generate", getattr(model, "generate"))
model.model.decoder.bidirectional_mask = None
- def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
+ def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+ ):
combined_attention_mask = None
if input_shape[-1] > 1:
- if self.bidirectional_mask == 'g':
+ if self.bidirectional_mask == "g":
(bsz, src_length) = input_shape
- combined_attention_mask = torch.zeros((bsz, 1, src_length, src_length + past_key_values_length), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
+ combined_attention_mask = torch.zeros(
+ (bsz, 1, src_length, src_length + past_key_values_length),
+ dtype=inputs_embeds.dtype,
+ device=inputs_embeds.device,
+ )
else:
- combined_attention_mask = _make_causal_mask_opt(input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length).to(inputs_embeds.device)
+ combined_attention_mask = _make_causal_mask_opt(
+ input_shape,
+ inputs_embeds.dtype,
+ past_key_values_length=past_key_values_length,
+ ).to(inputs_embeds.device)
if self.bidirectional_mask is not None:
assert attention_mask.shape == self.bidirectional_mask.shape
- expanded_bidirectional_mask = _expand_mask_opt(self.bidirectional_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
- combined_attention_mask = torch.maximum(expanded_bidirectional_mask, combined_attention_mask)
+ expanded_bidirectional_mask = _expand_mask_opt(
+ self.bidirectional_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ ).to(inputs_embeds.device)
+ combined_attention_mask = torch.maximum(
+ expanded_bidirectional_mask, combined_attention_mask
+ )
if attention_mask is not None:
- expanded_attn_mask = _expand_mask_opt(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device)
- combined_attention_mask = expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
+ expanded_attn_mask = _expand_mask_opt(
+ attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ ).to(inputs_embeds.device)
+ combined_attention_mask = (
+ expanded_attn_mask
+ if combined_attention_mask is None
+ else expanded_attn_mask + combined_attention_mask
+ )
return combined_attention_mask
- setattr(model.model.decoder, '_prepare_decoder_attention_mask', MethodType(_prepare_decoder_attention_mask, model.model.decoder))
-
- def forward(self: OPTForCausalLM, input_ids: Optional[torch.LongTensor]=None, attention_mask: Optional[torch.Tensor]=None, bidirectional_mask: Optional[torch.ByteTensor]=None, head_mask: Optional[torch.Tensor]=None, past_key_values: Optional[List[torch.FloatTensor]]=None, inputs_embeds: Optional[torch.FloatTensor]=None, labels: Optional[torch.LongTensor]=None, use_cache: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, return_dict: Optional[bool]=None):
+ setattr(
+ model.model.decoder,
+ "_prepare_decoder_attention_mask",
+ MethodType(_prepare_decoder_attention_mask, model.model.decoder),
+ )
+
+ def forward(
+ self: OPTForCausalLM,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ bidirectional_mask: Optional[torch.ByteTensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ):
def call_og_forward():
- return self._original_forward(input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, past_key_values=past_key_values, inputs_embeds=inputs_embeds, labels=labels, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict)
+ return self._original_forward(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ labels=labels,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
if bidirectional_mask is None:
return call_og_forward()
self.model.decoder.bidirectional_mask = bidirectional_mask
@@ -317,7 +632,7 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Dict[str, Any]):
"""Wraps original generate to enable PrefixLM-style attention."""
- self.model.decoder.bidirectional_mask = 'g'
+ self.model.decoder.bidirectional_mask = "g"
try:
output = self._original_generate(*args, **kwargs)
except:
@@ -325,12 +640,23 @@ def _convert_opt_causal_lm_to_prefix_lm(model: OPTForCausalLM) -> OPTForCausalLM
raise
self.model.decoder.bidirectional_mask = None
return output
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'generate', MethodType(generate, model))
- setattr(model, '_prefix_lm_converted', True)
+
+ setattr(model, "forward", MethodType(forward, model))
+ setattr(model, "generate", MethodType(generate, model))
+ setattr(model, "_prefix_lm_converted", True)
return model
+
+
_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM, OPTForCausalLM)
-CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM, GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
+CAUSAL_LM_TYPES = Union[
+ GPT2LMHeadModel,
+ GPTJForCausalLM,
+ GPTNeoForCausalLM,
+ GPTNeoXForCausalLM,
+ BloomForCausalLM,
+ OPTForCausalLM,
+]
+
def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES:
"""Converts a HuggingFace Causal LM to a Prefix LM.
@@ -396,7 +722,12 @@ def convert_hf_causal_lm_to_prefix_lm(model: CAUSAL_LM_TYPES) -> CAUSAL_LM_TYPES
elif isinstance(model, OPTForCausalLM):
return _convert_opt_causal_lm_to_prefix_lm(model)
else:
- raise TypeError(f'Cannot convert model to Prefix LM. ' + f'Model does not belong to set of supported HF models:' + f'\n{_SUPPORTED_HF_MODELS}')
+ raise TypeError(
+ f"Cannot convert model to Prefix LM. "
+ + f"Model does not belong to set of supported HF models:"
+ + f"\n{_SUPPORTED_HF_MODELS}"
+ )
+
def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
"""Attempts to add bidirectional_mask to batch if missing.
@@ -404,12 +735,16 @@ def add_bidirectional_mask_if_missing(batch: Dict[str, Any]):
Raises:
KeyError if bidirectional_mask is missing and can't be inferred
"""
- if 'bidirectional_mask' not in batch:
- if batch.get('mode', None) == 'icl_task':
- batch['bidirectional_mask'] = batch['attention_mask'].clone()
- for (i, continuation_indices) in enumerate(batch['continuation_indices']):
- batch['bidirectional_mask'][i, continuation_indices] = 0
- elif 'labels' in batch and 'attention_mask' in batch:
- batch['bidirectional_mask'] = torch.logical_and(torch.eq(batch['attention_mask'], 1), torch.eq(batch['labels'], -100)).type_as(batch['attention_mask'])
+ if "bidirectional_mask" not in batch:
+ if batch.get("mode", None) == "icl_task":
+ batch["bidirectional_mask"] = batch["attention_mask"].clone()
+ for i, continuation_indices in enumerate(batch["continuation_indices"]):
+ batch["bidirectional_mask"][i, continuation_indices] = 0
+ elif "labels" in batch and "attention_mask" in batch:
+ batch["bidirectional_mask"] = torch.logical_and(
+ torch.eq(batch["attention_mask"], 1), torch.eq(batch["labels"], -100)
+ ).type_as(batch["attention_mask"])
else:
- raise KeyError('No bidirectional_mask in batch and not sure how to construct one.')
\ No newline at end of file
+ raise KeyError(
+ "No bidirectional_mask in batch and not sure how to construct one."
+ )
diff --git a/model/llava/model/mpt/meta_init_context.py b/model/llava/model/mpt/meta_init_context.py
index 6cba6fff0fe21fe222c7ab38eae44a9784c0be9c..208ab255cedb65e5c444b1c5fa5abf72cbdb1512 100644
--- a/model/llava/model/mpt/meta_init_context.py
+++ b/model/llava/model/mpt/meta_init_context.py
@@ -1,9 +1,11 @@
from contextlib import contextmanager
+
import torch
import torch.nn as nn
+
@contextmanager
-def init_empty_weights(include_buffers: bool=False):
+def init_empty_weights(include_buffers: bool = False):
"""Meta initialization context manager.
A context manager under which models are initialized with all parameters
@@ -30,11 +32,12 @@ def init_empty_weights(include_buffers: bool=False):
"""
- with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
+ with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
yield f
+
@contextmanager
-def init_on_device(device: torch.device, include_buffers: bool=False):
+def init_on_device(device: torch.device, include_buffers: bool = False):
"""Device initialization context manager.
A context manager under which models are initialized with all parameters
@@ -62,33 +65,47 @@ def init_on_device(device: torch.device, include_buffers: bool=False):
if param is not None:
param_cls = type(module._parameters[name])
kwargs = module._parameters[name].__dict__
- module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
+ module._parameters[name] = param_cls(
+ module._parameters[name].to(device), **kwargs
+ )
def register_empty_buffer(module, name, buffer):
old_register_buffer(module, name, buffer)
if buffer is not None:
module._buffers[name] = module._buffers[name].to(device)
+
if include_buffers:
- tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
+ tensor_constructors_to_patch = {
+ torch_function_name: getattr(torch, torch_function_name)
+ for torch_function_name in ["empty", "zeros", "ones", "full"]
+ }
else:
tensor_constructors_to_patch = {}
def patch_tensor_constructor(fn):
-
def wrapper(*args, **kwargs):
- kwargs['device'] = device
+ kwargs["device"] = device
return fn(*args, **kwargs)
+
return wrapper
+
try:
nn.Module.register_parameter = register_empty_parameter
if include_buffers:
nn.Module.register_buffer = register_empty_buffer
for torch_function_name in tensor_constructors_to_patch.keys():
- setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
+ setattr(
+ torch,
+ torch_function_name,
+ patch_tensor_constructor(getattr(torch, torch_function_name)),
+ )
yield
finally:
nn.Module.register_parameter = old_register_parameter
if include_buffers:
nn.Module.register_buffer = old_register_buffer
- for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
- setattr(torch, torch_function_name, old_torch_function)
\ No newline at end of file
+ for (
+ torch_function_name,
+ old_torch_function,
+ ) in tensor_constructors_to_patch.items():
+ setattr(torch, torch_function_name, old_torch_function)
diff --git a/model/llava/model/mpt/modeling_mpt.py b/model/llava/model/mpt/modeling_mpt.py
index 5c3144a9872b7cf8df3bcab58e2f12ecc292d5c0..070c151e292f0a360bc468113602fcab1f8e594a 100644
--- a/model/llava/model/mpt/modeling_mpt.py
+++ b/model/llava/model/mpt/modeling_mpt.py
@@ -5,68 +5,95 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
import math
import warnings
from typing import List, Optional, Tuple, Union
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
-from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
+from transformers import (PreTrainedModel, PreTrainedTokenizer,
+ PreTrainedTokenizerFast)
+from transformers.modeling_outputs import (BaseModelOutputWithPast,
+ CausalLMOutputWithPast)
+
+from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
from .attention import attn_bias_shape, build_attn_bias
from .blocks import MPTBlock
-from .norm import NORM_CLASS_REGISTRY
from .configuration_mpt import MPTConfig
-from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
-from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
+from .hf_prefixlm_converter import (add_bidirectional_mask_if_missing,
+ convert_hf_causal_lm_to_prefix_lm)
from .meta_init_context import init_empty_weights
+from .norm import NORM_CLASS_REGISTRY
from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
+
Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
from transformers.utils import logging
+
logger = logging.get_logger(__name__)
+
class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
- base_model_prefix = 'model'
+ base_model_prefix = "model"
-class MPTModel(MPTPreTrainedModel):
+class MPTModel(MPTPreTrainedModel):
def __init__(self, config: MPTConfig):
config._validate_config()
super().__init__(config)
- self.attn_impl = config.attn_config['attn_impl']
- self.prefix_lm = config.attn_config['prefix_lm']
- self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
- self.alibi = config.attn_config['alibi']
- self.alibi_bias_max = config.attn_config['alibi_bias_max']
+ self.attn_impl = config.attn_config["attn_impl"]
+ self.prefix_lm = config.attn_config["prefix_lm"]
+ self.attn_uses_sequence_id = config.attn_config["attn_uses_sequence_id"]
+ self.alibi = config.attn_config["alibi"]
+ self.alibi_bias_max = config.attn_config["alibi_bias_max"]
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
- norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
- raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
+ norm_options = " | ".join(NORM_CLASS_REGISTRY.keys())
+ raise NotImplementedError(
+ f"Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options})."
+ )
norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
self.embedding_fraction = config.embedding_fraction
- self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
+ self.wte = nn.Embedding(
+ config.vocab_size, config.d_model, device=config.init_device
+ )
if not self.alibi:
- self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
+ self.wpe = nn.Embedding(
+ config.max_seq_len, config.d_model, device=config.init_device
+ )
self.emb_drop = nn.Dropout(config.emb_pdrop)
- self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
+ self.blocks = nn.ModuleList(
+ [
+ MPTBlock(device=config.init_device, **config.to_dict())
+ for _ in range(config.n_layers)
+ ]
+ )
self.norm_f = norm_class(config.d_model, device=config.init_device)
- if config.init_device != 'meta':
+ if config.init_device != "meta":
self.apply(self.param_init_fn)
self.is_causal = not self.prefix_lm
self._attn_bias_initialized = False
self.attn_bias = None
- self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
+ self.attn_bias_shape = attn_bias_shape(
+ self.attn_impl,
+ config.n_heads,
+ config.max_seq_len,
+ self.alibi,
+ prefix_lm=self.prefix_lm,
+ causal=self.is_causal,
+ use_sequence_id=self.attn_uses_sequence_id,
+ )
if config.no_bias:
for module in self.modules():
- if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
+ if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter):
if config.verbose:
- warnings.warn(f'Removing bias ({module.bias}) from {module}.')
- module.register_parameter('bias', None)
+ warnings.warn(f"Removing bias ({module.bias}) from {module}.")
+ module.register_parameter("bias", None)
if config.verbose and config.verbose > 2:
print(self)
- if 'verbose' not in self.config.init_config:
- self.config.init_config['verbose'] = self.config.verbose
- if self.config.init_config['verbose'] > 1:
- init_fn_name = self.config.init_config['name']
- warnings.warn(f'Using {init_fn_name} initialization.')
+ if "verbose" not in self.config.init_config:
+ self.config.init_config["verbose"] = self.config.verbose
+ if self.config.init_config["verbose"] > 1:
+ init_fn_name = self.config.init_config["name"]
+ warnings.warn(f"Using {init_fn_name} initialization.")
self.gradient_checkpointing = False
def get_input_embeddings(self):
@@ -76,13 +103,30 @@ class MPTModel(MPTPreTrainedModel):
self.wte = value
@torch.no_grad()
- def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
+ def _attn_bias(
+ self,
+ device,
+ dtype,
+ attention_mask: Optional[torch.ByteTensor] = None,
+ prefix_mask: Optional[torch.ByteTensor] = None,
+ sequence_id: Optional[torch.LongTensor] = None,
+ ):
if not self._attn_bias_initialized:
if self.attn_bias_shape:
- self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
- self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
+ self.attn_bias = torch.zeros(
+ self.attn_bias_shape, device=device, dtype=dtype
+ )
+ self.attn_bias = build_attn_bias(
+ self.attn_impl,
+ self.attn_bias,
+ self.config.n_heads,
+ self.config.max_seq_len,
+ causal=self.is_causal,
+ alibi=self.alibi,
+ alibi_bias_max=self.alibi_bias_max,
+ )
self._attn_bias_initialized = True
- if self.attn_impl == 'flash':
+ if self.attn_impl == "flash":
return (self.attn_bias, attention_mask)
if self.attn_bias is not None:
self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
@@ -101,38 +145,71 @@ class MPTModel(MPTPreTrainedModel):
else:
attn_bias = attn_bias[:, :, :, -s_k:]
if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
- raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
+ raise ValueError(
+ f"attention_mask shape={attention_mask.shape} "
+ + f"and prefix_mask shape={prefix_mask.shape} are not equal."
+ )
min_val = torch.finfo(attn_bias.dtype).min
- attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
+ attn_bias = attn_bias.masked_fill(
+ ~attention_mask.view(-1, 1, 1, s_k), min_val
+ )
return (attn_bias, None)
def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
(s_k, s_q) = attn_bias.shape[-2:]
if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
- raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
+ raise ValueError(
+ "attn_bias does not match the expected shape. "
+ + f"The last two dimensions should both be {self.config.max_length} "
+ + f"but are {s_k} and {s_q}."
+ )
seq_len = prefix_mask.shape[-1]
if seq_len > self.config.max_seq_len:
- raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
+ raise ValueError(
+ f"prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
+ )
attn_bias = attn_bias[..., :seq_len, :seq_len]
- causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
+ causal = torch.tril(
+ torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)
+ ).view(1, 1, seq_len, seq_len)
prefix = prefix_mask.view(-1, 1, 1, seq_len)
cannot_attend = ~torch.logical_or(causal, prefix.bool())
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
return attn_bias
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
+ def _apply_sequence_id(
+ self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor
+ ):
seq_len = sequence_id.shape[-1]
if seq_len > self.config.max_seq_len:
- raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
+ raise ValueError(
+ f"sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}"
+ )
attn_bias = attn_bias[..., :seq_len, :seq_len]
- cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
+ cannot_attend = torch.logical_not(
+ torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))
+ ).unsqueeze(1)
min_val = torch.finfo(attn_bias.dtype).min
attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
return attn_bias
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, tok_emb: Optional[torch.FloatTensor]=None):
- return_dict = return_dict if return_dict is not None else self.config.return_dict
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.ByteTensor] = None,
+ prefix_mask: Optional[torch.ByteTensor] = None,
+ sequence_id: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ tok_emb: Optional[torch.FloatTensor] = None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None else self.config.return_dict
+ )
use_cache = use_cache if use_cache is not None else self.config.use_cache
if self.gradient_checkpointing and self.training:
@@ -146,21 +223,41 @@ class MPTModel(MPTPreTrainedModel):
if prefix_mask is not None:
prefix_mask = prefix_mask.bool()
if not return_dict:
- raise NotImplementedError('return_dict False is not implemented yet for MPT')
+ raise NotImplementedError(
+ "return_dict False is not implemented yet for MPT"
+ )
if output_attentions:
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
- raise NotImplementedError('MPT does not support training with left padding.')
+ raise NotImplementedError(
+ "output_attentions is not implemented yet for MPT"
+ )
+ if (
+ attention_mask is not None
+ and attention_mask[:, 0].sum() != attention_mask.shape[0]
+ and self.training
+ ):
+ raise NotImplementedError(
+ "MPT does not support training with left padding."
+ )
if self.prefix_lm and prefix_mask is None:
- raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
+ raise ValueError(
+ "prefix_mask is a required argument when MPT is configured with prefix_lm=True."
+ )
if self.training:
if self.attn_uses_sequence_id and sequence_id is None:
- raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
+ raise ValueError(
+ "sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True "
+ + "and the model is in train mode."
+ )
elif self.attn_uses_sequence_id is False and sequence_id is not None:
- warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
+ warnings.warn(
+ "MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. "
+ + "This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True."
+ )
if input_ids is not None:
S = input_ids.size(1)
- assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
+ assert (
+ S <= self.config.max_seq_len
+ ), f"Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}"
tok_emb = self.wte(input_ids)
else:
assert tok_emb is not None
@@ -171,45 +268,85 @@ class MPTModel(MPTPreTrainedModel):
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
- raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
+ raise ValueError(
+ f"past_key_values must provide a past_key_value for each attention "
+ + f"layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r})."
+ )
past_position = past_key_values[0][0].size(1)
if S + past_position > self.config.max_seq_len:
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
- pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
+ raise ValueError(
+ f"Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}."
+ )
+ pos = torch.arange(
+ past_position,
+ S + past_position,
+ dtype=torch.long,
+ device=input_ids.device,
+ ).unsqueeze(0)
if attention_mask is not None:
- pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
+ pos = torch.clamp(
+ pos
+ - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[
+ :, past_position:
+ ],
+ min=0,
+ )
pos_emb = self.wpe(pos)
x = tok_emb + pos_emb
if self.embedding_fraction == 1:
x = self.emb_drop(x)
else:
- x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
+ x_shrunk = x * self.embedding_fraction + x.detach() * (
+ 1 - self.embedding_fraction
+ )
assert isinstance(self.emb_drop, nn.Module)
x = self.emb_drop(x_shrunk)
- (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
+ (attn_bias, attention_mask) = self._attn_bias(
+ device=x.device,
+ dtype=x.dtype,
+ attention_mask=attention_mask,
+ prefix_mask=prefix_mask,
+ sequence_id=sequence_id,
+ )
if use_cache and past_key_values is None:
past_key_values = [() for _ in range(self.config.n_layers)]
all_hidden_states = () if output_hidden_states else None
- for (b_idx, block) in enumerate(self.blocks):
+ for b_idx, block in enumerate(self.blocks):
if output_hidden_states:
assert all_hidden_states is not None
all_hidden_states = all_hidden_states + (x,)
- past_key_value = past_key_values[b_idx] if past_key_values is not None else None
+ past_key_value = (
+ past_key_values[b_idx] if past_key_values is not None else None
+ )
if self.gradient_checkpointing and self.training:
(x, past_key_value) = torch.utils.checkpoint.checkpoint(
- block,
- x, past_key_value, attn_bias, attention_mask, self.is_causal
+ block, x, past_key_value, attn_bias, attention_mask, self.is_causal
)
else:
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
+ (x, past_key_value) = block(
+ x,
+ past_key_value=past_key_value,
+ attn_bias=attn_bias,
+ attention_mask=attention_mask,
+ is_causal=self.is_causal,
+ )
if past_key_values is not None:
past_key_values[b_idx] = past_key_value
x = self.norm_f(x)
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
+ return BaseModelOutputWithPast(
+ last_hidden_state=x,
+ past_key_values=past_key_values,
+ hidden_states=all_hidden_states,
+ )
def param_init_fn(self, module):
- init_fn_name = self.config.init_config['name']
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
+ init_fn_name = self.config.init_config["name"]
+ MODEL_INIT_REGISTRY[init_fn_name](
+ module=module,
+ n_layers=self.config.n_layers,
+ d_model=self.config.d_model,
+ **self.config.init_config,
+ )
def fsdp_wrap_fn(self, module):
return isinstance(module, MPTBlock)
@@ -217,21 +354,23 @@ class MPTModel(MPTPreTrainedModel):
def activation_checkpointing_fn(self, module):
return isinstance(module, MPTBlock)
-class MPTForCausalLM(MPTPreTrainedModel):
+class MPTForCausalLM(MPTPreTrainedModel):
def __init__(self, config: MPTConfig):
super().__init__(config)
if not config.tie_word_embeddings:
- raise ValueError('MPTForCausalLM only supports tied word embeddings')
+ raise ValueError("MPTForCausalLM only supports tied word embeddings")
self.transformer = MPTModel(config)
self.logit_scale = None
if config.logit_scale is not None:
logit_scale = config.logit_scale
if isinstance(logit_scale, str):
- if logit_scale == 'inv_sqrt_d_model':
+ if logit_scale == "inv_sqrt_d_model":
logit_scale = 1 / math.sqrt(config.d_model)
else:
- raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
+ raise ValueError(
+ f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'."
+ )
self.logit_scale = logit_scale
def get_input_embeddings(self):
@@ -252,25 +391,63 @@ class MPTForCausalLM(MPTPreTrainedModel):
def get_decoder(self):
return self.transformer
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
- return_dict = return_dict if return_dict is not None else self.config.return_dict
+ def forward(
+ self,
+ input_ids: torch.LongTensor,
+ past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
+ attention_mask: Optional[torch.ByteTensor] = None,
+ prefix_mask: Optional[torch.ByteTensor] = None,
+ sequence_id: Optional[torch.LongTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ use_cache: Optional[bool] = None,
+ ):
+ return_dict = (
+ return_dict if return_dict is not None else self.config.return_dict
+ )
use_cache = use_cache if use_cache is not None else self.config.use_cache
- outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
+ outputs = self.transformer(
+ input_ids=input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ prefix_mask=prefix_mask,
+ sequence_id=sequence_id,
+ return_dict=return_dict,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ use_cache=use_cache,
+ )
logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
if self.logit_scale is not None:
if self.logit_scale == 0:
- warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
+ warnings.warn(
+ f"Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs."
+ )
logits *= self.logit_scale
loss = None
if labels is not None:
labels = torch.roll(labels, shifts=-1)
labels[:, -1] = -100
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
+ loss = F.cross_entropy(
+ logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)
+ )
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ )
def param_init_fn(self, module):
- init_fn_name = self.config.init_config['name']
- MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
+ init_fn_name = self.config.init_config["name"]
+ MODEL_INIT_REGISTRY[init_fn_name](
+ module=module,
+ n_layers=self.config.n_layers,
+ d_model=self.config.d_model,
+ **self.config.init_config,
+ )
def fsdp_wrap_fn(self, module):
return isinstance(module, MPTBlock)
@@ -278,12 +455,16 @@ class MPTForCausalLM(MPTPreTrainedModel):
def activation_checkpointing_fn(self, module):
return isinstance(module, MPTBlock)
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
+ def prepare_inputs_for_generation(
+ self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
+ ):
if inputs_embeds is not None:
- raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
- attention_mask = kwargs['attention_mask'].bool()
+ raise NotImplementedError("inputs_embeds is not implemented for MPT yet")
+ attention_mask = kwargs["attention_mask"].bool()
if attention_mask[:, -1].sum() != attention_mask.shape[0]:
- raise NotImplementedError('MPT does not support generation with right padding.')
+ raise NotImplementedError(
+ "MPT does not support generation with right padding."
+ )
if self.transformer.attn_uses_sequence_id and self.training:
sequence_id = torch.zeros_like(input_ids[:1])
else:
@@ -292,11 +473,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
input_ids = input_ids[:, -1].unsqueeze(-1)
if self.transformer.prefix_lm:
prefix_mask = torch.ones_like(attention_mask)
- if kwargs.get('use_cache') == False:
- raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
+ if kwargs.get("use_cache") == False:
+ raise NotImplementedError(
+ "MPT with prefix_lm=True does not support use_cache=False."
+ )
else:
prefix_mask = None
- return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
+ return {
+ "input_ids": input_ids,
+ "attention_mask": attention_mask,
+ "prefix_mask": prefix_mask,
+ "sequence_id": sequence_id,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache", True),
+ }
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
@@ -307,5 +497,9 @@ class MPTForCausalLM(MPTPreTrainedModel):
"""
reordered_past = []
for layer_past in past_key_values:
- reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
- return reordered_past
\ No newline at end of file
+ reordered_past += [
+ tuple(
+ (past_state.index_select(0, beam_idx) for past_state in layer_past)
+ )
+ ]
+ return reordered_past
diff --git a/model/llava/model/mpt/norm.py b/model/llava/model/mpt/norm.py
index bec4a4ca3304c2188312387743a49b75015542be..42fa6d9c84a3c3cf8190a86dc5ca86b7412763b7 100644
--- a/model/llava/model/mpt/norm.py
+++ b/model/llava/model/mpt/norm.py
@@ -1,28 +1,55 @@
import torch
+
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
- if tensor.device.type == 'cuda':
+ if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
- elif tensor.device.type == 'cpu':
+ elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
-class LPLayerNorm(torch.nn.LayerNorm):
- def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
- super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
+class LPLayerNorm(torch.nn.LayerNorm):
+ def __init__(
+ self,
+ normalized_shape,
+ eps=1e-05,
+ elementwise_affine=True,
+ device=None,
+ dtype=None,
+ ):
+ super().__init__(
+ normalized_shape=normalized_shape,
+ eps=eps,
+ elementwise_affine=elementwise_affine,
+ device=device,
+ dtype=dtype,
+ )
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
- downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
+ downcast_weight = (
+ _cast_if_autocast_enabled(self.weight)
+ if self.weight is not None
+ else self.weight
+ )
+ downcast_bias = (
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
+ )
with torch.autocast(enabled=False, device_type=module_device.type):
- return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
+ return torch.nn.functional.layer_norm(
+ downcast_x,
+ self.normalized_shape,
+ downcast_weight,
+ downcast_bias,
+ self.eps,
+ )
+
def rms_norm(x, weight=None, eps=1e-05):
output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
@@ -30,27 +57,50 @@ def rms_norm(x, weight=None, eps=1e-05):
return output * weight
return output
-class RMSNorm(torch.nn.Module):
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
+class RMSNorm(torch.nn.Module):
+ def __init__(
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
+ ):
super().__init__()
self.eps = eps
if weight:
- self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
+ self.weight = torch.nn.Parameter(
+ torch.ones(normalized_shape, dtype=dtype, device=device)
+ )
else:
- self.register_parameter('weight', None)
+ self.register_parameter("weight", None)
def forward(self, x):
return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
-class LPRMSNorm(RMSNorm):
- def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
- super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
+class LPRMSNorm(RMSNorm):
+ def __init__(
+ self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None
+ ):
+ super().__init__(
+ normalized_shape=normalized_shape,
+ eps=eps,
+ weight=weight,
+ dtype=dtype,
+ device=device,
+ )
def forward(self, x):
downcast_x = _cast_if_autocast_enabled(x)
- downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
+ downcast_weight = (
+ _cast_if_autocast_enabled(self.weight)
+ if self.weight is not None
+ else self.weight
+ )
with torch.autocast(enabled=False, device_type=x.device.type):
return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
-NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
\ No newline at end of file
+
+
+NORM_CLASS_REGISTRY = {
+ "layernorm": torch.nn.LayerNorm,
+ "low_precision_layernorm": LPLayerNorm,
+ "rmsnorm": RMSNorm,
+ "low_precision_rmsnorm": LPRMSNorm,
+}
diff --git a/model/llava/model/mpt/param_init_fns.py b/model/llava/model/mpt/param_init_fns.py
index 418b83ca2363288046f4b48b1d706c5607341fb5..5c1d17a22a62e4411a537e2d7c0c96422e4a4174 100644
--- a/model/llava/model/mpt/param_init_fns.py
+++ b/model/llava/model/mpt/param_init_fns.py
@@ -3,101 +3,139 @@ import warnings
from collections.abc import Sequence
from functools import partial
from typing import Optional, Tuple, Union
+
import torch
from torch import nn
+
from .norm import NORM_CLASS_REGISTRY
-def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs):
+
+def torch_default_param_init_fn_(module: nn.Module, verbose: int = 0, **kwargs):
del kwargs
if verbose > 1:
warnings.warn(f"Initializing network using module's reset_parameters attribute")
- if hasattr(module, 'reset_parameters'):
+ if hasattr(module, "reset_parameters"):
module.reset_parameters()
+
def fused_init_helper_(module: nn.Module, init_fn_):
- _fused = getattr(module, '_fused', None)
+ _fused = getattr(module, "_fused", None)
if _fused is None:
- raise RuntimeError(f'Internal logic error')
+ raise RuntimeError(f"Internal logic error")
(dim, splits) = _fused
splits = (0, *splits, module.weight.size(dim))
- for (s, e) in zip(splits[:-1], splits[1:]):
+ for s, e in zip(splits[:-1], splits[1:]):
slice_indices = [slice(None)] * module.weight.ndim
slice_indices[dim] = slice(s, e)
init_fn_(module.weight[slice_indices])
-def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
+
+def generic_param_init_fn_(
+ module: nn.Module,
+ init_fn_,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
if verbose > 1:
- warnings.warn(f'If model has bias parameters they are initialized to 0.')
+ warnings.warn(f"If model has bias parameters they are initialized to 0.")
init_div_is_residual = init_div_is_residual
if init_div_is_residual is False:
div_is_residual = 1.0
elif init_div_is_residual is True:
div_is_residual = math.sqrt(2 * n_layers)
- elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int):
+ elif isinstance(init_div_is_residual, float) or isinstance(
+ init_div_is_residual, int
+ ):
div_is_residual = init_div_is_residual
elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric():
div_is_residual = float(init_div_is_residual)
else:
div_is_residual = 1.0
- raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}')
+ raise ValueError(
+ f"Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}"
+ )
if init_div_is_residual is not False:
if verbose > 1:
- warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')
+ warnings.warn(
+ f"Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. "
+ + f"Set `init_div_is_residual: false` in init config to disable this."
+ )
if isinstance(module, nn.Linear):
- if hasattr(module, '_fused'):
+ if hasattr(module, "_fused"):
fused_init_helper_(module, init_fn_)
else:
init_fn_(module.weight)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
- if init_div_is_residual is not False and getattr(module, '_is_residual', False):
+ if init_div_is_residual is not False and getattr(module, "_is_residual", False):
with torch.no_grad():
module.weight.div_(div_is_residual)
elif isinstance(module, nn.Embedding):
if emb_init_std is not None:
std = emb_init_std
if std == 0:
- warnings.warn(f'Embedding layer initialized to 0.')
+ warnings.warn(f"Embedding layer initialized to 0.")
emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std)
if verbose > 1:
- warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.')
+ warnings.warn(
+ f"Embedding layer initialized using normal distribution with mean=0 and std={std!r}."
+ )
elif emb_init_uniform_lim is not None:
lim = emb_init_uniform_lim
if isinstance(lim, Sequence):
if len(lim) > 2:
- raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.')
+ raise ValueError(
+ f"Uniform init requires a min and a max limit. User input: {lim}."
+ )
if lim[0] == lim[1]:
- warnings.warn(f'Embedding layer initialized to {lim[0]}.')
+ warnings.warn(f"Embedding layer initialized to {lim[0]}.")
else:
if lim == 0:
- warnings.warn(f'Embedding layer initialized to 0.')
+ warnings.warn(f"Embedding layer initialized to 0.")
lim = [-lim, lim]
(a, b) = lim
emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b)
if verbose > 1:
- warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.')
+ warnings.warn(
+ f"Embedding layer initialized using uniform distribution in range {lim}."
+ )
else:
emb_init_fn_ = init_fn_
emb_init_fn_(module.weight)
elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))):
if verbose > 1:
- warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.')
- if hasattr(module, 'weight') and module.weight is not None:
+ warnings.warn(
+ f"Norm weights are set to 1. If norm layer has a bias it is initialized to 0."
+ )
+ if hasattr(module, "weight") and module.weight is not None:
torch.nn.init.ones_(module.weight)
- if hasattr(module, 'bias') and module.bias is not None:
+ if hasattr(module, "bias") and module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.MultiheadAttention):
if module._qkv_same_embed_dim:
assert module.in_proj_weight is not None
- assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None)
+ assert (
+ module.q_proj_weight is None
+ and module.k_proj_weight is None
+ and (module.v_proj_weight is None)
+ )
assert d_model is not None
_d = d_model
splits = (0, _d, 2 * _d, 3 * _d)
- for (s, e) in zip(splits[:-1], splits[1:]):
+ for s, e in zip(splits[:-1], splits[1:]):
init_fn_(module.in_proj_weight[s:e])
else:
- assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None)
+ assert (
+ module.q_proj_weight is not None
+ and module.k_proj_weight is not None
+ and (module.v_proj_weight is not None)
+ )
assert module.in_proj_weight is None
init_fn_(module.q_proj_weight)
init_fn_(module.k_proj_weight)
@@ -109,37 +147,112 @@ def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model:
if module.bias_v is not None:
torch.nn.init.zeros_(module.bias_v)
init_fn_(module.out_proj.weight)
- if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False):
+ if init_div_is_residual is not False and getattr(
+ module.out_proj, "_is_residual", False
+ ):
with torch.no_grad():
module.out_proj.weight.div_(div_is_residual)
if module.out_proj.bias is not None:
torch.nn.init.zeros_(module.out_proj.bias)
else:
for _ in module.parameters(recurse=False):
- raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.')
+ raise NotImplementedError(
+ f"{module.__class__.__name__} parameters are not initialized by param_init_fn."
+ )
+
def _normal_init_(std, mean=0.0):
return partial(torch.nn.init.normal_, mean=mean, std=std)
-def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
+
+def _normal_param_init_fn_(
+ module: nn.Module,
+ std: float,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
init_fn_ = _normal_init_(std=std)
if verbose > 1:
- warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}')
- generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ warnings.warn(f"Using torch.nn.init.normal_ init fn mean=0.0, std={std}")
+ generic_param_init_fn_(
+ module=module,
+ init_fn_=init_fn_,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
+
-def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
+def baseline_param_init_fn_(
+ module: nn.Module,
+ init_std: float,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
if init_std is None:
- raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.")
- _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ raise ValueError(
+ "You must set model.init_config['init_std'] to a float value to use the default initialization scheme."
+ )
+ _normal_param_init_fn_(
+ module=module,
+ std=init_std,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
-def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
+
+def small_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: int,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
std = math.sqrt(2 / (5 * d_model))
- _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ _normal_param_init_fn_(
+ module=module,
+ std=std,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
+
-def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs):
+def neox_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: int,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ verbose: int = 0,
+ **kwargs,
+):
"""From section 2.3.1 of GPT-NeoX-20B:
An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)
@@ -149,33 +262,158 @@ def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init
del kwargs
residual_div = n_layers / math.sqrt(10)
if verbose > 1:
- warnings.warn(f'setting init_div_is_residual to {residual_div}')
- small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ warnings.warn(f"setting init_div_is_residual to {residual_div}")
+ small_param_init_fn_(
+ module=module,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=residual_div,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
-def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
+
+def kaiming_uniform_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ init_gain: float = 0,
+ fan_mode: str = "fan_in",
+ init_nonlinearity: str = "leaky_relu",
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
if verbose > 1:
- warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
- kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
- generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ warnings.warn(
+ f"Using nn.init.kaiming_uniform_ init fn with parameters: "
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
+ )
+ kaiming_uniform_ = partial(
+ nn.init.kaiming_uniform_,
+ a=init_gain,
+ mode=fan_mode,
+ nonlinearity=init_nonlinearity,
+ )
+ generic_param_init_fn_(
+ module=module,
+ init_fn_=kaiming_uniform_,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
+
-def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs):
+def kaiming_normal_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ init_gain: float = 0,
+ fan_mode: str = "fan_in",
+ init_nonlinearity: str = "leaky_relu",
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
if verbose > 1:
- warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')
- kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity)
- generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ warnings.warn(
+ f"Using nn.init.kaiming_normal_ init fn with parameters: "
+ + f"a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}"
+ )
+ kaiming_normal_ = partial(
+ torch.nn.init.kaiming_normal_,
+ a=init_gain,
+ mode=fan_mode,
+ nonlinearity=init_nonlinearity,
+ )
+ generic_param_init_fn_(
+ module=module,
+ init_fn_=kaiming_normal_,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
-def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
+
+def xavier_uniform_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ init_gain: float = 0,
+ verbose: int = 0,
+ **kwargs,
+):
del kwargs
xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain)
if verbose > 1:
- warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')
- generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
+ warnings.warn(
+ f"Using torch.nn.init.xavier_uniform_ init fn with parameters: "
+ + f"gain={init_gain}"
+ )
+ generic_param_init_fn_(
+ module=module,
+ init_fn_=xavier_uniform_,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
+
-def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs):
+def xavier_normal_param_init_fn_(
+ module: nn.Module,
+ n_layers: int,
+ d_model: Optional[int] = None,
+ init_div_is_residual: Union[int, float, str, bool] = True,
+ emb_init_std: Optional[float] = None,
+ emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None,
+ init_gain: float = 0,
+ verbose: int = 0,
+ **kwargs,
+):
xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain)
if verbose > 1:
- warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')
- generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose)
-MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_}
\ No newline at end of file
+ warnings.warn(
+ f"Using torch.nn.init.xavier_normal_ init fn with parameters: "
+ + f"gain={init_gain}"
+ )
+ generic_param_init_fn_(
+ module=module,
+ init_fn_=xavier_normal_,
+ d_model=d_model,
+ n_layers=n_layers,
+ init_div_is_residual=init_div_is_residual,
+ emb_init_std=emb_init_std,
+ emb_init_uniform_lim=emb_init_uniform_lim,
+ verbose=verbose,
+ )
+
+
+MODEL_INIT_REGISTRY = {
+ "default_": torch_default_param_init_fn_,
+ "baseline_": baseline_param_init_fn_,
+ "kaiming_uniform_": kaiming_uniform_param_init_fn_,
+ "kaiming_normal_": kaiming_normal_param_init_fn_,
+ "neox_init_": neox_param_init_fn_,
+ "small_init_": small_param_init_fn_,
+ "xavier_uniform_": xavier_uniform_param_init_fn_,
+ "xavier_normal_": xavier_normal_param_init_fn_,
+}
diff --git a/model/llava/model/utils.py b/model/llava/model/utils.py
index b732e869aa3ad9d08a5909d6ddf6c7631bb87805..976f0190bdf130627a289b0bbc765f75b0e7a669 100644
--- a/model/llava/model/utils.py
+++ b/model/llava/model/utils.py
@@ -5,16 +5,20 @@ from transformers import AutoConfig, StoppingCriteria
def auto_upgrade(config):
cfg = AutoConfig.from_pretrained(config)
- if 'llava' in config and 'llava' not in cfg.model_type:
- assert cfg.model_type == 'llama'
- print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
- print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
+ if "llava" in config and "llava" not in cfg.model_type:
+ assert cfg.model_type == "llama"
+ print(
+ "You are using newer LLaVA code base, while the checkpoint of v0 is from older code base."
+ )
+ print(
+ "You must upgrade the checkpoint to the new code base (this can be done automatically)."
+ )
confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
if confirm.lower() in ["y", "yes"]:
print("Upgrading checkpoint...")
assert len(cfg.architectures) == 1
setattr(cfg.__class__, "model_type", "llava")
- cfg.architectures[0] = 'LlavaLlamaForCausalLM'
+ cfg.architectures[0] = "LlavaLlamaForCausalLM"
cfg.save_pretrained(config)
print("Checkpoint upgraded.")
else:
@@ -22,24 +26,31 @@ def auto_upgrade(config):
exit(1)
-
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
- self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
+ self.keyword_ids = [
+ keyword_id[0]
+ for keyword_id in self.keyword_ids
+ if type(keyword_id) is list and len(keyword_id) == 1
+ ]
self.tokenizer = tokenizer
self.start_len = None
self.input_ids = input_ids
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
+ def __call__(
+ self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
+ ) -> bool:
if self.start_len is None:
self.start_len = self.input_ids.shape[1]
else:
for keyword_id in self.keyword_ids:
if output_ids[0, -1] == keyword_id:
return True
- outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
+ outputs = self.tokenizer.batch_decode(
+ output_ids[:, self.start_len :], skip_special_tokens=True
+ )[0]
for keyword in self.keywords:
if keyword in outputs:
return True
@@ -50,7 +61,7 @@ class KeywordsStoppingCriteria(StoppingCriteria):
# # if output_ids[0, -1] == keyword_id:
# # return True
-
+
# print("output_ids.shape: {}, self.start_len: {}".format(output_ids.shape, self.start_len))
# print("output_ids[:, self.start_len:]: ", output_ids[:, self.start_len:])
diff --git a/model/llava/serve/cli.py b/model/llava/serve/cli.py
index a385727b5cc7ad7c013c01d704297ec1af5d5686..044be2c8c162f1f6e8a3acbf25022778a17f47c4 100644
--- a/model/llava/serve/cli.py
+++ b/model/llava/serve/cli.py
@@ -6,14 +6,14 @@ import argparse
import time
import torch
-from transformers import AutoTokenizer, AutoModelForCausalLM
-
-from llava.conversation import conv_templates, SeparatorStyle
+from llava.conversation import SeparatorStyle, conv_templates
+from transformers import AutoModelForCausalLM, AutoTokenizer
@torch.inference_mode()
-def generate_stream(tokenizer, model, params, device,
- context_len=2048, stream_interval=2):
+def generate_stream(
+ tokenizer, model, params, device, context_len=2048, stream_interval=2
+):
"""Adapted from fastchat/serve/model_worker.py::generate_stream"""
prompt = params["prompt"]
@@ -30,17 +30,19 @@ def generate_stream(tokenizer, model, params, device,
for i in range(max_new_tokens):
if i == 0:
- out = model(
- torch.as_tensor([input_ids], device=device), use_cache=True)
+ out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
- 1, past_key_values[0][0].shape[-2] + 1, device=device)
- out = model(input_ids=torch.as_tensor([[token]], device=device),
- use_cache=True,
- attention_mask=attention_mask,
- past_key_values=past_key_values)
+ 1, past_key_values[0][0].shape[-2] + 1, device=device
+ )
+ out = model(
+ input_ids=torch.as_tensor([[token]], device=device),
+ use_cache=True,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ )
logits = out.logits
past_key_values = out.past_key_values
@@ -84,18 +86,21 @@ def main(args):
else:
num_gpus = int(num_gpus)
if num_gpus != 1:
- kwargs.update({
- "device_map": "auto",
- "max_memory": {i: "13GiB" for i in range(num_gpus)},
- })
+ kwargs.update(
+ {
+ "device_map": "auto",
+ "max_memory": {i: "13GiB" for i in range(num_gpus)},
+ }
+ )
elif args.device == "cpu":
kwargs = {}
else:
raise ValueError(f"Invalid device: {args.device}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
- model = AutoModelForCausalLM.from_pretrained(model_name,
- low_cpu_mem_usage=True, **kwargs)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name, low_cpu_mem_usage=True, **kwargs
+ )
if args.device == "cuda" and num_gpus == 1:
model.cuda()
@@ -126,11 +131,11 @@ def main(args):
print(f"{conv.roles[1]}: ", end="", flush=True)
pre = 0
for outputs in generate_stream(tokenizer, model, params, args.device):
- outputs = outputs[len(prompt) + 1:].strip()
+ outputs = outputs[len(prompt) + 1 :].strip()
outputs = outputs.split(" ")
now = len(outputs)
if now - 1 > pre:
- print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
+ print(" ".join(outputs[pre : now - 1]), end=" ", flush=True)
pre = now - 1
print(" ".join(outputs[pre:]), flush=True)
diff --git a/model/llava/serve/controller.py b/model/llava/serve/controller.py
index b61fca6ea9fe8aa37acd143784a3d76e90a58b9f..1cb67c27a657d71420f6dddb4e40e40e7645488e 100644
--- a/model/llava/serve/controller.py
+++ b/model/llava/serve/controller.py
@@ -5,23 +5,21 @@ It sends worker addresses to clients.
import argparse
import asyncio
import dataclasses
-from enum import Enum, auto
import json
import logging
+import threading
import time
+from enum import Enum, auto
from typing import List, Union
-import threading
-from fastapi import FastAPI, Request
-from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
-
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
from llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
from llava.utils import build_logger, server_error_msg
-
logger = build_logger("controller", "controller.log")
@@ -61,13 +59,15 @@ class Controller:
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(
- target=heart_beat_controller, args=(self,))
+ target=heart_beat_controller, args=(self,)
+ )
self.heart_beat_thread.start()
logger.info("Init controller")
- def register_worker(self, worker_name: str, check_heart_beat: bool,
- worker_status: dict):
+ def register_worker(
+ self, worker_name: str, check_heart_beat: bool, worker_status: dict
+ ):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
@@ -79,8 +79,12 @@ class Controller:
return False
self.worker_info[worker_name] = WorkerInfo(
- worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
- check_heart_beat, time.time())
+ worker_status["model_names"],
+ worker_status["speed"],
+ worker_status["queue_length"],
+ check_heart_beat,
+ time.time(),
+ )
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
@@ -131,15 +135,13 @@ class Controller:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
- pt = np.random.choice(np.arange(len(worker_names)),
- p=worker_speeds)
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
- pt = np.random.choice(np.arange(len(worker_names)),
- p=worker_speeds)
+ pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
@@ -165,7 +167,9 @@ class Controller:
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
- logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
+ logger.info(
+ f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
+ )
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
@@ -201,8 +205,12 @@ class Controller:
yield json.dumps(ret).encode() + b"\0"
try:
- response = requests.post(worker_addr + "/worker_generate_stream",
- json=params, stream=True, timeout=5)
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ json=params,
+ stream=True,
+ timeout=5,
+ )
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
@@ -214,7 +222,6 @@ class Controller:
}
yield json.dumps(ret).encode() + b"\0"
-
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
@@ -243,8 +250,8 @@ app = FastAPI()
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
- data["worker_name"], data["check_heart_beat"],
- data.get("worker_status", None))
+ data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
+ )
@app.post("/refresh_all_workers")
@@ -268,8 +275,7 @@ async def get_worker_address(request: Request):
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
- exist = controller.receive_heart_beat(
- data["worker_name"], data["queue_length"])
+ exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
return {"exist": exist}
@@ -289,8 +295,12 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
- parser.add_argument("--dispatch-method", type=str, choices=[
- "lottery", "shortest_queue"], default="shortest_queue")
+ parser.add_argument(
+ "--dispatch-method",
+ type=str,
+ choices=["lottery", "shortest_queue"],
+ default="shortest_queue",
+ )
args = parser.parse_args()
logger.info(f"args: {args}")
diff --git a/model/llava/serve/gradio_css.py b/model/llava/serve/gradio_css.py
index 55454130b3becdf786c1545a2f79028068389e7c..71d79b4a4b5a7ad84b8822d99e1740e77bc1f7a8 100644
--- a/model/llava/serve/gradio_css.py
+++ b/model/llava/serve/gradio_css.py
@@ -1,5 +1,4 @@
-code_highlight_css = (
-"""
+code_highlight_css = """
#chatbot .hll { background-color: #ffffcc }
#chatbot .c { color: #408080; font-style: italic }
#chatbot .err { border: 1px solid #FF0000 }
@@ -68,6 +67,5 @@ code_highlight_css = (
#chatbot .vi { color: #19177C }
#chatbot .vm { color: #19177C }
#chatbot .il { color: #666666 }
-""")
-#.highlight { background: #f8f8f8; }
-
+"""
+# .highlight { background: #f8f8f8; }
diff --git a/model/llava/serve/gradio_patch.py b/model/llava/serve/gradio_patch.py
index 07e5909e2d6b10fc75178daa54f45c01dcbb42cb..cb3b4838fe14c9df20b7f22eef03617e1e4b088a 100644
--- a/model/llava/serve/gradio_patch.py
+++ b/model/llava/serve/gradio_patch.py
@@ -50,7 +50,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
warnings.warn(
"The 'color_map' parameter has been deprecated.",
)
- #self.md = utils.get_markdown_parser()
+ # self.md = utils.get_markdown_parser()
self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
self.select: EventListenerMethod
"""
@@ -113,7 +113,7 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
): # This happens for previously processed messages
return chat_message
elif isinstance(chat_message, str):
- #return self.md.render(chat_message)
+ # return self.md.render(chat_message)
return str(self.md.convert(chat_message))
else:
raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
@@ -142,9 +142,10 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
processed_messages.append(
(
- #self._process_chat_messages(message_pair[0]),
- '' +
- message_pair[0] + "
",
+ # self._process_chat_messages(message_pair[0]),
+ ''
+ + message_pair[0]
+ + "
",
self._process_chat_messages(message_pair[1]),
)
)
@@ -164,5 +165,3 @@ class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
**kwargs,
)
return self
-
-
diff --git a/model/llava/serve/gradio_web_server.py b/model/llava/serve/gradio_web_server.py
index c6407730e2956ea0ea65dc7b11873f7b5bef126c..976972e690e317e173f57d43f4556c81dc94e4a5 100644
--- a/model/llava/serve/gradio_web_server.py
+++ b/model/llava/serve/gradio_web_server.py
@@ -1,22 +1,20 @@
import argparse
-from collections import defaultdict
import datetime
+import hashlib
import json
import os
import time
+from collections import defaultdict
import gradio as gr
import requests
-
-from llava.conversation import (default_conversation, conv_templates,
- SeparatorStyle)
from llava.constants import LOGDIR
-from llava.utils import (build_logger, server_error_msg,
- violates_moderation, moderation_msg)
-from llava.serve.gradio_patch import Chatbot as grChatbot
+from llava.conversation import (SeparatorStyle, conv_templates,
+ default_conversation)
from llava.serve.gradio_css import code_highlight_css
-import hashlib
-
+from llava.serve.gradio_patch import Chatbot as grChatbot
+from llava.utils import (build_logger, moderation_msg, server_error_msg,
+ violates_moderation)
logger = build_logger("gradio_web_server", "gradio_web_server.log")
@@ -65,31 +63,33 @@ def load_demo(url_params, request: gr.Request):
if "model" in url_params:
model = url_params["model"]
if model in models:
- dropdown_update = gr.Dropdown.update(
- value=model, visible=True)
+ dropdown_update = gr.Dropdown.update(value=model, visible=True)
state = default_conversation.copy()
- return (state,
- dropdown_update,
- gr.Chatbot.update(visible=True),
- gr.Textbox.update(visible=True),
- gr.Button.update(visible=True),
- gr.Row.update(visible=True),
- gr.Accordion.update(visible=True))
+ return (
+ state,
+ dropdown_update,
+ gr.Chatbot.update(visible=True),
+ gr.Textbox.update(visible=True),
+ gr.Button.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
def load_demo_refresh_model_list(request: gr.Request):
logger.info(f"load_demo. ip: {request.client.host}")
models = get_model_list()
state = default_conversation.copy()
- return (state, gr.Dropdown.update(
- choices=models,
- value=models[0] if len(models) > 0 else ""),
- gr.Chatbot.update(visible=True),
- gr.Textbox.update(visible=True),
- gr.Button.update(visible=True),
- gr.Row.update(visible=True),
- gr.Accordion.update(visible=True))
+ return (
+ state,
+ gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else ""),
+ gr.Chatbot.update(visible=True),
+ gr.Textbox.update(visible=True),
+ gr.Button.update(visible=True),
+ gr.Row.update(visible=True),
+ gr.Accordion.update(visible=True),
+ )
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
@@ -148,13 +148,14 @@ def add_text(state, text, image, image_process_mode, request: gr.Request):
if flagged:
state.skip_next = True
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
- no_change_btn,) * 5
+ no_change_btn,
+ ) * 5
text = text[:1536] # Hard cut-off
if image is not None:
text = text[:1200] # Hard cut-off for images
- if '' not in text:
- text = text + '\n'
+ if "" not in text:
+ text = text + "\n"
text = (text, image, image_process_mode)
state = default_conversation.copy()
state.append_message(state.roles[0], text)
@@ -195,9 +196,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
template_name = "multimodal"
elif "mpt" in model_name:
template_name = "mpt_text"
- elif "koala" in model_name: # Hardcode the condition
+ elif "koala" in model_name: # Hardcode the condition
template_name = "bair_v1"
- elif "v1" in model_name: # vicuna v1_1/v1_2
+ elif "v1" in model_name: # vicuna v1_1/v1_2
template_name = "vicuna_v1_1"
else:
template_name = "v1"
@@ -208,15 +209,24 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
# Query worker address
controller_url = args.controller_url
- ret = requests.post(controller_url + "/get_worker_address",
- json={"model": model_name})
+ ret = requests.post(
+ controller_url + "/get_worker_address", json={"model": model_name}
+ )
worker_addr = ret.json()["address"]
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
# No available worker
if worker_addr == "":
state.messages[-1][-1] = server_error_msg
- yield (state, state.to_gradio_chatbot(), disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ yield (
+ state,
+ state.to_gradio_chatbot(),
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
return
# Construct prompt
@@ -226,7 +236,9 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
for image, hash in zip(all_images, all_image_hash):
t = datetime.datetime.now()
- filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg")
+ filename = os.path.join(
+ LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
+ )
if not os.path.isfile(filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
image.save(filename)
@@ -237,37 +249,56 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
"prompt": prompt,
"temperature": float(temperature),
"max_new_tokens": min(int(max_new_tokens), 1536),
- "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2,
- "images": f'List of {len(state.get_images())} images: {all_image_hash}',
+ "stop": state.sep
+ if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
+ else state.sep2,
+ "images": f"List of {len(state.get_images())} images: {all_image_hash}",
}
logger.info(f"==== request ====\n{pload}")
- pload['images'] = state.get_images()
+ pload["images"] = state.get_images()
state.messages[-1][-1] = "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
try:
# Stream output
- response = requests.post(worker_addr + "/worker_generate_stream",
- headers=headers, json=pload, stream=True, timeout=10)
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=pload,
+ stream=True,
+ timeout=10,
+ )
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode())
if data["error_code"] == 0:
- output = data["text"][len(prompt):].strip()
+ output = data["text"][len(prompt) :].strip()
output = post_process_code(output)
state.messages[-1][-1] = output + "▌"
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
else:
output = data["text"] + f" (error_code: {data['error_code']})"
state.messages[-1][-1] = output
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
return
time.sleep(0.03)
except requests.exceptions.RequestException as e:
state.messages[-1][-1] = server_error_msg
- yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
+ yield (state, state.to_gradio_chatbot()) + (
+ disable_btn,
+ disable_btn,
+ disable_btn,
+ enable_btn,
+ enable_btn,
+ )
return
state.messages[-1][-1] = state.messages[-1][-1][:-1]
@@ -289,27 +320,30 @@ def http_bot(state, model_selector, temperature, max_new_tokens, request: gr.Req
}
fout.write(json.dumps(data) + "\n")
-title_markdown = ("""
+
+title_markdown = """
# 🌋 LLaVA: Large Language and Vision Assistant
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
-""")
+"""
-tos_markdown = ("""
+tos_markdown = """
### Terms of use
By using this service, users are required to agree to the following terms:
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
-""")
+"""
-learn_more_markdown = ("""
+learn_more_markdown = """
### License
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
-""")
+"""
-css = code_highlight_css + """
+css = (
+ code_highlight_css
+ + """
pre {
white-space: pre-wrap; /* Since CSS 2.1 */
white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
@@ -318,11 +352,13 @@ pre {
word-wrap: break-word; /* Internet Explorer 5.5+ */
}
"""
+)
def build_demo(embed_mode):
- textbox = gr.Textbox(show_label=False,
- placeholder="Enter text and press ENTER", visible=False).style(container=False)
+ textbox = gr.Textbox(
+ show_label=False, placeholder="Enter text and press ENTER", visible=False
+ ).style(container=False)
with gr.Blocks(title="LLaVA", theme=gr.themes.Base(), css=css) as demo:
state = gr.State()
@@ -336,26 +372,55 @@ def build_demo(embed_mode):
choices=models,
value=models[0] if len(models) > 0 else "",
interactive=True,
- show_label=False).style(container=False)
+ show_label=False,
+ ).style(container=False)
imagebox = gr.Image(type="pil")
image_process_mode = gr.Radio(
["Crop", "Resize", "Pad"],
value="Crop",
- label="Preprocess for non-square image")
+ label="Preprocess for non-square image",
+ )
cur_dir = os.path.dirname(os.path.abspath(__file__))
- gr.Examples(examples=[
- [f"{cur_dir}/examples/extreme_ironing.jpg", "What is unusual about this image?"],
- [f"{cur_dir}/examples/waterview.jpg", "What are the things I should be cautious about when I visit here?"],
- ], inputs=[imagebox, textbox])
-
- with gr.Accordion("Parameters", open=False, visible=False) as parameter_row:
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
+ gr.Examples(
+ examples=[
+ [
+ f"{cur_dir}/examples/extreme_ironing.jpg",
+ "What is unusual about this image?",
+ ],
+ [
+ f"{cur_dir}/examples/waterview.jpg",
+ "What are the things I should be cautious about when I visit here?",
+ ],
+ ],
+ inputs=[imagebox, textbox],
+ )
+
+ with gr.Accordion(
+ "Parameters", open=False, visible=False
+ ) as parameter_row:
+ temperature = gr.Slider(
+ minimum=0.0,
+ maximum=1.0,
+ value=0.2,
+ step=0.1,
+ interactive=True,
+ label="Temperature",
+ )
+ max_output_tokens = gr.Slider(
+ minimum=0,
+ maximum=1024,
+ value=512,
+ step=64,
+ interactive=True,
+ label="Max output tokens",
+ )
with gr.Column(scale=6):
- chatbot = grChatbot(elem_id="chatbot", label="LLaVA Chatbot", visible=False).style(height=550)
+ chatbot = grChatbot(
+ elem_id="chatbot", label="LLaVA Chatbot", visible=False
+ ).style(height=550)
with gr.Row():
with gr.Column(scale=8):
textbox.render()
@@ -365,7 +430,7 @@ def build_demo(embed_mode):
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
+ # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
@@ -376,32 +441,82 @@ def build_demo(embed_mode):
# Register listeners
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
- upvote_btn.click(upvote_last_response,
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
- downvote_btn.click(downvote_last_response,
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
- flag_btn.click(flag_last_response,
- [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn])
- regenerate_btn.click(regenerate, [state, image_process_mode],
- [state, chatbot, textbox, imagebox] + btn_list).then(
- http_bot, [state, model_selector, temperature, max_output_tokens],
- [state, chatbot] + btn_list)
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
-
- textbox.submit(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
- ).then(http_bot, [state, model_selector, temperature, max_output_tokens],
- [state, chatbot] + btn_list)
- submit_btn.click(add_text, [state, textbox, imagebox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
- ).then(http_bot, [state, model_selector, temperature, max_output_tokens],
- [state, chatbot] + btn_list)
+ upvote_btn.click(
+ upvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ downvote_btn.click(
+ downvote_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ flag_btn.click(
+ flag_last_response,
+ [state, model_selector],
+ [textbox, upvote_btn, downvote_btn, flag_btn],
+ )
+ regenerate_btn.click(
+ regenerate,
+ [state, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list,
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ clear_btn.click(
+ clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
+ )
+
+ textbox.submit(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list,
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
+ submit_btn.click(
+ add_text,
+ [state, textbox, imagebox, image_process_mode],
+ [state, chatbot, textbox, imagebox] + btn_list,
+ ).then(
+ http_bot,
+ [state, model_selector, temperature, max_output_tokens],
+ [state, chatbot] + btn_list,
+ )
if args.model_list_mode == "once":
- demo.load(load_demo, [url_params], [state, model_selector,
- chatbot, textbox, submit_btn, button_row, parameter_row],
- _js=get_window_url_params)
+ demo.load(
+ load_demo,
+ [url_params],
+ [
+ state,
+ model_selector,
+ chatbot,
+ textbox,
+ submit_btn,
+ button_row,
+ parameter_row,
+ ],
+ _js=get_window_url_params,
+ )
elif args.model_list_mode == "reload":
- demo.load(load_demo_refresh_model_list, None, [state, model_selector,
- chatbot, textbox, submit_btn, button_row, parameter_row])
+ demo.load(
+ load_demo_refresh_model_list,
+ None,
+ [
+ state,
+ model_selector,
+ chatbot,
+ textbox,
+ submit_btn,
+ button_row,
+ parameter_row,
+ ],
+ )
else:
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
@@ -414,8 +529,9 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=8)
- parser.add_argument("--model-list-mode", type=str, default="once",
- choices=["once", "reload"])
+ parser.add_argument(
+ "--model-list-mode", type=str, default="once", choices=["once", "reload"]
+ )
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
@@ -426,6 +542,6 @@ if __name__ == "__main__":
logger.info(args)
demo = build_demo(args.embed)
- demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10,
- api_open=False).launch(
- server_name=args.host, server_port=args.port, share=args.share)
+ demo.queue(
+ concurrency_count=args.concurrency_count, status_update_rate=10, api_open=False
+ ).launch(server_name=args.host, server_port=args.port, share=args.share)
diff --git a/model/llava/serve/model_worker.py b/model/llava/serve/model_worker.py
index a4ef900d42a823b5ae7cb41becf1a73b56c6565c..0095bb61b056e69e26d243534796385707f743b0 100644
--- a/model/llava/serve/model_worker.py
+++ b/model/llava/serve/model_worker.py
@@ -4,25 +4,23 @@ A model worker executes the model.
import argparse
import asyncio
import dataclasses
-import logging
import json
-import time
-from typing import List, Union
+import logging
import threading
+import time
import uuid
+from functools import partial
+from typing import List, Union
-from fastapi import FastAPI, Request, BackgroundTasks
-from fastapi.responses import StreamingResponse
import requests
-from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import uvicorn
-from functools import partial
-
+from fastapi import BackgroundTasks, FastAPI, Request
+from fastapi.responses import StreamingResponse
from llava.constants import WORKER_HEART_BEAT_INTERVAL
-from llava.utils import (build_logger, server_error_msg,
- pretty_print_semaphore)
from llava.model import *
+from llava.utils import build_logger, pretty_print_semaphore, server_error_msg
+from transformers import AutoModelForCausalLM, AutoTokenizer
GB = 1 << 30
@@ -40,7 +38,6 @@ DEFAULT_IM_END_TOKEN = ""
def heart_beat_worker(controller):
-
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
@@ -56,38 +53,66 @@ def load_model(model_path, model_name, num_gpus):
}
tokenizer = AutoTokenizer.from_pretrained(model_path)
- if 'llava' in model_name.lower():
- if 'mpt' in model_name.lower():
- model = LlavaMPTForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
+ if "llava" in model_name.lower():
+ if "mpt" in model_name.lower():
+ model = LlavaMPTForCausalLM.from_pretrained(
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
+ )
else:
- model = LlavaLlamaForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
- elif 'mpt' in model_name.lower():
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs)
+ model = LlavaLlamaForCausalLM.from_pretrained(
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
+ )
+ elif "mpt" in model_name.lower():
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ trust_remote_code=True,
+ **kwargs,
+ )
else:
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, **kwargs
+ )
image_processor = None
- if 'llava' in model_name.lower():
+ if "llava" in model_name.lower():
from transformers import CLIPImageProcessor, CLIPVisionModel
- image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
+
+ image_processor = CLIPImageProcessor.from_pretrained(
+ model.config.mm_vision_tower, torch_dtype=torch.float16
+ )
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
- tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
+ tokenizer.add_tokens(
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
+ )
vision_tower = model.get_model().vision_tower[0]
- if vision_tower.device.type == 'meta':
- vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
+ if vision_tower.device.type == "meta":
+ vision_tower = CLIPVisionModel.from_pretrained(
+ vision_tower.config._name_or_path,
+ torch_dtype=torch.float16,
+ low_cpu_mem_usage=True,
+ ).cuda()
model.get_model().vision_tower[0] = vision_tower
else:
- vision_tower.to(device='cuda', dtype=torch.float16)
+ vision_tower.to(device="cuda", dtype=torch.float16)
vision_config = vision_tower.config
- vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
+ vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
+ [DEFAULT_IMAGE_PATCH_TOKEN]
+ )[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
- vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
+ (
+ vision_config.im_start_token,
+ vision_config.im_end_token,
+ ) = tokenizer.convert_tokens_to_ids(
+ [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
+ )
if num_gpus == 1:
model.cuda()
@@ -101,11 +126,17 @@ def load_model(model_path, model_name, num_gpus):
class ModelWorker:
- def __init__(self, controller_addr, worker_addr,
- worker_id, no_register,
- model_path, model_name,
- keep_aspect_ratio,
- num_gpus):
+ def __init__(
+ self,
+ controller_addr,
+ worker_addr,
+ worker_id,
+ no_register,
+ model_path,
+ model_name,
+ keep_aspect_ratio,
+ num_gpus,
+ ):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.worker_id = worker_id
@@ -113,7 +144,7 @@ class ModelWorker:
model_path = model_path[:-1]
if model_name is None:
model_paths = model_path.split("/")
- if model_paths[-1].startswith('checkpoint-'):
+ if model_paths[-1].startswith("checkpoint-"):
self.model_name = model_paths[-2] + "_" + model_paths[-1]
else:
self.model_name = model_paths[-1]
@@ -123,13 +154,15 @@ class ModelWorker:
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.keep_aspect_ratio = keep_aspect_ratio
self.tokenizer, self.model, self.image_processor, self.context_len = load_model(
- model_path, self.model_name, num_gpus)
- self.is_multimodal = 'llava' in model_path.lower()
+ model_path, self.model_name, num_gpus
+ )
+ self.is_multimodal = "llava" in model_path.lower()
if not no_register:
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
- target=heart_beat_worker, args=(self,))
+ target=heart_beat_worker, args=(self,)
+ )
self.heart_beat_thread.start()
def register_to_controller(self):
@@ -139,23 +172,30 @@ class ModelWorker:
data = {
"worker_name": self.worker_addr,
"check_heart_beat": True,
- "worker_status": self.get_status()
+ "worker_status": self.get_status(),
}
r = requests.post(url, json=data)
assert r.status_code == 200
def send_heart_beat(self):
- logger.info(f"Send heart beat. Models: {[self.model_name]}. "
- f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
- f"global_counter: {global_counter}")
+ logger.info(
+ f"Send heart beat. Models: {[self.model_name]}. "
+ f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
+ f"global_counter: {global_counter}"
+ )
url = self.controller_addr + "/receive_heart_beat"
while True:
try:
- ret = requests.post(url, json={
- "worker_name": self.worker_addr,
- "queue_length": self.get_queue_length()}, timeout=5)
+ ret = requests.post(
+ url,
+ json={
+ "worker_name": self.worker_addr,
+ "queue_length": self.get_queue_length(),
+ },
+ timeout=5,
+ )
exist = ret.json()["exist"]
break
except requests.exceptions.RequestException as e:
@@ -169,8 +209,15 @@ class ModelWorker:
if model_semaphore is None:
return 0
else:
- return args.limit_model_concurrency - model_semaphore._value + (len(
- model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
+ return (
+ args.limit_model_concurrency
+ - model_semaphore._value
+ + (
+ len(model_semaphore._waiters)
+ if model_semaphore._waiters is not None
+ else 0
+ )
+ )
def get_status(self):
return {
@@ -181,20 +228,30 @@ class ModelWorker:
@torch.inference_mode()
def generate_stream(self, params):
- tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
+ tokenizer, model, image_processor = (
+ self.tokenizer,
+ self.model,
+ self.image_processor,
+ )
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
if images is not None and len(images) > 0 and self.is_multimodal:
- from PIL import Image
- from io import BytesIO
import base64
+ from io import BytesIO
+
+ from PIL import Image
+
assert type(images) is list
if len(images) > 0:
# assert len(images) == 1, "Only support one image for now"
- images = [Image.open(BytesIO(base64.b64decode(image))) for image in images]
- assert len(images) == prompt.count(DEFAULT_IMAGE_TOKEN), "Number of images does not match number of tokens in prompt"
+ images = [
+ Image.open(BytesIO(base64.b64decode(image))) for image in images
+ ]
+ assert len(images) == prompt.count(
+ DEFAULT_IMAGE_TOKEN
+ ), "Number of images does not match number of tokens in prompt"
if self.keep_aspect_ratio:
new_images = []
@@ -203,21 +260,40 @@ class ModelWorker:
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
- image = image_processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
- new_images.append(image.to(self.model.device, dtype=torch.float16))
+ image = image_processor.preprocess(
+ image,
+ return_tensors="pt",
+ do_center_crop=False,
+ size={"shortest_edge": shortest_edge},
+ )["pixel_values"][0]
+ new_images.append(
+ image.to(self.model.device, dtype=torch.float16)
+ )
# replace the image token with the image patch token in the prompt (each occurrence)
- cur_token_len = (image.shape[1]//14) * (image.shape[2]//14)
+ cur_token_len = (image.shape[1] // 14) * (image.shape[2] // 14)
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * cur_token_len
- if getattr(self.model.config, 'mm_use_im_start_end', False):
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ if getattr(self.model.config, "mm_use_im_start_end", False):
+ replace_token = (
+ DEFAULT_IM_START_TOKEN
+ + replace_token
+ + DEFAULT_IM_END_TOKEN
+ )
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token, 1)
images = new_images
else:
- images = image_processor(images, return_tensors='pt')['pixel_values']
+ images = image_processor(images, return_tensors="pt")[
+ "pixel_values"
+ ]
images = images.to(self.model.device, dtype=torch.float16)
- replace_token = DEFAULT_IMAGE_PATCH_TOKEN * 256 # HACK: 256 is the max image token length hacked
- if getattr(self.model.config, 'mm_use_im_start_end', False):
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ replace_token = (
+ DEFAULT_IMAGE_PATCH_TOKEN * 256
+ ) # HACK: 256 is the max image token length hacked
+ if getattr(self.model.config, "mm_use_im_start_end", False):
+ replace_token = (
+ DEFAULT_IM_START_TOKEN
+ + replace_token
+ + DEFAULT_IM_END_TOKEN
+ )
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
else:
images = None
@@ -249,18 +325,20 @@ class ModelWorker:
for i in range(max_new_tokens):
if i == 0:
out = model(
- torch.as_tensor([input_ids]).cuda(),
- use_cache=True,
- **image_args)
+ torch.as_tensor([input_ids]).cuda(), use_cache=True, **image_args
+ )
logits = out.logits
past_key_values = out.past_key_values
else:
attention_mask = torch.ones(
- 1, past_key_values[0][0].shape[-2] + 1, device="cuda")
- out = model(input_ids=torch.as_tensor([[token]], device="cuda"),
- use_cache=True,
- attention_mask=attention_mask,
- past_key_values=past_key_values)
+ 1, past_key_values[0][0].shape[-2] + 1, device="cuda"
+ )
+ out = model(
+ input_ids=torch.as_tensor([[token]], device="cuda"),
+ use_cache=True,
+ attention_mask=attention_mask,
+ past_key_values=past_key_values,
+ )
logits = out.logits
past_key_values = out.past_key_values
@@ -342,7 +420,9 @@ async def generate_stream(request: Request):
worker.send_heart_beat()
generator = worker.generate_stream_gate(params)
background_tasks = BackgroundTasks()
- background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
+ background_tasks.add_task(
+ partial(release_model_semaphore, fn=worker.send_heart_beat)
+ )
return StreamingResponse(generator, background=background_tasks)
@@ -355,13 +435,17 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
- parser.add_argument("--worker-address", type=str,
- default="http://localhost:21002")
- parser.add_argument("--controller-address", type=str,
- default="http://localhost:21001")
+ parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-name", type=str)
- parser.add_argument("--multi-modal", action="store_true", help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
+ parser.add_argument(
+ "--multi-modal",
+ action="store_true",
+ help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.",
+ )
parser.add_argument("--keep-aspect-ratio", action="store_true")
parser.add_argument("--num-gpus", type=int, default=1)
parser.add_argument("--limit-model-concurrency", type=int, default=5)
@@ -371,14 +455,18 @@ if __name__ == "__main__":
logger.info(f"args: {args}")
if args.multi_modal:
- logger.warning("Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.")
-
- worker = ModelWorker(args.controller_address,
- args.worker_address,
- worker_id,
- args.no_register,
- args.model_path,
- args.model_name,
- args.keep_aspect_ratio,
- args.num_gpus)
+ logger.warning(
+ "Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path."
+ )
+
+ worker = ModelWorker(
+ args.controller_address,
+ args.worker_address,
+ worker_id,
+ args.no_register,
+ args.model_path,
+ args.model_name,
+ args.keep_aspect_ratio,
+ args.num_gpus,
+ )
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
diff --git a/model/llava/serve/test_message.py b/model/llava/serve/test_message.py
index 6b090faed0e630b03b2294545050f1f4f5032cad..3d5c2576d943624db791332bda8427ef6f70778e 100644
--- a/model/llava/serve/test_message.py
+++ b/model/llava/serve/test_message.py
@@ -2,7 +2,6 @@ import argparse
import json
import requests
-
from llava.conversation import default_conversation
@@ -17,8 +16,9 @@ def main():
models.sort()
print(f"Models: {models}")
- ret = requests.post(controller_addr + "/get_worker_address",
- json={"model": args.model_name})
+ ret = requests.post(
+ controller_addr + "/get_worker_address", json={"model": args.model_name}
+ )
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")
@@ -37,11 +37,17 @@ def main():
"temperature": 0.7,
"stop": conv.sep,
}
- response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
- json=pload, stream=True)
+ response = requests.post(
+ worker_addr + "/worker_generate_stream",
+ headers=headers,
+ json=pload,
+ stream=True,
+ )
print(prompt.replace(conv.sep, "\n"), end="")
- for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
+ for chunk in response.iter_lines(
+ chunk_size=8192, decode_unicode=False, delimiter=b"\0"
+ ):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(conv.sep)[-1]
@@ -51,12 +57,15 @@ def main():
if __name__ == "__main__":
parser = argparse.ArgumentParser()
- parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
+ parser.add_argument(
+ "--controller-address", type=str, default="http://localhost:21001"
+ )
parser.add_argument("--worker-address", type=str)
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--max-new-tokens", type=int, default=32)
- parser.add_argument("--message", type=str, default=
- "Tell me a story with more than 1000 words.")
+ parser.add_argument(
+ "--message", type=str, default="Tell me a story with more than 1000 words."
+ )
args = parser.parse_args()
main()
diff --git a/model/llava/train/llama_flash_attn_monkey_patch.py b/model/llava/train/llama_flash_attn_monkey_patch.py
index 89f9c3b56fce9b6c8c8be334772686a15c9454d4..66f1f7ab8a2b286f44327d7759ca6b082c4a9d9a 100644
--- a/model/llava/train/llama_flash_attn_monkey_patch.py
+++ b/model/llava/train/llama_flash_attn_monkey_patch.py
@@ -2,15 +2,13 @@
from typing import List, Optional, Tuple
import torch
-from torch import nn
-
import transformers
-from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
-
from einops import rearrange
-
+from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
-from flash_attn.bert_padding import unpad_input, pad_input
+from torch import nn
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
def forward(
self,
@@ -19,20 +17,28 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor],
- Optional[Tuple[torch.Tensor]]]:
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
-
+
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
- query_states = self.q_proj(hidden_states).view(
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- key_states = self.k_proj(hidden_states).view(
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(
- bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states = (
+ self.q_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ key_states = (
+ self.k_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ value_states = (
+ self.v_proj(hidden_states)
+ .view(bsz, q_len, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ )
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
@@ -42,11 +48,9 @@ def forward(
offset = past_key_value[0].shape[-2]
kv_seq_len += offset
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(query_states,
- key_states,
- cos,
- sin,
- offset=offset)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, offset=offset
+ )
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
@@ -56,47 +60,55 @@ def forward(
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
- qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
- qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
+ qkv = torch.stack(
+ [query_states, key_states, value_states], dim=2
+ ) # [bsz, nh, 3, q_len, hd]
+ qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
-
if key_padding_mask is None:
- qkv = rearrange(qkv, 'b s ... -> (b s) ...')
+ qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
- cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
- device=qkv.device)
+ cu_q_lens = torch.arange(
+ 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
+ )
output = flash_attn_unpadded_qkvpacked_func(
- qkv, cu_q_lens, max_s, 0.0,
- softmax_scale=None, causal=True
+ qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
- output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
+ output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
- x = rearrange(qkv, 'b s three h d -> b s (three h d)')
+ x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
- x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
+ x_unpad = rearrange(
+ x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
+ )
output_unpad = flash_attn_unpadded_qkvpacked_func(
- x_unpad, cu_q_lens, max_s, 0.0,
- softmax_scale=None, causal=True
+ x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
+ )
+ output = rearrange(
+ pad_input(
+ rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
+ ),
+ "b s (h d) -> b s h d",
+ h=nheads,
)
- output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
- indices, bsz, q_len),
- 'b s (h d) -> b s h d', h=nheads)
- return self.o_proj(rearrange(output,
- 'b s h d -> b s (h d)')), None, None
+ return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
-def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
- inputs_embeds, past_key_values_length):
+def _prepare_decoder_attention_mask(
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length
+):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
- transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
+ transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
+ _prepare_decoder_attention_mask
+ )
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
diff --git a/model/llava/train/llava_trainer.py b/model/llava/train/llava_trainer.py
index 2824f25e92d8893103ad8f32848b749167c5630d..864ad016a65c83b22756be67a43bf6a24a5e50cd 100644
--- a/model/llava/train/llava_trainer.py
+++ b/model/llava/train/llava_trainer.py
@@ -1,9 +1,9 @@
import os
+from typing import Dict, Optional, Sequence
+
import torch
import torch.nn as nn
-
from transformers import Trainer
-from typing import Dict, Optional, Sequence
def unwrap_model(model: nn.Module) -> nn.Module:
@@ -21,9 +21,8 @@ def unwrap_model(model: nn.Module) -> nn.Module:
class LLaVATrainer(Trainer):
-
def _save(self, output_dir: Optional[str] = None, state_dict=None):
- if getattr(self.args, 'tune_mm_mlp_adapter', False):
+ if getattr(self.args, "tune_mm_mlp_adapter", False):
# Save the model
_state_dict = state_dict
if _state_dict is None:
@@ -32,18 +31,23 @@ class LLaVATrainer(Trainer):
_state_dict = model_to_save.state_dict()
weight_to_save = {}
- keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
+ keys_to_match = ["mm_projector", "embed_tokens", "embed_in"]
for k, v in _state_dict.items():
if any(key_match in k for key_match in keys_to_match):
weight_to_save[k] = v
- current_folder = output_dir.split('/')[-1]
+ current_folder = output_dir.split("/")[-1]
parent_folder = os.path.dirname(output_dir)
- if current_folder.startswith('checkpoint-'):
+ if current_folder.startswith("checkpoint-"):
mm_projector_folder = os.path.join(parent_folder, "mm_projector")
os.makedirs(mm_projector_folder, exist_ok=True)
- torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
+ torch.save(
+ weight_to_save,
+ os.path.join(mm_projector_folder, f"{current_folder}.bin"),
+ )
else:
- torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
+ torch.save(
+ weight_to_save, os.path.join(output_dir, f"mm_projector.bin")
+ )
super(LLaVATrainer, self)._save(output_dir, state_dict)
diff --git a/model/llava/train/train.py b/model/llava/train/train.py
index 49f7a0d5e33c7c082aa7e9857344a29a203dce13..f76872b3acd023cb5e69e32c239bd0ae5503443d 100644
--- a/model/llava/train/train.py
+++ b/model/llava/train/train.py
@@ -14,25 +14,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import copy
-from dataclasses import dataclass, field
import json
import logging
+import os
import pathlib
+from dataclasses import dataclass, field
from typing import Dict, Optional, Sequence
import torch
-
+import torch.nn as nn
import transformers
-from torch.utils.data import Dataset
-from llava.train.llava_trainer import LLaVATrainer
-
from llava import conversation as conversation_lib
from llava.model import *
-
+from llava.train.llava_trainer import LLaVATrainer
from PIL import Image
-import torch.nn as nn
+from torch.utils.data import Dataset
# TODO: import and use code from ../data/dataset.py
@@ -54,21 +51,24 @@ class ModelArguments:
freeze_backbone: bool = field(default=False)
tune_mm_mlp_adapter: bool = field(default=False)
vision_tower: Optional[str] = field(default=None)
- mm_vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
+ mm_vision_select_layer: Optional[int] = field(
+ default=-1
+ ) # default to the last layer
pretrain_mm_mlp_adapter: Optional[str] = field(default=None)
mm_use_im_start_end: bool = field(default=False)
@dataclass
class DataArguments:
- data_path: str = field(default=None,
- metadata={"help": "Path to the training data."})
+ data_path: str = field(
+ default=None, metadata={"help": "Path to the training data."}
+ )
lazy_preprocess: bool = False
is_multimodal: bool = False
sep_image_conv_front: bool = False
image_token_len: int = 0
image_folder: Optional[str] = field(default=None)
- image_aspect_ratio: str = 'square'
+ image_aspect_ratio: str = "square"
@dataclass
@@ -81,21 +81,16 @@ class TrainingArguments(transformers.TrainingArguments):
model_max_length: int = field(
default=512,
metadata={
- "help":
- "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
+ "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
-def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
- output_dir: str):
+def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
- cpu_state_dict = {
- key: value.cpu()
- for key, value in state_dict.items()
- }
+ cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
@@ -117,16 +112,19 @@ def smart_tokenizer_and_embedding_resize(
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
+ dim=0, keepdim=True
+ )
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
- dim=0, keepdim=True)
+ dim=0, keepdim=True
+ )
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
-def _tokenize_fn(strings: Sequence[str],
- tokenizer: transformers.PreTrainedTokenizer) -> Dict:
+def _tokenize_fn(
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
+) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
@@ -135,11 +133,10 @@ def _tokenize_fn(strings: Sequence[str],
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
- ) for text in strings
- ]
- input_ids = labels = [
- tokenized.input_ids[0] for tokenized in tokenized_list
+ )
+ for text in strings
]
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
for tokenized in tokenized_list
@@ -159,7 +156,7 @@ def _mask_targets(target, tokenized_lens, speakers):
target[:cur_idx] = IGNORE_INDEX
for tokenized_len, speaker in zip(tokenized_lens, speakers):
if speaker == "human":
- target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
cur_idx += tokenized_len
@@ -175,9 +172,10 @@ def _add_speaker_and_signal(header, source, get_conversation=True):
elif from_str.lower() == "gpt":
from_str = conversation_lib.default_conversation.roles[1]
else:
- from_str = 'unknown'
- sentence["value"] = (BEGIN_SIGNAL + from_str + ": " +
- sentence["value"] + END_SIGNAL)
+ from_str = "unknown"
+ sentence["value"] = (
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
+ )
if get_conversation:
conversation += sentence["value"]
conversation += BEGIN_SIGNAL
@@ -189,22 +187,34 @@ def preprocess_multimodal(
multimodal_cfg: dict,
cur_token_len: int,
) -> Dict:
- is_multimodal = multimodal_cfg['is_multimodal']
+ is_multimodal = multimodal_cfg["is_multimodal"]
# image_token_len = multimodal_cfg['image_token_len']
image_token_len = cur_token_len
if not is_multimodal:
return sources
for source in sources:
- if multimodal_cfg['sep_image_conv_front']:
- assert DEFAULT_IMAGE_TOKEN in source[0]['value']
- source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
- source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
+ if multimodal_cfg["sep_image_conv_front"]:
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
+ source[0]["value"] = (
+ source[0]["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
+ )
+ source[0]["value"] = (
+ DEFAULT_IMAGE_TOKEN
+ + conversation_lib.default_conversation.sep
+ + conversation_lib.default_conversation.roles[0]
+ + ": "
+ + source[0]["value"]
+ )
for sentence in source:
replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
- if multimodal_cfg['use_im_start_end']:
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
- sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
+ if multimodal_cfg["use_im_start_end"]:
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ sentence["value"] = sentence["value"].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
return sources
@@ -279,6 +289,7 @@ def preprocess_v1(
labels=targets,
)
+
def preprocess_mpt(
sources,
tokenizer: transformers.PreTrainedTokenizer,
@@ -317,9 +328,11 @@ def preprocess_mpt(
total_len = int(target.ne(tokenizer.pad_token_id).sum())
rounds = conversation.split(conv.sep)
- re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
+ re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
for conv_idx in range(3, len(rounds), 2):
- re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
+ re_rounds.append(
+ conv.sep.join(rounds[conv_idx : conv_idx + 2])
+ ) # user + gpt
cur_len = 0
target[:cur_len] = IGNORE_INDEX
for i, rou in enumerate(re_rounds):
@@ -330,7 +343,9 @@ def preprocess_mpt(
if len(parts) != 2:
break
parts[0] += sep
- round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
+ round_len = len(tokenizer(rou).input_ids) + len(
+ tokenizer(conv.sep).input_ids
+ )
instruction_len = len(tokenizer(parts[0]).input_ids)
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
@@ -377,8 +392,9 @@ def preprocess(
input_ids = conversations_tokenized["input_ids"]
targets = copy.deepcopy(input_ids)
for target, source in zip(targets, sources):
- tokenized_lens = _tokenize_fn([header] + [s["value"] for s in source],
- tokenizer)["input_ids_lens"]
+ tokenized_lens = _tokenize_fn(
+ [header] + [s["value"] for s in source], tokenizer
+ )["input_ids_lens"]
speakers = [sentence["from"] for sentence in source]
_mask_targets(target, tokenized_lens, speakers)
@@ -388,8 +404,7 @@ def preprocess(
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self, data_path: str,
- tokenizer: transformers.PreTrainedTokenizer):
+ def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
list_data_dict = json.load(open(data_path, "r"))
@@ -411,9 +426,12 @@ class SupervisedDataset(Dataset):
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
- def __init__(self, data_path: str,
- tokenizer: transformers.PreTrainedTokenizer,
- multimodal_cfg: dict):
+ def __init__(
+ self,
+ data_path: str,
+ tokenizer: transformers.PreTrainedTokenizer,
+ multimodal_cfg: dict,
+ ):
super(LazySupervisedDataset, self).__init__()
logging.warning("Loading data...")
list_data_dict = json.load(open(data_path, "r"))
@@ -431,54 +449,74 @@ class LazySupervisedDataset(Dataset):
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
- if 'image' in sources[0]:
- image_file = self.list_data_dict[i]['image']
- image_folder = self.multimodal_cfg['image_folder']
- processor = self.multimodal_cfg['image_processor']
- image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
- if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
+ if "image" in sources[0]:
+ image_file = self.list_data_dict[i]["image"]
+ image_folder = self.multimodal_cfg["image_folder"]
+ processor = self.multimodal_cfg["image_processor"]
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+ if self.multimodal_cfg["image_aspect_ratio"] == "keep":
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 448, 224
shortest_edge = int(min(max_len / aspect_ratio, min_len))
- image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
- elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
+ image = processor.preprocess(
+ image,
+ return_tensors="pt",
+ do_center_crop=False,
+ size={"shortest_edge": shortest_edge},
+ )["pixel_values"][0]
+ elif self.multimodal_cfg["image_aspect_ratio"] == "pad":
+
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
- result = Image.new(pil_img.mode, (width, width), background_color)
+ result = Image.new(
+ pil_img.mode, (width, width), background_color
+ )
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
- result = Image.new(pil_img.mode, (height, height), background_color)
+ result = Image.new(
+ pil_img.mode, (height, height), background_color
+ )
result.paste(pil_img, ((height - width) // 2, 0))
return result
- image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
+
+ image = expand2square(
+ image, tuple(int(x * 255) for x in processor.image_mean)
+ )
+ image = processor.preprocess(image, return_tensors="pt")[
+ "pixel_values"
+ ][0]
else:
- image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
- cur_token_len = (image.shape[1]//14) * (image.shape[2]//14) # FIXME: 14 is hardcoded patch size
+ image = processor.preprocess(image, return_tensors="pt")[
+ "pixel_values"
+ ][0]
+ cur_token_len = (image.shape[1] // 14) * (
+ image.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
- self.multimodal_cfg, cur_token_len)
+ self.multimodal_cfg,
+ cur_token_len,
+ )
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
- data_dict = preprocess(
- sources,
- self.tokenizer)
+ data_dict = preprocess(sources, self.tokenizer)
if isinstance(i, int):
- data_dict = dict(input_ids=data_dict["input_ids"][0],
- labels=data_dict["labels"][0])
+ data_dict = dict(
+ input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]
+ )
# image exist in the data
- if 'image' in self.list_data_dict[i]:
- data_dict['image'] = image
- elif self.multimodal_cfg['is_multimodal']:
+ if "image" in self.list_data_dict[i]:
+ data_dict["image"] = image
+ elif self.multimodal_cfg["is_multimodal"]:
# image does not exist in the data, but the model is multimodal
- crop_size = self.multimodal_cfg['image_processor'].crop_size
- data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
+ crop_size = self.multimodal_cfg["image_processor"].crop_size
+ data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
return data_dict
@@ -489,59 +527,65 @@ class DataCollatorForSupervisedDataset(object):
tokenizer: transformers.PreTrainedTokenizer
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
- input_ids, labels = tuple([instance[key] for instance in instances]
- for key in ("input_ids", "labels"))
+ input_ids, labels = tuple(
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
+ )
input_ids = torch.nn.utils.rnn.pad_sequence(
- input_ids,
- batch_first=True,
- padding_value=self.tokenizer.pad_token_id)
- labels = torch.nn.utils.rnn.pad_sequence(labels,
- batch_first=True,
- padding_value=IGNORE_INDEX)
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(
+ labels, batch_first=True, padding_value=IGNORE_INDEX
+ )
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
- if 'image' in instances[0]:
- images = [instance['image'] for instance in instances]
+ if "image" in instances[0]:
+ images = [instance["image"] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
- batch['images'] = torch.stack(images)
+ batch["images"] = torch.stack(images)
else:
- batch['images'] = images
+ batch["images"] = images
return batch
-def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
- data_args) -> Dict:
+def make_supervised_data_module(
+ tokenizer: transformers.PreTrainedTokenizer, data_args
+) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
- dataset_cls = (LazySupervisedDataset
- if data_args.lazy_preprocess else SupervisedDataset)
- train_dataset = dataset_cls(tokenizer=tokenizer,
- data_path=data_args.data_path,
- multimodal_cfg=dict(
- is_multimodal=data_args.is_multimodal,
- sep_image_conv_front=data_args.sep_image_conv_front,
- image_token_len=data_args.image_token_len,
- image_folder=data_args.image_folder,
- image_aspect_ratio=data_args.image_aspect_ratio,
- use_im_start_end=getattr(data_args, 'mm_use_im_start_end', False),
- image_processor=getattr(data_args, 'image_processor', None)))
+ dataset_cls = (
+ LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset
+ )
+ train_dataset = dataset_cls(
+ tokenizer=tokenizer,
+ data_path=data_args.data_path,
+ multimodal_cfg=dict(
+ is_multimodal=data_args.is_multimodal,
+ sep_image_conv_front=data_args.sep_image_conv_front,
+ image_token_len=data_args.image_token_len,
+ image_folder=data_args.image_folder,
+ image_aspect_ratio=data_args.image_aspect_ratio,
+ use_im_start_end=getattr(data_args, "mm_use_im_start_end", False),
+ image_processor=getattr(data_args, "image_processor", None),
+ ),
+ )
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
- return dict(train_dataset=train_dataset,
- eval_dataset=None,
- data_collator=data_collator)
+ return dict(
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
+ )
def train():
parser = transformers.HfArgumentParser(
- (ModelArguments, DataArguments, TrainingArguments))
+ (ModelArguments, DataArguments, TrainingArguments)
+ )
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if model_args.vision_tower is not None:
- if 'mpt' in model_args.model_name_or_path:
+ if "mpt" in model_args.model_name_or_path:
model = LlavaMPTForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
@@ -561,12 +605,12 @@ def train():
if model_args.freeze_backbone:
model.model.requires_grad_(False)
- if 'mpt' in model_args.model_name_or_path:
+ if "mpt" in model_args.model_name_or_path:
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
- padding_side="right"
+ padding_side="right",
)
else:
tokenizer = transformers.AutoTokenizer.from_pretrained(
@@ -585,23 +629,29 @@ def train():
model=model,
)
if "llama" in model_args.model_name_or_path:
- tokenizer.add_special_tokens({
- "eos_token": DEFAULT_EOS_TOKEN,
- "bos_token": DEFAULT_BOS_TOKEN,
- "unk_token": DEFAULT_UNK_TOKEN,
- })
+ tokenizer.add_special_tokens(
+ {
+ "eos_token": DEFAULT_EOS_TOKEN,
+ "bos_token": DEFAULT_BOS_TOKEN,
+ "unk_token": DEFAULT_UNK_TOKEN,
+ }
+ )
else:
tokenizer.pad_token = tokenizer.unk_token
if "mpt" in model_args.model_name_or_path:
- conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
+ "mpt"
+ ]
else:
- conversation_lib.default_conversation = conversation_lib.conv_templates["vicuna_v1_1"]
+ conversation_lib.default_conversation = conversation_lib.conv_templates[
+ "vicuna_v1_1"
+ ]
if model_args.vision_tower is not None:
model_vision_dict = model.get_model().initialize_vision_modules(
vision_tower=model_args.vision_tower,
mm_vision_select_layer=model_args.mm_vision_select_layer,
- pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter
+ pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
)
dtype = torch.float32
if training_args.fp16:
@@ -609,13 +659,15 @@ def train():
if training_args.bf16:
dtype = torch.bfloat16
model.get_model().vision_tower[0].to(dtype=dtype, device=training_args.device)
- vision_config = model_vision_dict['vision_config']
+ vision_config = model_vision_dict["vision_config"]
- data_args.image_token_len = model_vision_dict['image_token_len']
- data_args.image_processor = model_vision_dict['image_processor']
+ data_args.image_token_len = model_vision_dict["image_token_len"]
+ data_args.image_processor = model_vision_dict["image_processor"]
data_args.is_multimodal = True
- model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter
+ model.config.tune_mm_mlp_adapter = (
+ training_args.tune_mm_mlp_adapter
+ ) = model_args.tune_mm_mlp_adapter
if model_args.tune_mm_mlp_adapter:
model.requires_grad_(False)
for p in model.get_model().mm_projector.parameters():
@@ -626,45 +678,66 @@ def train():
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
- model.config.mm_use_im_start_end = data_args.mm_use_im_start_end = model_args.mm_use_im_start_end
- vision_config.use_im_start_end = training_args.use_im_start_end = model_args.mm_use_im_start_end
+ model.config.mm_use_im_start_end = (
+ data_args.mm_use_im_start_end
+ ) = model_args.mm_use_im_start_end
+ vision_config.use_im_start_end = (
+ training_args.use_im_start_end
+ ) = model_args.mm_use_im_start_end
model.config.sep_image_conv_front = data_args.sep_image_conv_front
- model.initialize_vision_tokenizer(mm_use_im_start_end=model_args.mm_use_im_start_end, tokenizer=tokenizer, device=training_args.device,
- tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter, pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter)
+ model.initialize_vision_tokenizer(
+ mm_use_im_start_end=model_args.mm_use_im_start_end,
+ tokenizer=tokenizer,
+ device=training_args.device,
+ tune_mm_mlp_adapter=model_args.tune_mm_mlp_adapter,
+ pretrain_mm_mlp_adapter=model_args.pretrain_mm_mlp_adapter,
+ )
params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
if len(params_no_grad) > 0:
if training_args.fsdp is not None and len(training_args.fsdp) > 0:
if len(params_no_grad) < 10:
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
+ print(
+ "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}".format(
+ len(params_no_grad), params_no_grad
+ )
+ )
else:
- print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
- print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
- print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
+ print(
+ "[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)".format(
+ len(params_no_grad), ", ".join(params_no_grad[:10])
+ )
+ )
+ print(
+ "[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental."
+ )
+ print(
+ "[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining"
+ )
+
+ from torch.distributed.fsdp.fully_sharded_data_parallel import \
+ FullyShardedDataParallel as FSDP
- from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
def patch_FSDP_use_orig_params(func):
def wrap_func(*args, **kwargs):
- use_orig_params = kwargs.pop('use_orig_params', True)
+ use_orig_params = kwargs.pop("use_orig_params", True)
return func(*args, **kwargs, use_orig_params=use_orig_params)
+
return wrap_func
FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
- data_module = make_supervised_data_module(tokenizer=tokenizer,
- data_args=data_args)
- trainer = LLaVATrainer(model=model,
- tokenizer=tokenizer,
- args=training_args,
- **data_module)
+ data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
+ trainer = LLaVATrainer(
+ model=model, tokenizer=tokenizer, args=training_args, **data_module
+ )
if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.save_state()
- safe_save_model_for_hf_trainer(trainer=trainer,
- output_dir=training_args.output_dir)
+ safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
diff --git a/model/llava/train/train_mem.py b/model/llava/train/train_mem.py
index 2487d317855b27d5b07a755ee0389667e4964f02..f3940cf7fea248d055a9cb333a08ebca0f782885 100644
--- a/model/llava/train/train_mem.py
+++ b/model/llava/train/train_mem.py
@@ -3,7 +3,8 @@
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
-from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
+from llava.train.llama_flash_attn_monkey_patch import \
+ replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
diff --git a/model/llava/utils.py b/model/llava/utils.py
index 8f7163c0ba1d9a81d81a950bce61e0f0db06066e..0a2d5fd533ded77352f5548a0ed027b700365ea4 100644
--- a/model/llava/utils.py
+++ b/model/llava/utils.py
@@ -5,11 +5,14 @@ import os
import sys
import requests
-
from llava.constants import LOGDIR
-server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
-moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+server_error_msg = (
+ "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
+)
+moderation_msg = (
+ "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
+)
handler = None
@@ -47,7 +50,8 @@ def build_logger(logger_name, logger_filename):
os.makedirs(LOGDIR, exist_ok=True)
filename = os.path.join(LOGDIR, logger_filename)
handler = logging.handlers.TimedRotatingFileHandler(
- filename, when='D', utc=True)
+ filename, when="D", utc=True
+ )
handler.setFormatter(formatter)
for name, item in logging.root.manager.loggerDict.items():
@@ -61,33 +65,34 @@ class StreamToLogger(object):
"""
Fake file-like stream object that redirects writes to a logger instance.
"""
+
def __init__(self, logger, log_level=logging.INFO):
self.terminal = sys.stdout
self.logger = logger
self.log_level = log_level
- self.linebuf = ''
+ self.linebuf = ""
def __getattr__(self, attr):
return getattr(self.terminal, attr)
def write(self, buf):
temp_linebuf = self.linebuf + buf
- self.linebuf = ''
+ self.linebuf = ""
for line in temp_linebuf.splitlines(True):
# From the io.TextIOWrapper docs:
# On output, if newline is None, any '\n' characters written
# are translated to the system default line separator.
# By default sys.stdout.write() expects '\n' newlines and then
# translates them so this is still cross platform.
- if line[-1] == '\n':
+ if line[-1] == "\n":
self.logger.log(self.log_level, line.rstrip())
else:
self.linebuf += line
def flush(self):
- if self.linebuf != '':
+ if self.linebuf != "":
self.logger.log(self.log_level, self.linebuf.rstrip())
- self.linebuf = ''
+ self.linebuf = ""
def disable_torch_init():
@@ -95,6 +100,7 @@ def disable_torch_init():
Disable the redundant torch default initialization to accelerate model creation.
"""
import torch
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
@@ -104,8 +110,10 @@ def violates_moderation(text):
Check whether the text violates OpenAI moderation API.
"""
url = "https://api.openai.com/v1/moderations"
- headers = {"Content-Type": "application/json",
- "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
+ headers = {
+ "Content-Type": "application/json",
+ "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"],
+ }
text = text.replace("\n", "")
data = "{" + '"input": ' + f'"{text}"' + "}"
data = data.encode("utf-8")
diff --git a/model/segment_anything/__init__.py b/model/segment_anything/__init__.py
index 34383d83f5e76bc801f31b20e5651e383be348b6..e66218b2edd8754f1546ad1dca8b604ce891c365 100755
--- a/model/segment_anything/__init__.py
+++ b/model/segment_anything/__init__.py
@@ -4,12 +4,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-from .build_sam import (
- build_sam,
- build_sam_vit_h,
- build_sam_vit_l,
- build_sam_vit_b,
- sam_model_registry,
-)
-from .predictor import SamPredictor
from .automatic_mask_generator import SamAutomaticMaskGenerator
+from .build_sam import (build_sam, build_sam_vit_b, build_sam_vit_h,
+ build_sam_vit_l, sam_model_registry)
+from .predictor import SamPredictor
diff --git a/model/segment_anything/automatic_mask_generator.py b/model/segment_anything/automatic_mask_generator.py
index d5a8c969207f119feff7087f94e044403acdff00..aa4bc4f0324cf7f91ded55a0993b51deeec41537 100755
--- a/model/segment_anything/automatic_mask_generator.py
+++ b/model/segment_anything/automatic_mask_generator.py
@@ -4,32 +4,21 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Any, Dict, List, Optional, Tuple
+
import numpy as np
import torch
from torchvision.ops.boxes import batched_nms, box_area # type: ignore
-from typing import Any, Dict, List, Optional, Tuple
-
from .modeling import Sam
from .predictor import SamPredictor
-from .utils.amg import (
- MaskData,
- area_from_rle,
- batch_iterator,
- batched_mask_to_box,
- box_xyxy_to_xywh,
- build_all_layer_point_grids,
- calculate_stability_score,
- coco_encode_rle,
- generate_crop_boxes,
- is_box_near_crop_edge,
- mask_to_rle_pytorch,
- remove_small_regions,
- rle_to_mask,
- uncrop_boxes_xyxy,
- uncrop_masks,
- uncrop_points,
-)
+from .utils.amg import (MaskData, area_from_rle, batch_iterator,
+ batched_mask_to_box, box_xyxy_to_xywh,
+ build_all_layer_point_grids, calculate_stability_score,
+ coco_encode_rle, generate_crop_boxes,
+ is_box_near_crop_edge, mask_to_rle_pytorch,
+ remove_small_regions, rle_to_mask, uncrop_boxes_xyxy,
+ uncrop_masks, uncrop_points)
class SamAutomaticMaskGenerator:
@@ -115,7 +104,8 @@ class SamAutomaticMaskGenerator:
"coco_rle",
], f"Unknown output_mode {output_mode}."
if output_mode == "coco_rle":
- from pycocotools import mask as mask_utils # type: ignore # noqa: F401
+ from pycocotools import \
+ mask as mask_utils # type: ignore # noqa: F401
if min_mask_region_area > 0:
import cv2 # type: ignore # noqa: F401
@@ -172,7 +162,9 @@ class SamAutomaticMaskGenerator:
# Encode masks
if self.output_mode == "coco_rle":
- mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
+ mask_data["segmentations"] = [
+ coco_encode_rle(rle) for rle in mask_data["rles"]
+ ]
elif self.output_mode == "binary_mask":
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
else:
@@ -242,7 +234,9 @@ class SamAutomaticMaskGenerator:
# Generate masks for this crop in batches
data = MaskData()
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
- batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
+ batch_data = self._process_batch(
+ points, cropped_im_size, crop_box, orig_size
+ )
data.cat(batch_data)
del batch_data
self.predictor.reset_image()
@@ -275,7 +269,9 @@ class SamAutomaticMaskGenerator:
# Run model on this batch
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
- in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
+ in_labels = torch.ones(
+ in_points.shape[0], dtype=torch.int, device=in_points.device
+ )
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
@@ -298,7 +294,9 @@ class SamAutomaticMaskGenerator:
# Calculate stability score
data["stability_score"] = calculate_stability_score(
- data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
+ data["masks"],
+ self.predictor.model.mask_threshold,
+ self.stability_score_offset,
)
if self.stability_score_thresh > 0.0:
keep_mask = data["stability_score"] >= self.stability_score_thresh
@@ -309,7 +307,9 @@ class SamAutomaticMaskGenerator:
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
- keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
+ keep_mask = ~is_box_near_crop_edge(
+ data["boxes"], crop_box, [0, 0, orig_w, orig_h]
+ )
if not torch.all(keep_mask):
data.filter(keep_mask)
diff --git a/model/segment_anything/build_sam.py b/model/segment_anything/build_sam.py
index 2f85cebcc30a0c410453ee257c1f1b7091872b0a..788d25ad5a6fd32c112201301b320f5884d6e8e8 100755
--- a/model/segment_anything/build_sam.py
+++ b/model/segment_anything/build_sam.py
@@ -4,11 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import torch
-
from functools import partial
-from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
+import torch
+
+from .modeling import (ImageEncoderViT, MaskDecoder, PromptEncoder, Sam,
+ TwoWayTransformer)
def build_sam_vit_h(checkpoint=None):
diff --git a/model/segment_anything/modeling/__init__.py b/model/segment_anything/modeling/__init__.py
index 38e906243d898d7fc071c0fe218338c5cace3ea1..088af386e5b45d14e99d11dec132821ddba5df39 100755
--- a/model/segment_anything/modeling/__init__.py
+++ b/model/segment_anything/modeling/__init__.py
@@ -4,8 +4,8 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-from .sam import Sam
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
+from .sam import Sam
from .transformer import TwoWayTransformer
diff --git a/model/segment_anything/modeling/common.py b/model/segment_anything/modeling/common.py
index 2bf15236a3eb24d8526073bc4fa2b274cccb3f96..e8727816d4861a2d0c7c367879951d1d4fa791fb 100755
--- a/model/segment_anything/modeling/common.py
+++ b/model/segment_anything/modeling/common.py
@@ -4,11 +4,11 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Type
+
import torch
import torch.nn as nn
-from typing import Type
-
class MLPBlock(nn.Module):
def __init__(
diff --git a/model/segment_anything/modeling/image_encoder.py b/model/segment_anything/modeling/image_encoder.py
index b34026cc61c09550549a6f3e6d932a1e19e308c6..b472a3d6b7a609134afe18d7f8740e0c01a56842 100755
--- a/model/segment_anything/modeling/image_encoder.py
+++ b/model/segment_anything/modeling/image_encoder.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Optional, Tuple, Type
+
import torch
import torch.nn as nn
import torch.nn.functional as F
-from typing import Optional, Tuple, Type
-
from .common import LayerNorm2d, MLPBlock
@@ -68,7 +68,9 @@ class ImageEncoderViT(nn.Module):
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
- torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
+ torch.zeros(
+ 1, img_size // patch_size, img_size // patch_size, embed_dim
+ )
)
self.blocks = nn.ModuleList()
@@ -106,7 +108,6 @@ class ImageEncoderViT(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
-
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.pos_embed
@@ -115,8 +116,8 @@ class ImageEncoderViT(nn.Module):
x = blk(x)
dtype = x.dtype
- if dtype == torch.float16: # prevent overflow
- with torch.autocast(device_type='cuda', dtype=torch.float32):
+ if dtype == torch.float16: # prevent overflow
+ with torch.autocast(device_type="cuda", dtype=torch.float32):
x = self.neck(x.permute(0, 3, 1, 2))
x = x.to(dtype)
else:
@@ -167,7 +168,9 @@ class Block(nn.Module):
)
self.norm2 = norm_layer(dim)
- self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
+ self.mlp = MLPBlock(
+ embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer
+ )
self.window_size = window_size
@@ -232,23 +235,34 @@ class Attention(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
- qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ qkv = (
+ self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ )
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
- attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
+ attn = add_decomposed_rel_pos(
+ attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)
+ )
attn = attn.softmax(dim=-1)
- x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
+ x = (
+ (attn @ v)
+ .view(B, self.num_heads, H, W, -1)
+ .permute(0, 2, 3, 1, 4)
+ .reshape(B, H, W, -1)
+ )
x = self.proj(x)
return x
-def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
+def window_partition(
+ x: torch.Tensor, window_size: int
+) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
@@ -268,12 +282,17 @@ def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, T
Hp, Wp = H + pad_h, W + pad_w
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ )
return windows, (Hp, Wp)
def window_unpartition(
- windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
+ windows: torch.Tensor,
+ window_size: int,
+ pad_hw: Tuple[int, int],
+ hw: Tuple[int, int],
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
@@ -289,7 +308,9 @@ def window_unpartition(
Hp, Wp = pad_hw
H, W = hw
B = windows.shape[0] // (Hp * Wp // window_size // window_size)
- x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
+ x = windows.view(
+ B, Hp // window_size, Wp // window_size, window_size, window_size, -1
+ )
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
if Hp > H or Wp > W:
@@ -363,7 +384,9 @@ def add_decomposed_rel_pos(
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
attn = (
- attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
+ attn.view(B, q_h, q_w, k_h, k_w)
+ + rel_h[:, :, :, :, None]
+ + rel_w[:, :, :, None, :]
).view(B, q_h * q_w, k_h * k_w)
return attn
diff --git a/model/segment_anything/modeling/mask_decoder.py b/model/segment_anything/modeling/mask_decoder.py
index f7c0f1be2ce3dbee6b0f32656a51fe9c48c353e3..105bc9206e0fc2b1ceef69a31f4a16ae07e37a94 100755
--- a/model/segment_anything/modeling/mask_decoder.py
+++ b/model/segment_anything/modeling/mask_decoder.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import List, Tuple, Type
+
import torch
from torch import nn
from torch.nn import functional as F
-from typing import List, Tuple, Type
-
from .common import LayerNorm2d
@@ -51,10 +51,14 @@ class MaskDecoder(nn.Module):
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
self.output_upscaling = nn.Sequential(
- nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2),
+ nn.ConvTranspose2d(
+ transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
+ ),
LayerNorm2d(transformer_dim // 4),
activation(),
- nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2),
+ nn.ConvTranspose2d(
+ transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
+ ),
activation(),
)
self.output_hypernetworks_mlps = nn.ModuleList(
@@ -118,9 +122,13 @@ class MaskDecoder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Predicts masks. See 'forward' for more details."""
# Concatenate output tokens
- output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
- output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
-
+ output_tokens = torch.cat(
+ [self.iou_token.weight, self.mask_tokens.weight], dim=0
+ )
+ output_tokens = output_tokens.unsqueeze(0).expand(
+ sparse_prompt_embeddings.size(0), -1, -1
+ )
+
# sparse_prompt_embeddings = sparse_prompt_embeddings.half()
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
@@ -143,10 +151,14 @@ class MaskDecoder(nn.Module):
upscaled_embedding = self.output_upscaling(src)
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
- hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
+ hyper_in_list.append(
+ self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
+ )
hyper_in = torch.stack(hyper_in_list, dim=1)
b, c, h, w = upscaled_embedding.shape
- masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, self.num_mask_tokens, h, w)
+ masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(
+ b, self.num_mask_tokens, h, w
+ )
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
diff --git a/model/segment_anything/modeling/prompt_encoder.py b/model/segment_anything/modeling/prompt_encoder.py
index c08726b353e94e6b324759655bea4ab11238628d..16bc3a45e75f154453ed0724c70ce8daa0324c81 100755
--- a/model/segment_anything/modeling/prompt_encoder.py
+++ b/model/segment_anything/modeling/prompt_encoder.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Any, Optional, Tuple, Type
+
import numpy as np
import torch
from torch import nn
-from typing import Any, Optional, Tuple, Type
-
from .common import LayerNorm2d
@@ -43,11 +43,16 @@ class PromptEncoder(nn.Module):
self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
- point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
+ point_embeddings = [
+ nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
+ ]
self.point_embeddings = nn.ModuleList(point_embeddings)
self.not_a_point_embed = nn.Embedding(1, embed_dim)
- self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1])
+ self.mask_input_size = (
+ 4 * image_embedding_size[0],
+ 4 * image_embedding_size[1],
+ )
self.mask_downscaling = nn.Sequential(
nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
LayerNorm2d(mask_in_chans // 4),
@@ -83,7 +88,9 @@ class PromptEncoder(nn.Module):
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
- point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
+ point_embedding = self.pe_layer.forward_with_coords(
+ points, self.input_image_size
+ )
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
@@ -94,7 +101,9 @@ class PromptEncoder(nn.Module):
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2)
- corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
+ corner_embedding = self.pe_layer.forward_with_coords(
+ coords, self.input_image_size
+ )
corner_embedding[:, 0, :] += self.point_embeddings[2].weight
corner_embedding[:, 1, :] += self.point_embeddings[3].weight
return corner_embedding
@@ -153,7 +162,9 @@ class PromptEncoder(nn.Module):
Bx(embed_dim)x(embed_H)x(embed_W)
"""
bs = self._get_batch_size(points, boxes, masks, text_embeds)
- sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
+ sparse_embeddings = torch.empty(
+ (bs, 0, self.embed_dim), device=self._get_device()
+ )
if points is not None:
coords, labels = points
point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
@@ -206,7 +217,9 @@ class PositionEmbeddingRandom(nn.Module):
"""Generate positional encoding for a grid of the specified size."""
h, w = size
device: Any = self.positional_encoding_gaussian_matrix.device
- grid = torch.ones((h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype)
+ grid = torch.ones(
+ (h, w), device=device, dtype=self.positional_encoding_gaussian_matrix.dtype
+ )
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / h
diff --git a/model/segment_anything/modeling/sam.py b/model/segment_anything/modeling/sam.py
index c857fd5aaad4e696b56aa2fa7c1b23ddf0ca569d..f1d82cac3cc1deea45171fd9360dfd7fa25e457a 100755
--- a/model/segment_anything/modeling/sam.py
+++ b/model/segment_anything/modeling/sam.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Any, Dict, List, Tuple
+
import torch
from torch import nn
from torch.nn import functional as F
-from typing import Any, Dict, List, Tuple
-
from .image_encoder import ImageEncoderViT
from .mask_decoder import MaskDecoder
from .prompt_encoder import PromptEncoder
@@ -43,7 +43,9 @@ class Sam(nn.Module):
self.image_encoder = image_encoder
self.prompt_encoder = prompt_encoder
self.mask_decoder = mask_decoder
- self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer(
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
+ )
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
@property
@@ -94,7 +96,9 @@ class Sam(nn.Module):
shape BxCxHxW, where H=W=256. Can be passed as mask input
to subsequent iterations of prediction.
"""
- input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)
+ input_images = torch.stack(
+ [self.preprocess(x["image"]) for x in batched_input], dim=0
+ )
image_embeddings = self.image_encoder(input_images)
outputs = []
@@ -162,7 +166,9 @@ class Sam(nn.Module):
)
# masks = masks.to(dtype)
masks = masks[..., : input_size[0], : input_size[1]]
- masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
+ masks = F.interpolate(
+ masks, original_size, mode="bilinear", align_corners=False
+ )
return masks
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
diff --git a/model/segment_anything/modeling/transformer.py b/model/segment_anything/modeling/transformer.py
index 28fafea52288603fea275f3a100790471825c34a..8c511e4ff35cc91132b09edd788c96f9a5768161 100755
--- a/model/segment_anything/modeling/transformer.py
+++ b/model/segment_anything/modeling/transformer.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import torch
-from torch import Tensor, nn
-
import math
from typing import Tuple, Type
+import torch
+from torch import Tensor, nn
+
from .common import MLPBlock
@@ -198,7 +198,9 @@ class Attention(nn.Module):
self.embedding_dim = embedding_dim
self.internal_dim = embedding_dim // downsample_rate
self.num_heads = num_heads
- assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim."
+ assert (
+ self.internal_dim % num_heads == 0
+ ), "num_heads must divide embedding_dim."
self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
self.k_proj = nn.Linear(embedding_dim, self.internal_dim)
diff --git a/model/segment_anything/predictor.py b/model/segment_anything/predictor.py
index a3820fb7de8647e5d6adf229debc498b33caad62..bf52d81c2ef2e81b87e574fc935e88749ae3ebf6 100755
--- a/model/segment_anything/predictor.py
+++ b/model/segment_anything/predictor.py
@@ -4,13 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Optional, Tuple
+
import numpy as np
import torch
from .modeling import Sam
-
-from typing import Optional, Tuple
-
from .utils.transforms import ResizeLongestSide
@@ -55,7 +54,9 @@ class SamPredictor:
# Transform the image to the form expected by the model
input_image = self.transform.apply_image(image)
input_image_torch = torch.as_tensor(input_image, device=self.device)
- input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[
+ None, :, :, :
+ ]
self.set_torch_image(input_image_torch, image.shape[:2])
@@ -131,7 +132,9 @@ class SamPredictor:
a subsequent iteration as mask input.
"""
if not self.is_image_set:
- raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
# Transform input prompts
coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None
@@ -140,15 +143,21 @@ class SamPredictor:
point_labels is not None
), "point_labels must be supplied if point_coords is supplied."
point_coords = self.transform.apply_coords(point_coords, self.original_size)
- coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
- labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device)
+ coords_torch = torch.as_tensor(
+ point_coords, dtype=torch.float, device=self.device
+ )
+ labels_torch = torch.as_tensor(
+ point_labels, dtype=torch.int, device=self.device
+ )
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
if box is not None:
box = self.transform.apply_boxes(box, self.original_size)
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
box_torch = box_torch[None, :]
if mask_input is not None:
- mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device)
+ mask_input_torch = torch.as_tensor(
+ mask_input, dtype=torch.float, device=self.device
+ )
mask_input_torch = mask_input_torch[None, :, :, :]
masks, iou_predictions, low_res_masks = self.predict_torch(
@@ -211,7 +220,9 @@ class SamPredictor:
a subsequent iteration as mask input.
"""
if not self.is_image_set:
- raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
+ raise RuntimeError(
+ "An image must be set with .set_image(...) before mask prediction."
+ )
if point_coords is not None:
points = (point_coords, point_labels)
@@ -235,7 +246,9 @@ class SamPredictor:
)
# Upscale the masks to the original image resolution
- masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size)
+ masks = self.model.postprocess_masks(
+ low_res_masks, self.input_size, self.original_size
+ )
if not return_logits:
masks = masks > self.model.mask_threshold
@@ -252,7 +265,9 @@ class SamPredictor:
raise RuntimeError(
"An image must be set with .set_image(...) to generate an embedding."
)
- assert self.features is not None, "Features must exist if an image has been set."
+ assert (
+ self.features is not None
+ ), "Features must exist if an image has been set."
return self.features
@property
diff --git a/model/segment_anything/utils/amg.py b/model/segment_anything/utils/amg.py
index be064071ef399fea96c673ad173689656c23534a..5c3bc5d789049076a2404b1b2477110cebc32fb2 100755
--- a/model/segment_anything/utils/amg.py
+++ b/model/segment_anything/utils/amg.py
@@ -4,14 +4,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import numpy as np
-import torch
-
import math
from copy import deepcopy
from itertools import product
from typing import Any, Dict, Generator, ItemsView, List, Tuple
+import numpy as np
+import torch
+
class MaskData:
"""
diff --git a/model/segment_anything/utils/onnx.py b/model/segment_anything/utils/onnx.py
index 3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe..3521208f620aeef707707037d027c0156d940cdf 100755
--- a/model/segment_anything/utils/onnx.py
+++ b/model/segment_anything/utils/onnx.py
@@ -4,12 +4,12 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from typing import Tuple
+
import torch
import torch.nn as nn
from torch.nn import functional as F
-from typing import Tuple
-
from ..modeling import Sam
from .amg import calculate_stability_score
@@ -48,32 +48,43 @@ class SamOnnxModel(nn.Module):
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size
- def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
+ def _embed_points(
+ self, point_coords: torch.Tensor, point_labels: torch.Tensor
+ ) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
- point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
- point_labels == -1
+ point_embedding = (
+ point_embedding
+ + self.model.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
- point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
- i
- ].weight * (point_labels == i)
+ point_embedding = (
+ point_embedding
+ + self.model.prompt_encoder.point_embeddings[i].weight
+ * (point_labels == i)
+ )
return point_embedding
- def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
- mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
+ def _embed_masks(
+ self, input_mask: torch.Tensor, has_mask_input: torch.Tensor
+ ) -> torch.Tensor:
+ mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(
+ input_mask
+ )
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
- def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
+ def mask_postprocessing(
+ self, masks: torch.Tensor, orig_im_size: torch.Tensor
+ ) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(self.img_size, self.img_size),
@@ -81,7 +92,9 @@ class SamOnnxModel(nn.Module):
align_corners=False,
)
- prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
+ prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(
+ torch.int64
+ )
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
orig_im_size = orig_im_size.to(torch.int64)
diff --git a/model/segment_anything/utils/transforms.py b/model/segment_anything/utils/transforms.py
index 97a682a28ed0fb1481a27a6134d44a98d41d78f3..4232d84252ea4983b194b2ebe8796741d252ef87 100755
--- a/model/segment_anything/utils/transforms.py
+++ b/model/segment_anything/utils/transforms.py
@@ -4,13 +4,14 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
+from copy import deepcopy
+from typing import Tuple
+
import numpy as np
import torch
from torch.nn import functional as F
-from torchvision.transforms.functional import resize, to_pil_image # type: ignore
-
-from copy import deepcopy
-from typing import Tuple
+from torchvision.transforms.functional import resize # type: ignore
+from torchvision.transforms.functional import to_pil_image
class ResizeLongestSide:
@@ -27,10 +28,14 @@ class ResizeLongestSide:
"""
Expects a numpy array with shape HxWxC in uint8 format.
"""
- target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ target_size = self.get_preprocess_shape(
+ image.shape[0], image.shape[1], self.target_length
+ )
return np.array(resize(to_pil_image(image), target_size))
- def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ def apply_coords(
+ self, coords: np.ndarray, original_size: Tuple[int, ...]
+ ) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
@@ -44,7 +49,9 @@ class ResizeLongestSide:
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
- def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
+ def apply_boxes(
+ self, boxes: np.ndarray, original_size: Tuple[int, ...]
+ ) -> np.ndarray:
"""
Expects a numpy array shape Bx4. Requires the original image size
in (H, W) format.
@@ -59,7 +66,9 @@ class ResizeLongestSide:
the transformation expected by the model.
"""
# Expects an image in BCHW format. May not exactly match apply_image.
- target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
+ target_size = self.get_preprocess_shape(
+ image.shape[0], image.shape[1], self.target_length
+ )
return F.interpolate(
image, target_size, mode="bilinear", align_corners=False, antialias=True
)
@@ -91,7 +100,9 @@ class ResizeLongestSide:
return boxes.reshape(-1, 4)
@staticmethod
- def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
+ def get_preprocess_shape(
+ oldh: int, oldw: int, long_side_length: int
+ ) -> Tuple[int, int]:
"""
Compute the output size given input size and target long side length.
"""
diff --git a/train_ds.py b/train_ds.py
new file mode 100755
index 0000000000000000000000000000000000000000..ea98d250e09667c11c2b0d1296765258c8c20df7
--- /dev/null
+++ b/train_ds.py
@@ -0,0 +1,455 @@
+import argparse
+import os
+import shutil
+import sys
+import time
+from functools import partial
+
+import deepspeed
+import numpy as np
+import torch
+import tqdm
+import transformers
+from torch.utils.tensorboard import SummaryWriter
+
+from model.LISA import LISA
+from utils.dataset import HybridDataset, ValDataset, collate_fn
+from utils.utils import (AverageMeter, ProgressMeter, Summary, dict_to_cuda,
+ intersectionAndUnionGPU)
+
+
+def parse_args(args):
+ parser = argparse.ArgumentParser(description="LISA Model Training")
+ parser.add_argument("--local_rank", default=0, type=int, help="node rank")
+ parser.add_argument(
+ "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview"
+ )
+ parser.add_argument("--vis_save_path", default="./vis_output", type=str)
+ parser.add_argument(
+ "--precision",
+ default="bf16",
+ type=str,
+ choices=["fp32", "bf16", "fp16"],
+ help="precision for inference",
+ )
+ parser.add_argument("--image_size", default=1024, type=int, help="image size")
+ parser.add_argument("--model_max_length", default=512, type=int)
+ parser.add_argument("--lora_r", default=8, type=int)
+ parser.add_argument(
+ "--vision-tower", default="openai/clip-vit-large-patch14", type=str
+ )
+ parser.add_argument("--load_in_8bit", action="store_true", default=False)
+ parser.add_argument("--load_in_4bit", action="store_true", default=False)
+
+ parser.add_argument(
+ "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str
+ )
+ parser.add_argument(
+ "--sem_seg_data",
+ default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary",
+ type=str,
+ )
+ parser.add_argument(
+ "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str
+ )
+ parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str)
+ parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str)
+ parser.add_argument(
+ "--val_dataset", default="ReasonSeg|val", type=str
+ )
+ parser.add_argument("--dataset_dir", default="./dataset", type=str)
+ parser.add_argument("--log_base_dir", default="./runs", type=str)
+ parser.add_argument("--exp_name", default="lisa", type=str)
+ parser.add_argument("--epochs", default=20, type=int)
+ parser.add_argument("--steps_per_epoch", default=500, type=int)
+ parser.add_argument(
+ "--batch_size", default=2, type=int, help="batch size per device per step"
+ )
+ parser.add_argument(
+ "--grad_accumulation_steps",
+ default=10,
+ type=int,
+ help="batch size per device per step",
+ )
+ parser.add_argument("--val_batch_size", default=1, type=int)
+ parser.add_argument("--workers", default=4, type=int)
+ parser.add_argument("--lr", default=0.0003, type=float)
+ parser.add_argument("--ce_loss_weight", default=1.0, type=float)
+ parser.add_argument("--dice_loss_weight", default=0.5, type=float)
+ parser.add_argument("--bce_loss_weight", default=2.0, type=float)
+ parser.add_argument("--lora_alpha", default=16, type=int)
+ parser.add_argument("--lora_dropout", default=0.05, type=float)
+ parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str)
+ parser.add_argument("--explanatory", default=0.1, type=float)
+ parser.add_argument("--beta1", default=0.9, type=float)
+ parser.add_argument("--beta2", default=0.95, type=float)
+ parser.add_argument("--num_classes_per_sample", default=3, type=int)
+ parser.add_argument("--exclude_val", action="store_true", default=False)
+ parser.add_argument("--no_eval", action="store_true", default=False)
+ parser.add_argument("--eval_only", action="store_true", default=False)
+ parser.add_argument("--vision_pretrained", default="PATH TO SAM ViT-H Pre-trained Wegiht", type=str)
+ parser.add_argument("--weight", default="", type=str)
+ parser.add_argument("--print_freq", default=1, type=int)
+ parser.add_argument("--start_epoch", default=0, type=int)
+ return parser.parse_args(args)
+
+
+def main(args):
+ args = parse_args(args)
+ args.log_dir = os.path.join(args.log_base_dir, args.exp_name)
+ if args.local_rank == 0:
+ os.makedirs(args.log_dir, exist_ok=True)
+ writer = SummaryWriter(args.log_dir)
+ else:
+ writer = None
+
+ # Create model
+ tokenizer = transformers.AutoTokenizer.from_pretrained(
+ args.version,
+ cache_dir=None,
+ model_max_length=args.model_max_length,
+ padding_side="right",
+ use_fast=False,
+ )
+ tokenizer.pad_token = tokenizer.unk_token
+ num_added_tokens = tokenizer.add_tokens("[SEG]")
+ ret_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids
+ args.seg_token_idx = ret_token_idx[0]
+
+ model = LISA(
+ args.local_rank,
+ args.seg_token_idx,
+ tokenizer,
+ args.version,
+ args.lora_r,
+ args.precision,
+ vision_tower=args.vision_tower,
+ load_in_8bit=args.load_in_8bit,
+ load_in_4bit=args.load_in_4bit,
+ ce_loss_weight=args.ce_loss_weight,
+ dice_loss_weight=args.dice_loss_weight,
+ bce_loss_weight=args.bce_loss_weight,
+ vision_pretrained=args.vision_pretrained,
+ )
+
+ if args.weight:
+ state_dict = torch.load(args.weight, map_location='cpu')
+ model.load_state_dict(state_dict, strict=True)
+
+ world_size = torch.cuda.device_count()
+ args.distributed = world_size > 1
+ train_dataset = HybridDataset(
+ args.dataset_dir,
+ tokenizer,
+ args.vision_tower,
+ samples_per_epoch=args.batch_size * args.grad_accumulation_steps * args.steps_per_epoch * world_size,
+ precision=args.precision,
+ image_size=args.image_size,
+ num_classes_per_sample=args.num_classes_per_sample,
+ exclude_val=args.exclude_val,
+ dataset=args.dataset,
+ sem_seg_data=args.sem_seg_data,
+ refer_seg_data=args.refer_seg_data,
+ vqa_data=args.vqa_data,
+ reason_seg_data=args.reason_seg_data,
+ explanatory=args.explanatory,
+ )
+
+ if args.no_eval == False:
+ val_dataset = ValDataset(
+ args.dataset_dir,
+ tokenizer,
+ args.vision_tower,
+ args.val_dataset,
+ args.image_size,
+ )
+ print(f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples.")
+ else:
+ val_dataset = None
+ print(f"Training with {len(train_dataset)} examples.")
+
+ ds_config = {
+ "train_micro_batch_size_per_gpu": args.batch_size,
+ "gradient_accumulation_steps": args.grad_accumulation_steps,
+ "optimizer": {
+ "type": "AdamW",
+ "params": {
+ "lr": args.lr,
+ "weight_decay": 0.0,
+ "betas": (args.beta1, args.beta2),
+ },
+ },
+ "scheduler": {
+ "type": "WarmupDecayLR",
+ "params": {
+ "total_num_steps": args.epochs * args.steps_per_epoch,
+ "warmup_min_lr": 0,
+ "warmup_max_lr": args.lr,
+ "warmup_num_steps": 100,
+ "warmup_type": "linear",
+ },
+ },
+ "fp16": {
+ "enabled": args.precision == "fp16",
+ },
+ "bf16": {
+ "enabled": args.precision == "bf16",
+ },
+ "gradient_clipping": 1.0,
+ "zero_optimization": {
+ "stage": 2,
+ "contiguous_gradients": True,
+ "overlap_comm": True,
+ "reduce_scatter": True,
+ "reduce_bucket_size": 5e8,
+ "allgather_bucket_size": 5e8,
+ },
+ }
+ model_engine, optimizer, train_loader, scheduler = deepspeed.initialize(
+ model=model,
+ model_parameters=model.parameters(),
+ training_data=train_dataset,
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
+ config=ds_config,
+ )
+
+ if val_dataset is not None:
+ assert args.val_batch_size == 1
+ val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=False)
+ val_loader = torch.utils.data.DataLoader(
+ val_dataset,
+ batch_size=args.val_batch_size,
+ shuffle=False,
+ num_workers=args.workers,
+ pin_memory=False,
+ sampler=val_sampler,
+ collate_fn=partial(collate_fn, tokenizer=tokenizer),
+ )
+
+ train_iter = iter(train_loader)
+ best_score, cur_ciou = 0.0, 0.0
+
+ if args.eval_only:
+ giou, ciou = validate(
+ val_loader, model_engine, 0, writer, args
+ )
+ exit()
+
+ for epoch in range(args.start_epoch, args.epochs):
+
+ # train for one epoch
+ train_iter = train(
+ train_loader,
+ model_engine,
+ epoch,
+ scheduler,
+ writer,
+ train_iter,
+ args,
+ )
+
+ if args.no_eval == False:
+ giou, ciou = validate(
+ val_loader, model_engine, epoch, writer, args
+ )
+ is_best = giou > best_score
+ best_score = max(giou, best_score)
+ cur_ciou = ciou if is_best else cur_ciou
+
+ if args.no_eval or is_best:
+ save_dir = os.path.join(args.log_dir, "ckpt_model")
+ if args.local_rank == 0:
+ torch.save(
+ {"epoch": epoch},
+ os.path.join(
+ args.log_dir,
+ "meta_log_giou{:.3f}_ciou{:.3f}.pth".format(
+ best_score, cur_ciou
+ ),
+ ),
+ )
+ if os.path.exists(save_dir):
+ shutil.rmtree(save_dir)
+ torch.distributed.barrier()
+ model_engine.save_checkpoint(save_dir)
+
+
+def train(
+ train_loader,
+ model,
+ epoch,
+ scheduler,
+ writer,
+ train_iter,
+ args,
+):
+ """Main training loop."""
+ batch_time = AverageMeter("Time", ":6.3f")
+ data_time = AverageMeter("Data", ":6.3f")
+ losses = AverageMeter("Loss", ":.4f")
+ ce_losses = AverageMeter("CeLoss", ":.4f")
+ mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f")
+ mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f")
+ mask_losses = AverageMeter("MaskLoss", ":.4f")
+
+ progress = ProgressMeter(
+ args.steps_per_epoch,
+ [
+ batch_time,
+ losses,
+ ce_losses,
+ mask_losses,
+ mask_bce_losses,
+ mask_dice_losses,
+ ],
+ prefix="Epoch: [{}]".format(epoch),
+ )
+
+ # switch to train mode
+ model.train()
+ end = time.time()
+ for global_step in range(args.steps_per_epoch):
+ for i in range(args.grad_accumulation_steps):
+ try:
+ input_dict = next(train_iter)
+ except:
+ train_iter = iter(train_loader)
+ input_dict = next(train_iter)
+
+ data_time.update(time.time() - end)
+ input_dict = dict_to_cuda(input_dict)
+
+ if args.precision == "fp16":
+ input_dict["images"] = input_dict["images"].half()
+ input_dict["images_clip"] = input_dict["images_clip"].half()
+ elif args.precision == "bf16":
+ input_dict["images"] = input_dict["images"].bfloat16()
+ input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
+ else:
+ input_dict["images"] = input_dict["images"].float()
+ input_dict["images_clip"] = input_dict["images_clip"].float()
+
+ output_dict = model(**input_dict)
+
+ loss = output_dict["loss"]
+ ce_loss = output_dict["ce_loss"]
+ mask_bce_loss = output_dict["mask_bce_loss"]
+ mask_dice_loss = output_dict["mask_dice_loss"]
+ mask_loss = output_dict["mask_loss"]
+
+ losses.update(loss.item(), input_dict["images"].size(0))
+ ce_losses.update(ce_loss.item(), input_dict["images"].size(0))
+ mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0))
+ mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0))
+ mask_losses.update(mask_loss.item(), input_dict["images"].size(0))
+ model.backward(loss)
+ model.step()
+
+ # measure elapsed time
+ batch_time.update(time.time() - end)
+ end = time.time()
+
+ if global_step % args.print_freq == 0:
+ if args.distributed:
+ batch_time.all_reduce()
+ data_time.all_reduce()
+
+ losses.all_reduce()
+ ce_losses.all_reduce()
+ mask_bce_losses.all_reduce()
+ mask_dice_losses.all_reduce()
+ mask_losses.all_reduce()
+
+ if args.local_rank == 0:
+ progress.display(global_step + 1)
+ writer.add_scalar("train/loss", losses.avg, global_step)
+ writer.add_scalar("train/ce_loss", ce_losses.avg, global_step)
+ writer.add_scalar(
+ "train/mask_bce_loss", mask_bce_losses.avg, global_step
+ )
+ writer.add_scalar(
+ "train/mask_dice_loss", mask_dice_losses.avg, global_step
+ )
+ writer.add_scalar("train/mask_loss", mask_losses.avg, global_step)
+ writer.add_scalar(
+ "metrics/total_secs_per_batch", batch_time.avg, global_step
+ )
+ writer.add_scalar(
+ "metrics/data_secs_per_batch", data_time.avg, global_step
+ )
+
+ batch_time.reset()
+ data_time.reset()
+ losses.reset()
+ ce_losses.reset()
+ mask_bce_losses.reset()
+ mask_dice_losses.reset()
+ mask_losses.reset()
+
+ if global_step != 0:
+ curr_lr = scheduler.get_last_lr()
+ if args.local_rank == 0:
+ writer.add_scalar("train/lr", curr_lr[0], global_step)
+
+ return train_iter
+
+
+def validate(val_loader, model_engine, epoch, writer, args):
+ intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM)
+ union_meter = AverageMeter("Union", ":6.3f", Summary.SUM)
+ acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM)
+
+ model_engine.eval()
+
+ for input_dict in tqdm.tqdm(val_loader):
+ input_dict = dict_to_cuda(input_dict)
+ if args.precision == "fp16":
+ input_dict["images"] = input_dict["images"].half()
+ input_dict["images_clip"] = input_dict["images_clip"].half()
+ elif args.precision == "bf16":
+ input_dict["images"] = input_dict["images"].bfloat16()
+ input_dict["images_clip"] = input_dict["images_clip"].bfloat16()
+ else:
+ input_dict["images"] = input_dict["images"].float()
+ input_dict["images_clip"] = input_dict["images_clip"].float()
+
+ output_dict = model_engine(**input_dict)
+
+ pred_masks = output_dict["pred_masks"]
+ masks_list = output_dict["gt_masks"][0].int()
+ output_list = (pred_masks[0] > 0).int()
+ assert len(pred_masks) == 1
+
+ intersection, union, acc_iou = 0.0, 0.0, 0.0
+ for mask_i, output_i in zip(masks_list, output_list):
+ intersection_i, union_i, _ = intersectionAndUnionGPU(
+ output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255
+ )
+ intersection += intersection_i
+ union += union_i
+ acc_iou += intersection_i / (union_i + 1e-5)
+ acc_iou[union_i == 0] += 1.0 # no-object target
+ intersection, union = intersection.cpu().numpy(), union.cpu().numpy()
+ acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0]
+ intersection_meter.update(intersection), union_meter.update(
+ union
+ ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0])
+
+ intersection_meter.all_reduce()
+ union_meter.all_reduce()
+ acc_iou_meter.all_reduce()
+
+ iou_class = intersection_meter.sum / (union_meter.sum + 1e-10)
+ ciou = iou_class[1]
+ giou = acc_iou_meter.avg[1]
+
+ if args.local_rank == 0:
+ writer.add_scalar("val/giou", giou, epoch)
+ writer.add_scalar("val/giou", ciou, epoch)
+ print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou))
+
+ return giou, ciou
+
+
+if __name__ == "__main__":
+ main(sys.argv[1:])
diff --git a/utils/ade20k_classes.json b/utils/ade20k_classes.json
new file mode 100644
index 0000000000000000000000000000000000000000..1f96e616bc3fd2f8c0ec4caea975d77c680f44bb
--- /dev/null
+++ b/utils/ade20k_classes.json
@@ -0,0 +1,30 @@
+[
+ "wall", "building", "sky", "floor", "tree", "ceiling", "road",
+ "bed", "windowpane", "grass", "cabinet", "sidewalk",
+ "person", "earth", "door", "table", "mountain", "plant",
+ "curtain", "chair", "car", "water", "painting", "sofa",
+ "shelf", "house", "sea", "mirror", "rug", "field", "armchair",
+ "seat", "fence", "desk", "rock", "wardrobe", "lamp",
+ "bathtub", "railing", "cushion", "base", "box", "column",
+ "signboard", "chest of drawers", "counter", "sand", "sink",
+ "skyscraper", "fireplace", "refrigerator", "grandstand",
+ "path", "stairs", "runway", "case", "pool table", "pillow",
+ "screen door", "stairway", "river", "bridge", "bookcase",
+ "blind", "coffee table", "toilet", "flower", "book", "hill",
+ "bench", "countertop", "stove", "palm", "kitchen island",
+ "computer", "swivel chair", "boat", "bar", "arcade machine",
+ "hovel", "bus", "towel", "light", "truck", "tower",
+ "chandelier", "awning", "streetlight", "booth",
+ "television receiver", "airplane", "dirt track", "apparel",
+ "pole", "land", "bannister", "escalator", "ottoman", "bottle",
+ "buffet", "poster", "stage", "van", "ship", "fountain",
+ "conveyer belt", "canopy", "washer", "plaything",
+ "swimming pool", "stool", "barrel", "basket", "waterfall",
+ "tent", "bag", "minibike", "cradle", "oven", "ball", "food",
+ "step", "tank", "trade name", "microwave", "pot", "animal",
+ "bicycle", "lake", "dishwasher", "screen", "blanket",
+ "sculpture", "hood", "sconce", "vase", "traffic light",
+ "tray", "ashcan", "fan", "pier", "crt screen", "plate",
+ "monitor", "bulletin board", "shower", "radiator", "glass",
+ "clock", "flag"
+]
\ No newline at end of file
diff --git a/utils/cocostuff_classes.txt b/utils/cocostuff_classes.txt
new file mode 100755
index 0000000000000000000000000000000000000000..1d5a692b83ac8eead2bfffa805e1115cef737bae
--- /dev/null
+++ b/utils/cocostuff_classes.txt
@@ -0,0 +1,183 @@
+0: unlabeled
+1: person
+2: bicycle
+3: car
+4: motorcycle
+5: airplane
+6: bus
+7: train
+8: truck
+9: boat
+10: traffic light
+11: fire hydrant
+12: street sign
+13: stop sign
+14: parking meter
+15: bench
+16: bird
+17: cat
+18: dog
+19: horse
+20: sheep
+21: cow
+22: elephant
+23: bear
+24: zebra
+25: giraffe
+26: hat
+27: backpack
+28: umbrella
+29: shoe
+30: eye glasses
+31: handbag
+32: tie
+33: suitcase
+34: frisbee
+35: skis
+36: snowboard
+37: sports ball
+38: kite
+39: baseball bat
+40: baseball glove
+41: skateboard
+42: surfboard
+43: tennis racket
+44: bottle
+45: plate
+46: wine glass
+47: cup
+48: fork
+49: knife
+50: spoon
+51: bowl
+52: banana
+53: apple
+54: sandwich
+55: orange
+56: broccoli
+57: carrot
+58: hot dog
+59: pizza
+60: donut
+61: cake
+62: chair
+63: couch
+64: potted plant
+65: bed
+66: mirror
+67: dining table
+68: window
+69: desk
+70: toilet
+71: door
+72: tv
+73: laptop
+74: mouse
+75: remote
+76: keyboard
+77: cell phone
+78: microwave
+79: oven
+80: toaster
+81: sink
+82: refrigerator
+83: blender
+84: book
+85: clock
+86: vase
+87: scissors
+88: teddy bear
+89: hair drier
+90: toothbrush
+91: hair brush
+92: banner
+93: blanket
+94: branch
+95: bridge
+96: building-other
+97: bush
+98: cabinet
+99: cage
+100: cardboard
+101: carpet
+102: ceiling-other
+103: ceiling-tile
+104: cloth
+105: clothes
+106: clouds
+107: counter
+108: cupboard
+109: curtain
+110: desk-stuff
+111: dirt
+112: door-stuff
+113: fence
+114: floor-marble
+115: floor-other
+116: floor-stone
+117: floor-tile
+118: floor-wood
+119: flower
+120: fog
+121: food-other
+122: fruit
+123: furniture-other
+124: grass
+125: gravel
+126: ground-other
+127: hill
+128: house
+129: leaves
+130: light
+131: mat
+132: metal
+133: mirror-stuff
+134: moss
+135: mountain
+136: mud
+137: napkin
+138: net
+139: paper
+140: pavement
+141: pillow
+142: plant-other
+143: plastic
+144: platform
+145: playingfield
+146: railing
+147: railroad
+148: river
+149: road
+150: rock
+151: roof
+152: rug
+153: salad
+154: sand
+155: sea
+156: shelf
+157: sky
+158: skyscraper
+159: snow
+160: solid-other
+161: stairs
+162: stone
+163: straw
+164: structural-other
+165: table
+166: tent
+167: textile-other
+168: towel
+169: tree
+170: vegetable
+171: wall-brick
+172: wall-concrete
+173: wall-other
+174: wall-panel
+175: wall-stone
+176: wall-tile
+177: wall-wood
+178: water-other
+179: waterdrops
+180: window-blind
+181: window-other
+182: wood
diff --git a/utils/conversation.py b/utils/conversation.py
index 0cf11c4096391485b332e2006fd88aa80e6b783e..65ea31ff2e1ba6f93c5942d096162576284fff61 100644
--- a/utils/conversation.py
+++ b/utils/conversation.py
@@ -3,8 +3,8 @@ Conversation prompt templates.
"""
import dataclasses
-from enum import auto, Enum
-from typing import List, Tuple, Any
+from enum import Enum, auto
+from typing import Any, List
class SeparatorStyle(Enum):
diff --git a/utils/data_proc_demo.py b/utils/data_proc_demo.py
deleted file mode 100644
index 88eb9b6696c7e65e8384846182c1129083f239db..0000000000000000000000000000000000000000
--- a/utils/data_proc_demo.py
+++ /dev/null
@@ -1,83 +0,0 @@
-import os
-import numpy as np
-import json
-import cv2
-import glob
-
-def get_mask_from_json(json_path, img):
- try:
- with open(json_path, 'r') as r:
- anno = json.loads(r.read())
- except:
- with open(json_path, 'r', encoding="cp1252") as r:
- anno = json.loads(r.read())
-
- inform = anno['shapes']
- comments = anno['text']
- is_sentence = anno['is_sentence']
-
- height, width = img.shape[:2]
-
- ### sort polies by area
- area_list = []
- valid_poly_list = []
- for i in inform:
- label_id = i['label']
- points = i['points']
- if 'flag' == label_id.lower(): ## meaningless deprecated annotations
- continue
-
- tmp_mask = np.zeros((height, width), dtype=np.uint8)
- cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
- cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
- tmp_area = tmp_mask.sum()
-
- area_list.append(tmp_area)
- valid_poly_list.append(i)
-
- ### ground-truth mask
- sort_index = np.argsort(area_list)[::-1].astype(np.int32)
- sort_index = list(sort_index)
- sort_inform = []
- for s_idx in sort_index:
- sort_inform.append(valid_poly_list[s_idx])
-
- mask = np.zeros((height, width), dtype=np.uint8)
- for i in sort_inform:
- label_id = i['label']
- points = i['points']
-
- if 'ignore' in label_id.lower():
- label_value = 255 # ignored during evaluation
- else:
- label_value = 1 # target
-
- cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
- cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)
-
- return mask, comments, is_sentence
-
-
-if __name__ == '__main__':
- data_dir = './train'
- vis_dir = './vis'
-
- if not os.path.exists(vis_dir):
- os.makedirs(vis_dir)
-
- json_path_list = sorted(glob.glob(data_dir + '/*.json'))
- for json_path in json_path_list:
- img_path = json_path.replace('.json', '.jpg')
- img = cv2.imread(img_path)[:,:,::-1]
-
- # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton.
- mask, comments, is_sentence = get_mask_from_json(json_path, img)
-
- ## visualization. Green for target, and red for ignore.
- valid_mask = (mask == 1).astype(np.float32)[:,:,None]
- ignore_mask = (mask == 255).astype(np.float32)[:,:,None]
- vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + ((np.array([0,255,0]) * 0.6 + img * 0.4) * valid_mask + (np.array([255,0,0]) * 0.6 + img * 0.4) * ignore_mask)
- vis_img = np.concatenate([img, vis_img], 1)
- vis_path = os.path.join(vis_dir, json_path.split('/')[-1].replace('.json', '.jpg'))
- cv2.imwrite(vis_path, vis_img[:,:,::-1])
- print('Visualization has been saved to: ', vis_path)
\ No newline at end of file
diff --git a/utils/data_processing.py b/utils/data_processing.py
new file mode 100644
index 0000000000000000000000000000000000000000..d47a80f0111019c97ccb2ce198f37495ee037471
--- /dev/null
+++ b/utils/data_processing.py
@@ -0,0 +1,90 @@
+import glob
+import json
+import os
+
+import cv2
+import numpy as np
+
+
+def get_mask_from_json(json_path, img):
+ try:
+ with open(json_path, "r") as r:
+ anno = json.loads(r.read())
+ except:
+ with open(json_path, "r", encoding="cp1252") as r:
+ anno = json.loads(r.read())
+
+ inform = anno["shapes"]
+ comments = anno["text"]
+ is_sentence = anno["is_sentence"]
+
+ height, width = img.shape[:2]
+
+ ### sort polies by area
+ area_list = []
+ valid_poly_list = []
+ for i in inform:
+ label_id = i["label"]
+ points = i["points"]
+ if "flag" == label_id.lower(): ## meaningless deprecated annotations
+ continue
+
+ tmp_mask = np.zeros((height, width), dtype=np.uint8)
+ cv2.polylines(tmp_mask, np.array([points], dtype=np.int32), True, 1, 1)
+ cv2.fillPoly(tmp_mask, np.array([points], dtype=np.int32), 1)
+ tmp_area = tmp_mask.sum()
+
+ area_list.append(tmp_area)
+ valid_poly_list.append(i)
+
+ ### ground-truth mask
+ sort_index = np.argsort(area_list)[::-1].astype(np.int32)
+ sort_index = list(sort_index)
+ sort_inform = []
+ for s_idx in sort_index:
+ sort_inform.append(valid_poly_list[s_idx])
+
+ mask = np.zeros((height, width), dtype=np.uint8)
+ for i in sort_inform:
+ label_id = i["label"]
+ points = i["points"]
+
+ if "ignore" in label_id.lower():
+ label_value = 255 # ignored during evaluation
+ else:
+ label_value = 1 # target
+
+ cv2.polylines(mask, np.array([points], dtype=np.int32), True, label_value, 1)
+ cv2.fillPoly(mask, np.array([points], dtype=np.int32), label_value)
+
+ return mask, comments, is_sentence
+
+
+if __name__ == "__main__":
+ data_dir = "./train"
+ vis_dir = "./vis"
+
+ if not os.path.exists(vis_dir):
+ os.makedirs(vis_dir)
+
+ json_path_list = sorted(glob.glob(data_dir + "/*.json"))
+ for json_path in json_path_list:
+ img_path = json_path.replace(".json", ".jpg")
+ img = cv2.imread(img_path)[:, :, ::-1]
+
+ # In generated mask, value 1 denotes valid target region, and value 255 stands for region ignored during evaluaiton.
+ mask, comments, is_sentence = get_mask_from_json(json_path, img)
+
+ ## visualization. Green for target, and red for ignore.
+ valid_mask = (mask == 1).astype(np.float32)[:, :, None]
+ ignore_mask = (mask == 255).astype(np.float32)[:, :, None]
+ vis_img = img * (1 - valid_mask) * (1 - ignore_mask) + (
+ (np.array([0, 255, 0]) * 0.6 + img * 0.4) * valid_mask
+ + (np.array([255, 0, 0]) * 0.6 + img * 0.4) * ignore_mask
+ )
+ vis_img = np.concatenate([img, vis_img], 1)
+ vis_path = os.path.join(
+ vis_dir, json_path.split("/")[-1].replace(".json", ".jpg")
+ )
+ cv2.imwrite(vis_path, vis_img[:, :, ::-1])
+ print("Visualization has been saved to: ", vis_path)
diff --git a/utils/dataset.py b/utils/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3499046ff89c54bf1dcffa7f1f8c2b564a1d0ffe
--- /dev/null
+++ b/utils/dataset.py
@@ -0,0 +1,450 @@
+import glob
+import os
+import random
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from pycocotools import mask
+from transformers import CLIPImageProcessor
+
+from model.segment_anything.utils.transforms import ResizeLongestSide
+
+from .conversation import get_default_conv_template
+from .data_processing import get_mask_from_json
+from .reason_seg_dataset import ReasonSegDataset
+from .refer import REFER
+from .refer_seg_dataset import ReferSegDataset
+from .sem_seg_dataset import SemSegDataset
+from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN)
+from .vqa_dataset import VQADataset
+
+
+def collate_fn(batch, tokenizer=None):
+ image_path_list = []
+ images_list = []
+ images_clip_list = []
+ conversation_list = []
+ masks_list = []
+ label_list = []
+ resize_list = []
+ questions_list = []
+ sampled_classes_list = []
+ offset_list = [0]
+ cnt = 0
+ inferences = []
+ for (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ label,
+ resize,
+ questions,
+ sampled_classes,
+ inference,
+ ) in batch:
+ image_path_list.append(image_path)
+ images_list.append(images)
+ images_clip_list.append(images_clip)
+ conversation_list.extend(conversations)
+ label_list.append(label)
+ masks_list.append(masks.float())
+ resize_list.append(resize)
+ questions_list.append(questions)
+ sampled_classes_list.append(sampled_classes)
+ cnt += len(conversations)
+ offset_list.append(cnt)
+ inferences.append(inference)
+
+ tokenize_data = tokenizer(
+ conversation_list,
+ return_tensors="pt",
+ padding="longest",
+ max_length=tokenizer.model_max_length,
+ truncation=True,
+ )
+
+ input_ids = tokenize_data.input_ids
+ attention_masks = tokenize_data.attention_mask
+
+ IGNORE_TOKEN_ID = -100
+ conv = get_default_conv_template("vicuna").copy()
+ targets = input_ids.clone()
+ sep = conv.sep + conv.roles[1] + ": "
+ for conversation, target in zip(conversation_list, targets):
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
+
+ rounds = conversation.split(conv.sep2)
+ cur_len = 1
+ target[:cur_len] = IGNORE_TOKEN_ID
+ for i, rou in enumerate(rounds):
+ if rou == "":
+ break
+
+ parts = rou.split(sep)
+ # if len(parts) != 2:
+ # break
+ assert len(parts) == 2, (len(parts), rou)
+ parts[0] += sep
+ round_len = len(tokenizer(rou).input_ids)
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
+
+ target[cur_len : cur_len + instruction_len] = IGNORE_TOKEN_ID
+
+ cur_len += round_len
+ target[cur_len:] = IGNORE_TOKEN_ID
+
+ if False:
+ # if True:
+ z = target.clone()
+ z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
+ # rank0_print(tokenizer.decode(z))
+ print(
+ "conversation: ",
+ conversation,
+ "tokenizer.decode(z): ",
+ tokenizer.decode(z),
+ )
+
+ if cur_len < tokenizer.model_max_length:
+ assert cur_len == total_len
+
+ return {
+ "image_paths": image_path_list,
+ "images": torch.stack(images_list, dim=0),
+ "images_clip": torch.stack(images_clip_list, dim=0),
+ "input_ids": input_ids,
+ "labels": targets,
+ "attention_masks": attention_masks,
+ "masks_list": masks_list,
+ "label_list": label_list,
+ "resize_list": resize_list,
+ "offset": torch.LongTensor(offset_list),
+ "questions_list": questions_list,
+ "sampled_classes_list": sampled_classes_list,
+ "inference": inferences[0],
+ "conversation_list": conversation_list,
+ }
+
+
+class HybridDataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch=500 * 8 * 2 * 10,
+ precision: str = "fp32",
+ image_size: int = 224,
+ num_classes_per_sample: int = 3,
+ exclude_val=False,
+ dataset="sem_seg||refer_seg||vqa||reason_seg",
+ sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
+ refer_seg_data="refclef||refcoco||refcoco+||refcocog",
+ vqa_data="llava_instruct_150k",
+ reason_seg_data="ReasonSeg|train",
+ explanatory=0.1,
+ ):
+ self.exclude_val = exclude_val
+ self.dataset = dataset
+ self.samples_per_epoch = samples_per_epoch
+ self.explanatory = explanatory
+ self.num_classes_per_sample = num_classes_per_sample
+
+ self.base_image_dir = base_image_dir
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.precision = precision
+
+ self.datasets = dataset.split("||")
+
+ self.all_datasets = []
+ for dataset in self.datasets:
+ if dataset == "sem_seg":
+ self.all_datasets.append(
+ SemSegDataset(
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch,
+ precision,
+ image_size,
+ num_classes_per_sample,
+ exclude_val,
+ sem_seg_data,
+ )
+ )
+ elif dataset == "refer_seg":
+ self.all_datasets.append(
+ ReferSegDataset(
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch,
+ precision,
+ image_size,
+ num_classes_per_sample,
+ exclude_val,
+ refer_seg_data,
+ )
+ )
+ elif dataset == "vqa":
+ self.all_datasets.append(
+ VQADataset(
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch,
+ precision,
+ image_size,
+ num_classes_per_sample,
+ exclude_val,
+ vqa_data,
+ )
+ )
+ elif dataset == "reason_seg":
+ self.all_datasets.append(
+ ReasonSegDataset(
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch,
+ precision,
+ image_size,
+ num_classes_per_sample,
+ exclude_val,
+ reason_seg_data,
+ explanatory,
+ )
+ )
+
+ def __len__(self):
+ return self.samples_per_epoch
+
+ def __getitem__(self, idx):
+ ind = (random.randint(0, 2023) * (idx + 1)) % len(
+ self.datasets
+ ) # random.randint(0, len(self.datasets)-1)
+ data = self.all_datasets[ind]
+ inference = False
+ return *data[0], inference
+
+
+class ValDataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ val_dataset,
+ image_size=1024,
+ ):
+ self.base_image_dir = base_image_dir
+ splits = val_dataset.split("|")
+ if len(splits) == 2:
+ ds, split = splits
+ images = glob.glob(
+ os.path.join(self.base_image_dir, "reason_seg", ds, split, "*.jpg")
+ )
+ self.images = images
+ self.data_type = 'reason_seg'
+ elif len(splits) == 3:
+ ds, splitBy, split = splits
+ refer_api = REFER(self.base_image_dir, ds, splitBy)
+ ref_ids_val = refer_api.getRefIds(split=split)
+ images_ids_val = refer_api.getImgIds(ref_ids=ref_ids_val)
+ refs_val = refer_api.loadRefs(ref_ids=ref_ids_val)
+ refer_seg_ds = {}
+ refer_seg_ds["images"] = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_val)
+ for item in loaded_images:
+ item = item.copy()
+ if ds == "refclef":
+ item["file_name"] = os.path.join(
+ base_image_dir, "images/saiapr_tc-12", item["file_name"]
+ )
+ elif ds in ["refcoco", "refcoco+", "refcocog", "grefcoco"]:
+ item["file_name"] = os.path.join(
+ base_image_dir,
+ "images/mscoco/images/train2014",
+ item["file_name"],
+ )
+ refer_seg_ds["images"].append(item)
+ refer_seg_ds["annotations"] = refer_api.Anns # anns_val
+
+ img2refs = {}
+ for ref in refs_val:
+ image_id = ref["image_id"]
+ img2refs[image_id] = img2refs.get(image_id, []) + [
+ ref,
+ ]
+ refer_seg_ds["img2refs"] = img2refs
+ self.refer_seg_ds = refer_seg_ds
+ self.data_type = 'refer_seg'
+
+ self.ds = ds
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.transform = ResizeLongestSide(image_size)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+
+ def __len__(self):
+ if self.data_type == 'refer_seg':
+ return len(self.refer_seg_ds["images"])
+ else:
+ return len(self.images)
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.img_size - h
+ padw = self.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def __getitem__(self, idx):
+ if self.data_type == 'refer_seg':
+ refer_seg_ds = self.refer_seg_ds
+ images = refer_seg_ds["images"]
+ annotations = refer_seg_ds["annotations"]
+ img2refs = refer_seg_ds["img2refs"]
+
+ image = images[idx]
+ image_path = image["file_name"]
+ image_id = image["id"]
+
+ refs = img2refs[image_id]
+ if len(refs) == 0:
+ raise ValueError("image {} has no refs".format(image_id))
+
+ sents = []
+ ann_ids = []
+ for ref in refs:
+ for sent in ref["sentences"]:
+ sents.append(sent["sent"].strip().lower())
+ ann_ids.append(ref["ann_id"])
+
+ sampled_sents = sents
+ sampled_ann_ids = ann_ids
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ is_sentence = False
+ else:
+ image_path = self.images[idx]
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ json_path = image_path.replace(".jpg", ".json")
+ mask_json, sampled_sents, is_sentence = get_mask_from_json(json_path, img)
+ sampled_sents = [sampled_sents[0]]
+
+ conversations = []
+ conv = get_default_conv_template("vicuna").copy()
+ i = 0
+ while i < len(sampled_sents):
+ conv.messages = []
+ text = sampled_sents[i].strip()
+ if is_sentence:
+ conv.append_message(
+ conv.roles[0],
+ DEFAULT_IMAGE_TOKEN
+ + " {} Please output segmentation mask.".format(text),
+ )
+ conv.append_message(conv.roles[1], "[SEG].")
+ else:
+ conv.append_message(
+ conv.roles[0],
+ DEFAULT_IMAGE_TOKEN
+ + " What is {} in this image? Please output segmentation mask.".format(
+ text
+ ),
+ )
+ conv.append_message(conv.roles[1], "[SEG].")
+ conversations.append(conv.get_prompt())
+ i += 1
+
+ # replace token
+ image_token_len = 256
+ for i in range(len(conversations)):
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ conversations[i] = conversations[i].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+
+ # preprocess images for clip
+ images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
+ "pixel_values"
+ ][0]
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+
+ # preprocess images for sam
+ images = self.transform.apply_image(images)
+
+ resize = images.shape[:2]
+
+ images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
+
+ if self.data_type == 'refer_seg':
+ masks = []
+ for i, ann_id in enumerate(sampled_ann_ids):
+ ann = annotations[ann_id]
+ if len(ann["segmentation"]) == 0 and sampled_sents[i] != "":
+ m = np.zeros((image["height"], image["width"], 1))
+ else:
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image["height"], image["width"]
+ )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(
+ m, axis=2
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ masks.append(m)
+ else:
+ masks = [mask_json]
+
+ masks = np.stack(masks, axis=0)
+ masks = torch.from_numpy(masks)
+ labels = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
+ inference = True
+
+ return (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ labels,
+ resize,
+ None,
+ None,
+ inference,
+ )
diff --git a/utils/reason_seg_dataset.py b/utils/reason_seg_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b911c7f36ed3a81fa6995cb2d9c2d5d846f50a8
--- /dev/null
+++ b/utils/reason_seg_dataset.py
@@ -0,0 +1,247 @@
+import glob
+import json
+import os
+import random
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor
+
+from model.segment_anything.utils.transforms import ResizeLongestSide
+
+from .conversation import get_default_conv_template
+from .data_processing import get_mask_from_json
+from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
+ EXPLANATORY_QUESTION_LIST, LONG_QUESTION_LIST,
+ SHORT_QUESTION_LIST)
+
+
+class ReasonSegDataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch=500 * 8 * 2 * 10,
+ precision: str = "fp32",
+ image_size: int = 224,
+ num_classes_per_sample: int = 3,
+ exclude_val=False,
+ reason_seg_data="ReasonSeg|train",
+ explanatory=0.1,
+ ):
+ self.exclude_val = exclude_val
+ self.reason_seg_data = reason_seg_data
+ self.samples_per_epoch = samples_per_epoch
+ self.explanatory = explanatory
+ self.num_classes_per_sample = num_classes_per_sample
+
+ self.base_image_dir = base_image_dir
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.precision = precision
+ self.transform = ResizeLongestSide(image_size)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+
+ self.short_question_list = SHORT_QUESTION_LIST
+ self.long_question_list = LONG_QUESTION_LIST
+ self.answer_list = ANSWER_LIST
+
+ if explanatory != -1:
+ self.explanatory_question_list = EXPLANATORY_QUESTION_LIST
+
+ if explanatory != -1:
+ self.img_to_why = {}
+ for sub_data in [
+ "20230711_2000_0_processed_masked_finished_masked.json",
+ "20230711_2000_0_processed_masked_partial_masked.json",
+ ]:
+ with open(
+ os.path.join(base_image_dir, "reason_seg", "explanatory", sub_data)
+ ) as f:
+ items = json.load(f)
+ for item in items:
+ img_name = item["image_path"].split("/")[-1]
+ self.img_to_why[img_name] = {
+ "query": item["query"],
+ "outputs": item["outputs"],
+ }
+
+ reason_seg_data, splits = reason_seg_data.split("|")
+ splits = splits.split("_")
+ images = []
+ for split in splits:
+ images_split = glob.glob(
+ os.path.join(
+ base_image_dir, "reason_seg", reason_seg_data, split, "*.jpg"
+ )
+ )
+ images.extend(images_split)
+ jsons = [path.replace(".jpg", ".json") for path in images]
+ self.reason_seg_data = (images, jsons)
+
+ def __len__(self):
+ return self.samples_per_epoch
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.img_size - h
+ padw = self.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def __getitem__(self, idx):
+ images, jsons = self.reason_seg_data
+ idx = random.randint(0, len(images) - 1)
+ image_path = images[idx]
+ json_path = jsons[idx]
+
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ ori_size = images.shape[:2]
+ # preprocess images for clip
+ images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
+ "pixel_values"
+ ][0]
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+ images = self.transform.apply_image(images) # preprocess images for sam
+ resize = images.shape[:2]
+
+ mask, sents, is_sentence = get_mask_from_json(json_path, img)
+ if len(sents) >= self.num_classes_per_sample:
+ sampled_inds = np.random.choice(
+ list(range(len(sents))), size=self.num_classes_per_sample, replace=False
+ )
+ else:
+ sampled_inds = list(range(len(sents)))
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
+ sampled_masks = [
+ (mask == 1).astype(np.float32) for _ in range(len(sampled_inds))
+ ]
+
+ image_name = image_path.split("/")[-1]
+ if (
+ self.explanatory != -1 and image_name in self.img_to_why
+ ): # ds in ['20230711_2000_0_processed_masked_partial_masked', '20230711_2000_0_processed_masked_finished_masked', 'trainval_rephrased_20230730_checked_final_masked', 'rephrased_20230730_checked_final_masked']:
+ if random.random() < self.explanatory:
+ choice = 2
+ else:
+ choice = random.randint(0, 1)
+
+ questions = []
+ answers = []
+ class_ids = []
+ for text in sampled_sents:
+ if is_sentence:
+ question_template = random.choice(self.long_question_list)
+ questions.append(question_template.format(sent=text))
+ else:
+ question_template = random.choice(self.short_question_list)
+ questions.append(question_template.format(class_name=text.lower()))
+
+ img_name = image_path.split("/")[-1]
+ if self.explanatory != -1 and img_name in self.img_to_why:
+ # choice = random.randint(0, 2)
+ if choice == 0: # [SEG] token
+ answers.append(random.choice(self.answer_list))
+ elif choice == 1: # [SEG] token + text answer
+ image_name = image_path.split("/")[-1]
+ answer = self.img_to_why[image_name]["outputs"]
+ answer = random.choice(self.answer_list) + " {}".format(answer)
+ questions[-1] = (
+ DEFAULT_IMAGE_TOKEN
+ + " "
+ + text
+ + " {}".format(random.choice(self.explanatory_question_list))
+ )
+ answers.append(answer)
+ elif choice == 2: # vanilla text answer
+ image_name = image_path.split("/")[-1]
+ answer = self.img_to_why[image_name]["outputs"]
+ questions[-1] = DEFAULT_IMAGE_TOKEN + " " + text
+ answers.append(answer)
+ else:
+ raise ValueError("Not implemented yet.")
+ else:
+ answers.append(random.choice(self.answer_list))
+
+ conversations = []
+ conv = get_default_conv_template("vicuna").copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ i = 0
+ while i < len(questions):
+ conv.messages = []
+ conv.append_message(conv.roles[0], questions[i])
+ conv.append_message(conv.roles[1], answers[i])
+ conversations.append(conv.get_prompt())
+ i += 1
+
+ # ==============================
+ # replace token
+ for i in range(len(conversations)):
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ conversations[i] = conversations[i].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+ # ==============================
+
+ images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
+
+ image_name = image_path.split("/")[-1]
+ if self.explanatory != -1 and image_name in self.img_to_why and choice == 2:
+ # print("e1")
+
+ masks = torch.rand(0, *ori_size)
+ label = torch.ones(ori_size) * self.ignore_label
+ else:
+ # print("e2")
+
+ masks = np.stack(sampled_masks, axis=0)
+ masks = torch.from_numpy(masks)
+ label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
+
+ # print("reason_seg: {}".format(conversations))
+
+ # # debug
+ # if masks.shape[0] != 0:
+ # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
+ # os.makedirs(save_dir, exist_ok=True)
+ # print("masks.shape: ", masks.shape)
+ # for i in range(masks.shape[0]):
+ # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i].numpy().astype(np.uint8)*100)
+ # assert len(conversations) == masks.shape[0]
+ # with open("{}/conversations.txt".format(save_dir), "w+") as f:
+ # for i in range(len(conversations)):
+ # f.write("{}. ".format(i) + conversations[i] + "\n")
+ # shutil.copy(image_path, save_dir)
+
+ return (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ label,
+ resize,
+ questions,
+ sampled_sents,
+ )
diff --git a/utils/refer.py b/utils/refer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b4cea716e40e73d0b5aa118143eb076392f5eb1
--- /dev/null
+++ b/utils/refer.py
@@ -0,0 +1,391 @@
+__author__ = "licheng"
+
+"""
+This interface provides access to four datasets:
+1) refclef
+2) refcoco
+3) refcoco+
+4) refcocog
+split by unc and google
+
+The following API functions are defined:
+REFER - REFER api class
+getRefIds - get ref ids that satisfy given filter conditions.
+getAnnIds - get ann ids that satisfy given filter conditions.
+getImgIds - get image ids that satisfy given filter conditions.
+getCatIds - get category ids that satisfy given filter conditions.
+loadRefs - load refs with the specified ref ids.
+loadAnns - load anns with the specified ann ids.
+loadImgs - load images with the specified image ids.
+loadCats - load category names with the specified category ids.
+getRefBox - get ref's bounding box [x, y, w, h] given the ref_id
+showRef - show image, segmentation or box of the referred object with the ref
+getMask - get mask and area of the referred object given ref
+showMask - show mask of the referred object given ref
+"""
+
+import itertools
+import json
+import os.path as osp
+import pickle
+import sys
+import time
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import numpy as np
+import skimage.io as io
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon, Rectangle
+from pycocotools import mask
+
+
+class REFER:
+ def __init__(self, data_root, dataset="refcoco", splitBy="unc"):
+ # provide data_root folder which contains refclef, refcoco, refcoco+ and refcocog
+ # also provide dataset name and splitBy information
+ # e.g., dataset = 'refcoco', splitBy = 'unc'
+ print("loading dataset %s into memory..." % dataset)
+ self.ROOT_DIR = osp.abspath(osp.dirname(__file__))
+ self.DATA_DIR = osp.join(data_root, dataset)
+ if dataset in ["refcoco", "refcoco+", "refcocog"]:
+ self.IMAGE_DIR = osp.join(data_root, "images/mscoco/images/train2014")
+ elif dataset == "refclef":
+ self.IMAGE_DIR = osp.join(data_root, "images/saiapr_tc-12")
+ else:
+ print("No refer dataset is called [%s]" % dataset)
+ sys.exit()
+
+ self.dataset = dataset
+
+ # load refs from data/dataset/refs(dataset).json
+ tic = time.time()
+
+ ref_file = osp.join(self.DATA_DIR, "refs(" + splitBy + ").p")
+ print("ref_file: ", ref_file)
+ self.data = {}
+ self.data["dataset"] = dataset
+ self.data["refs"] = pickle.load(open(ref_file, "rb"))
+
+ # load annotations from data/dataset/instances.json
+ instances_file = osp.join(self.DATA_DIR, "instances.json")
+ instances = json.load(open(instances_file, "rb"))
+ self.data["images"] = instances["images"]
+ self.data["annotations"] = instances["annotations"]
+ self.data["categories"] = instances["categories"]
+
+ # create index
+ self.createIndex()
+ print("DONE (t=%.2fs)" % (time.time() - tic))
+
+ def createIndex(self):
+ # create sets of mapping
+ # 1) Refs: {ref_id: ref}
+ # 2) Anns: {ann_id: ann}
+ # 3) Imgs: {image_id: image}
+ # 4) Cats: {category_id: category_name}
+ # 5) Sents: {sent_id: sent}
+ # 6) imgToRefs: {image_id: refs}
+ # 7) imgToAnns: {image_id: anns}
+ # 8) refToAnn: {ref_id: ann}
+ # 9) annToRef: {ann_id: ref}
+ # 10) catToRefs: {category_id: refs}
+ # 11) sentToRef: {sent_id: ref}
+ # 12) sentToTokens: {sent_id: tokens}
+ print("creating index...")
+ # fetch info from instances
+ Anns, Imgs, Cats, imgToAnns = {}, {}, {}, {}
+ for ann in self.data["annotations"]:
+ Anns[ann["id"]] = ann
+ imgToAnns[ann["image_id"]] = imgToAnns.get(ann["image_id"], []) + [ann]
+ for img in self.data["images"]:
+ Imgs[img["id"]] = img
+ for cat in self.data["categories"]:
+ Cats[cat["id"]] = cat["name"]
+
+ # fetch info from refs
+ Refs, imgToRefs, refToAnn, annToRef, catToRefs = {}, {}, {}, {}, {}
+ Sents, sentToRef, sentToTokens = {}, {}, {}
+ for ref in self.data["refs"]:
+ # ids
+ ref_id = ref["ref_id"]
+ ann_id = ref["ann_id"]
+ category_id = ref["category_id"]
+ image_id = ref["image_id"]
+
+ # add mapping related to ref
+ Refs[ref_id] = ref
+ imgToRefs[image_id] = imgToRefs.get(image_id, []) + [ref]
+ catToRefs[category_id] = catToRefs.get(category_id, []) + [ref]
+ refToAnn[ref_id] = Anns[ann_id]
+ annToRef[ann_id] = ref
+
+ # add mapping of sent
+ for sent in ref["sentences"]:
+ Sents[sent["sent_id"]] = sent
+ sentToRef[sent["sent_id"]] = ref
+ sentToTokens[sent["sent_id"]] = sent["tokens"]
+
+ # create class members
+ self.Refs = Refs
+ self.Anns = Anns
+ self.Imgs = Imgs
+ self.Cats = Cats
+ self.Sents = Sents
+ self.imgToRefs = imgToRefs
+ self.imgToAnns = imgToAnns
+ self.refToAnn = refToAnn
+ self.annToRef = annToRef
+ self.catToRefs = catToRefs
+ self.sentToRef = sentToRef
+ self.sentToTokens = sentToTokens
+ print("index created.")
+
+ def getRefIds(self, image_ids=[], cat_ids=[], ref_ids=[], split=""):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == len(split) == 0:
+ refs = self.data["refs"]
+ else:
+ if not len(image_ids) == 0:
+ refs = [self.imgToRefs[image_id] for image_id in image_ids]
+ else:
+ refs = self.data["refs"]
+ if not len(cat_ids) == 0:
+ refs = [ref for ref in refs if ref["category_id"] in cat_ids]
+ if not len(ref_ids) == 0:
+ refs = [ref for ref in refs if ref["ref_id"] in ref_ids]
+ if not len(split) == 0:
+ if split in ["testA", "testB", "testC"]:
+ refs = [
+ ref for ref in refs if split[-1] in ref["split"]
+ ] # we also consider testAB, testBC, ...
+ elif split in ["testAB", "testBC", "testAC"]:
+ refs = [
+ ref for ref in refs if ref["split"] == split
+ ] # rarely used I guess...
+ elif split == "test":
+ refs = [ref for ref in refs if "test" in ref["split"]]
+ elif split == "train" or split == "val":
+ refs = [ref for ref in refs if ref["split"] == split]
+ else:
+ print("No such split [%s]" % split)
+ sys.exit()
+ ref_ids = [ref["ref_id"] for ref in refs]
+ return ref_ids
+
+ def getAnnIds(self, image_ids=[], cat_ids=[], ref_ids=[]):
+ image_ids = image_ids if type(image_ids) == list else [image_ids]
+ cat_ids = cat_ids if type(cat_ids) == list else [cat_ids]
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if len(image_ids) == len(cat_ids) == len(ref_ids) == 0:
+ ann_ids = [ann["id"] for ann in self.data["annotations"]]
+ else:
+ if not len(image_ids) == 0:
+ lists = [
+ self.imgToAnns[image_id]
+ for image_id in image_ids
+ if image_id in self.imgToAnns
+ ] # list of [anns]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.data["annotations"]
+ if not len(cat_ids) == 0:
+ anns = [ann for ann in anns if ann["category_id"] in cat_ids]
+ ann_ids = [ann["id"] for ann in anns]
+ if not len(ref_ids) == 0:
+ ids = set(ann_ids).intersection(
+ set([self.Refs[ref_id]["ann_id"] for ref_id in ref_ids])
+ )
+ return ann_ids
+
+ def getImgIds(self, ref_ids=[]):
+ ref_ids = ref_ids if type(ref_ids) == list else [ref_ids]
+
+ if not len(ref_ids) == 0:
+ image_ids = list(set([self.Refs[ref_id]["image_id"] for ref_id in ref_ids]))
+ else:
+ image_ids = self.Imgs.keys()
+ return image_ids
+
+ def getCatIds(self):
+ return self.Cats.keys()
+
+ def loadRefs(self, ref_ids=[]):
+ if type(ref_ids) == list:
+ return [self.Refs[ref_id] for ref_id in ref_ids]
+ elif type(ref_ids) == int:
+ return [self.Refs[ref_ids]]
+
+ def loadAnns(self, ann_ids=[]):
+ if type(ann_ids) == list:
+ return [self.Anns[ann_id] for ann_id in ann_ids]
+ elif type(ann_ids) == int or type(ann_ids) == unicode:
+ return [self.Anns[ann_ids]]
+
+ def loadImgs(self, image_ids=[]):
+ if type(image_ids) == list:
+ return [self.Imgs[image_id] for image_id in image_ids]
+ elif type(image_ids) == int:
+ return [self.Imgs[image_ids]]
+
+ def loadCats(self, cat_ids=[]):
+ if type(cat_ids) == list:
+ return [self.Cats[cat_id] for cat_id in cat_ids]
+ elif type(cat_ids) == int:
+ return [self.Cats[cat_ids]]
+
+ def getRefBox(self, ref_id):
+ ref = self.Refs[ref_id]
+ ann = self.refToAnn[ref_id]
+ return ann["bbox"] # [x, y, w, h]
+
+ def showRef(self, ref, seg_box="seg"):
+ ax = plt.gca()
+ # show image
+ image = self.Imgs[ref["image_id"]]
+ I = io.imread(osp.join(self.IMAGE_DIR, image["file_name"]))
+ ax.imshow(I)
+ # show refer expression
+ for sid, sent in enumerate(ref["sentences"]):
+ print("%s. %s" % (sid + 1, sent["sent"]))
+ # show segmentations
+ if seg_box == "seg":
+ ann_id = ref["ann_id"]
+ ann = self.Anns[ann_id]
+ polygons = []
+ color = []
+ c = "none"
+ if type(ann["segmentation"][0]) == list:
+ # polygon used for refcoco*
+ for seg in ann["segmentation"]:
+ poly = np.array(seg).reshape((len(seg) / 2, 2))
+ polygons.append(Polygon(poly, True, alpha=0.4))
+ color.append(c)
+ p = PatchCollection(
+ polygons,
+ facecolors=color,
+ edgecolors=(1, 1, 0, 0),
+ linewidths=3,
+ alpha=1,
+ )
+ ax.add_collection(p) # thick yellow polygon
+ p = PatchCollection(
+ polygons,
+ facecolors=color,
+ edgecolors=(1, 0, 0, 0),
+ linewidths=1,
+ alpha=1,
+ )
+ ax.add_collection(p) # thin red polygon
+ else:
+ # mask used for refclef
+ rle = ann["segmentation"]
+ m = mask.decode(rle)
+ img = np.ones((m.shape[0], m.shape[1], 3))
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
+ for i in range(3):
+ img[:, :, i] = color_mask[i]
+ ax.imshow(np.dstack((img, m * 0.5)))
+ # show bounding-box
+ elif seg_box == "box":
+ ann_id = ref["ann_id"]
+ ann = self.Anns[ann_id]
+ bbox = self.getRefBox(ref["ref_id"])
+ box_plot = Rectangle(
+ (bbox[0], bbox[1]),
+ bbox[2],
+ bbox[3],
+ fill=False,
+ edgecolor="green",
+ linewidth=3,
+ )
+ ax.add_patch(box_plot)
+
+ def getMask(self, ref):
+ # return mask, area and mask-center
+ ann = self.refToAnn[ref["ref_id"]]
+ image = self.Imgs[ref["image_id"]]
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(ann["segmentation"], image["height"], image["width"])
+ else:
+ rle = ann["segmentation"]
+ m = mask.decode(rle)
+ m = np.sum(
+ m, axis=2
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ # compute area
+ area = sum(mask.area(rle)) # should be close to ann['area']
+ return {"mask": m, "area": area}
+ # # position
+ # position_x = np.mean(np.where(m==1)[1]) # [1] means columns (matlab style) -> x (c style)
+ # position_y = np.mean(np.where(m==1)[0]) # [0] means rows (matlab style) -> y (c style)
+ # # mass position (if there were multiple regions, we use the largest one.)
+ # label_m = label(m, connectivity=m.ndim)
+ # regions = regionprops(label_m)
+ # if len(regions) > 0:
+ # largest_id = np.argmax(np.array([props.filled_area for props in regions]))
+ # largest_props = regions[largest_id]
+ # mass_y, mass_x = largest_props.centroid
+ # else:
+ # mass_x, mass_y = position_x, position_y
+ # # if centroid is not in mask, we find the closest point to it from mask
+ # if m[mass_y, mass_x] != 1:
+ # print('Finding closes mask point ...')
+ # kernel = np.ones((10, 10),np.uint8)
+ # me = cv2.erode(m, kernel, iterations = 1)
+ # points = zip(np.where(me == 1)[0].tolist(), np.where(me == 1)[1].tolist()) # row, col style
+ # points = np.array(points)
+ # dist = np.sum((points - (mass_y, mass_x))**2, axis=1)
+ # id = np.argsort(dist)[0]
+ # mass_y, mass_x = points[id]
+ # # return
+ # return {'mask': m, 'area': area, 'position_x': position_x, 'position_y': position_y, 'mass_x': mass_x, 'mass_y': mass_y}
+ # # show image and mask
+ # I = io.imread(osp.join(self.IMAGE_DIR, image['file_name']))
+ # plt.figure()
+ # plt.imshow(I)
+ # ax = plt.gca()
+ # img = np.ones( (m.shape[0], m.shape[1], 3) )
+ # color_mask = np.array([2.0,166.0,101.0])/255
+ # for i in range(3):
+ # img[:,:,i] = color_mask[i]
+ # ax.imshow(np.dstack( (img, m*0.5) ))
+ # plt.show()
+
+ def showMask(self, ref):
+ M = self.getMask(ref)
+ msk = M["mask"]
+ ax = plt.gca()
+ ax.imshow(msk)
+
+
+if __name__ == "__main__":
+ refer = REFER(dataset="refcocog", splitBy="google")
+ ref_ids = refer.getRefIds()
+ print(len(ref_ids))
+
+ print(len(refer.Imgs))
+ print(len(refer.imgToRefs))
+
+ ref_ids = refer.getRefIds(split="train")
+ print("There are %s training referred objects." % len(ref_ids))
+
+ for ref_id in ref_ids:
+ ref = refer.loadRefs(ref_id)[0]
+ if len(ref["sentences"]) < 2:
+ continue
+
+ pprint(ref)
+ print("The label is %s." % refer.Cats[ref["category_id"]])
+ plt.figure()
+ refer.showRef(ref, seg_box="box")
+ plt.show()
+
+ # plt.figure()
+ # refer.showMask(ref)
+ # plt.show()
diff --git a/utils/refer_seg_dataset.py b/utils/refer_seg_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c54a07b8ab670b22c168f8e13439a4d9ec4aa0b
--- /dev/null
+++ b/utils/refer_seg_dataset.py
@@ -0,0 +1,272 @@
+import os
+import random
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from pycocotools import mask
+from transformers import CLIPImageProcessor
+
+from model.segment_anything.utils.transforms import ResizeLongestSide
+
+from .conversation import get_default_conv_template
+from .refer import REFER
+from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
+ SHORT_QUESTION_LIST)
+
+
+class ReferSegDataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch=500 * 8 * 2 * 10,
+ precision: str = "fp32",
+ image_size: int = 224,
+ num_classes_per_sample: int = 3,
+ exclude_val=False,
+ refer_seg_data="refclef||refcoco||refcoco+||refcocog",
+ ):
+ self.exclude_val = exclude_val
+ self.samples_per_epoch = samples_per_epoch
+ self.num_classes_per_sample = num_classes_per_sample
+
+ self.base_image_dir = base_image_dir
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.precision = precision
+ self.transform = ResizeLongestSide(image_size)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+
+ self.short_question_list = SHORT_QUESTION_LIST
+ self.answer_list = ANSWER_LIST
+
+ DATA_DIR = os.path.join(base_image_dir, "refer_seg")
+ self.refer_seg_ds_list = refer_seg_data.split(
+ "||"
+ ) # ['refclef', 'refcoco', 'refcoco+', 'refcocog', '']
+ self.refer_seg_data = {}
+ for ds in self.refer_seg_ds_list:
+ if ds == "refcocog":
+ splitBy = "umd"
+ else:
+ splitBy = "unc"
+ refer_api = REFER(DATA_DIR, ds, splitBy)
+ ref_ids_train = refer_api.getRefIds(split="train")
+ images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train)
+ refs_train = refer_api.loadRefs(ref_ids=ref_ids_train)
+ ref_file = os.path.join(DATA_DIR, ds, "refs(" + splitBy + ").p")
+
+ refer_seg_ds = {}
+ refer_seg_ds["images"] = []
+ loaded_images = refer_api.loadImgs(image_ids=images_ids_train)
+
+ for item in loaded_images:
+ item = item.copy()
+ if ds == "refclef":
+ item["file_name"] = os.path.join(
+ DATA_DIR, "images/saiapr_tc-12", item["file_name"]
+ )
+ else:
+ item["file_name"] = os.path.join(
+ DATA_DIR, "images/mscoco/images/train2014", item["file_name"]
+ )
+ refer_seg_ds["images"].append(item)
+ refer_seg_ds["annotations"] = refer_api.Anns # anns_train
+
+ print(
+ "dataset {} (refs {}) (train split) has {} images and {} annotations (before excluding: {} images)".format(
+ ds,
+ splitBy,
+ len(refer_seg_ds["images"]),
+ len(refer_seg_ds["annotations"]),
+ len(loaded_images),
+ )
+ )
+
+ img2refs = {}
+ for ref in refs_train:
+ image_id = ref["image_id"]
+ img2refs[image_id] = img2refs.get(image_id, []) + [
+ ref,
+ ]
+ refer_seg_ds["img2refs"] = img2refs
+ self.refer_seg_data[ds] = refer_seg_ds
+
+ def __len__(self):
+ return self.samples_per_epoch
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.img_size - h
+ padw = self.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def __getitem__(self, idx):
+ ds = random.randint(0, len(self.refer_seg_ds_list) - 1)
+ ds = self.refer_seg_ds_list[ds]
+ refer_seg_ds = self.refer_seg_data[ds]
+ images = refer_seg_ds["images"]
+ annotations = refer_seg_ds["annotations"]
+ img2refs = refer_seg_ds["img2refs"]
+ idx = random.randint(0, len(images) - 1)
+ image = images[idx]
+ image_path = image["file_name"]
+ image_id = image["id"]
+ refs = img2refs[image_id]
+ if len(refs) == 0:
+ return self.__getitem__(0)
+
+ sents = []
+ ann_ids = []
+ for ref in refs:
+ for sent in ref["sentences"]:
+ text = sent["sent"]
+ sents.append(text)
+ ann_ids.append(ref["ann_id"])
+ if len(sents) >= self.num_classes_per_sample:
+ sampled_inds = np.random.choice(
+ list(range(len(sents))), size=self.num_classes_per_sample, replace=False
+ )
+ else:
+ sampled_inds = list(range(len(sents)))
+ sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist()
+ sampled_ann_ids = np.vectorize(ann_ids.__getitem__)(sampled_inds).tolist()
+ sampled_classes = sampled_sents
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ ori_size = images.shape[:2]
+
+ # preprocess images for clip
+ images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")[
+ "pixel_values"
+ ][0]
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+ images = self.transform.apply_image(images) # preprocess images for sam
+ resize = images.shape[:2]
+
+ questions = []
+ answers = []
+ class_ids = []
+ for text in sampled_classes:
+ text = text.strip()
+ assert len(text.split("||")) == 1
+ question_template = random.choice(self.short_question_list)
+ questions.append(question_template.format(class_name=text.lower()))
+ answers.append(random.choice(self.answer_list))
+
+ conversations = []
+ conv = get_default_conv_template("vicuna").copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ i = 0
+ while i < len(questions):
+ conv.messages = []
+ conv.append_message(conv.roles[0], questions[i])
+ conv.append_message(conv.roles[1], answers[i])
+ conversations.append(conv.get_prompt())
+ i += 1
+
+ # ==============================
+ # replace token
+ for i in range(len(conversations)):
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ conversations[i] = conversations[i].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+ # ==============================
+
+ images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
+
+ masks = []
+ for ann_id in sampled_ann_ids:
+ ann = annotations[ann_id]
+
+ if len(ann["segmentation"]) == 0:
+ m = np.zeros((image["height"], image["width"])).astype(np.uint8)
+ masks.append(m)
+ continue
+
+ if type(ann["segmentation"][0]) == list: # polygon
+ rle = mask.frPyObjects(
+ ann["segmentation"], image["height"], image["width"]
+ )
+ else:
+ rle = ann["segmentation"]
+ for i in range(len(rle)):
+ if not isinstance(rle[i]["counts"], bytes):
+ rle[i]["counts"] = rle[i]["counts"].encode()
+ m = mask.decode(rle)
+ m = np.sum(
+ m, axis=2
+ ) # sometimes there are multiple binary map (corresponding to multiple segs)
+ m = m.astype(np.uint8) # convert to np.uint8
+ masks.append(m)
+
+ masks = np.stack(masks, axis=0)
+
+ # debug
+ # print("masks.shape: ", masks.shape)
+ # for i in range(masks.shape[0]):
+ # cv2.imwrite("debug/{}_mask_{}.png".format(image_path.split("refer_seg/images")[-1].replace("/", "-").split(".")[0], sampled_sents[i]), masks[i]*100)
+
+ # debug
+ # if ds.endswith("masked"):
+ # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
+ # os.makedirs(save_dir, exist_ok=True)
+ # print("masks.shape: ", masks.shape)
+ # for i in range(masks.shape[0]):
+ # cv2.imwrite("{}/mask_{}.jpg".format(save_dir, i), masks[i]*100)
+ # assert len(conversations) == masks.shape[0]
+ # with open("{}/conversations.txt".format(save_dir), "w+") as f:
+ # for i in range(len(conversations)):
+ # f.write("{}. ".format(i) + conversations[i] + "\n")
+ # shutil.copy(image_path, save_dir)
+
+ masks = torch.from_numpy(masks)
+ label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
+
+ # print("refer_seg: {}".format(conversations))
+
+ # # debug
+ # save_dir = "./debug/{}".format(image_path.split("/")[-1].split(".")[0])
+ # os.makedirs(save_dir, exist_ok=True)
+ # print("masks.shape: ", masks.shape)
+ # for i in range(masks.shape[0]):
+ # cv2.imwrite("{}/mask_{}_{}.jpg".format(save_dir, i, sampled_classes[i]), masks[i].numpy().astype(np.uint8)*100)
+ # assert len(conversations) == masks.shape[0]
+ # with open("{}/conversations.txt".format(save_dir), "w+") as f:
+ # for i in range(len(conversations)):
+ # f.write("{}. ".format(i) + conversations[i] + "\n")
+ # shutil.copy(image_path, save_dir)
+
+ return (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ label,
+ resize,
+ questions,
+ sampled_classes,
+ )
diff --git a/utils/sem_seg_dataset.py b/utils/sem_seg_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83aa1b389a6ed3ad5e19b88aad06438a7e61164
--- /dev/null
+++ b/utils/sem_seg_dataset.py
@@ -0,0 +1,359 @@
+import glob
+import json
+import os
+import random
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from pycocotools.coco import COCO
+from transformers import CLIPImageProcessor
+
+from model.segment_anything.utils.transforms import ResizeLongestSide
+
+from .conversation import get_default_conv_template
+from .utils import (ANSWER_LIST, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN,
+ SHORT_QUESTION_LIST)
+
+def init_mapillary(base_image_dir):
+ mapillary_data_root = os.path.join(base_image_dir, "mapillary")
+ with open(os.path.join(mapillary_data_root, "config_v2.0.json")) as f:
+ mapillary_classes = json.load(f)["labels"]
+ mapillary_classes = [x["readable"].lower() for x in mapillary_classes]
+ mapillary_classes = np.array(mapillary_classes)
+ mapillary_labels = sorted(
+ glob.glob(
+ os.path.join(mapillary_data_root, "training", "v2.0", "labels", "*.png")
+ )
+ )
+ mapillary_images = [
+ x.replace(".png", ".jpg").replace("v2.0/labels", "images")
+ for x in mapillary_labels
+ ]
+ print("mapillary: ", len(mapillary_images))
+ return mapillary_classes, mapillary_images, mapillary_labels
+
+
+def init_ade20k(base_image_dir):
+ with open("utils/ade20k_classes.json", "r") as f:
+ ade20k_classes = json.load(f)
+ ade20k_classes = np.array(ade20k_classes)
+ image_ids = sorted(
+ os.listdir(os.path.join(base_image_dir, "ade20k/images", "training"))
+ )
+ ade20k_image_ids = []
+ for x in image_ids:
+ if x.endswith(".jpg"):
+ ade20k_image_ids.append(x[:-4])
+ ade20k_images = []
+ for image_id in ade20k_image_ids: # self.descriptions:
+ ade20k_images.append(
+ os.path.join(
+ base_image_dir,
+ "ade20k",
+ "images",
+ "training",
+ "{}.jpg".format(image_id),
+ )
+ )
+ ade20k_labels = [
+ x.replace(".jpg", ".png").replace("images", "annotations")
+ for x in ade20k_images
+ ]
+ print("ade20k: ", len(ade20k_images))
+ return ade20k_classes, ade20k_images, ade20k_labels
+
+
+def init_cocostuff(base_image_dir):
+ cocostuff_classes = []
+ with open("utils/cocostuff_classes.txt") as f:
+ for line in f.readlines()[1:]:
+ cocostuff_classes.append(line.strip().split(": ")[-1])
+ cocostuff_classes = np.array(cocostuff_classes)
+ cocostuff_images = []
+ cocostuff_image_dir = glob.glob(
+ os.path.join(base_image_dir, "cocostuff", "train2017", "*.jpg")
+ )
+ for image_id in cocostuff_image_dir:
+ cocostuff_images.append(image_id)
+ cocostuff_labels = [
+ x.replace(".jpg", ".png").replace("images", "annotations")
+ for x in cocostuff_images
+ ]
+ print("cocostuff: ", len(cocostuff_images))
+ return cocostuff_classes, cocostuff_images, cocostuff_labels
+
+
+def init_paco_lvis(base_image_dir):
+ coco_api_paco_lvis = COCO(
+ os.path.join(
+ base_image_dir, "vlpart", "paco", "annotations", "paco_lvis_v1_train.json"
+ )
+ )
+ all_classes = coco_api_paco_lvis.loadCats(coco_api_paco_lvis.getCatIds())
+ class_map_paco_lvis = {}
+ for cat in all_classes:
+ cat_split = cat["name"].strip().split(":")
+ if len(cat_split) == 1:
+ name = cat_split[0].split("_(")[0]
+ else:
+ assert len(cat_split) == 2
+ obj, part = cat_split
+ obj = obj.split("_(")[0]
+ part = part.split("_(")[0]
+ # if random.random() < 0.5:
+ # name = obj + " " + part
+ # else:
+ # name = "the {} of the {}".format(part, obj)
+ name = (obj, part)
+ class_map_paco_lvis[cat["id"]] = name
+ img_ids = coco_api_paco_lvis.getImgIds()
+ print("paco_lvis: ", len(img_ids))
+ return class_map_paco_lvis, img_ids, coco_api_paco_lvis
+
+
+def init_pascal_part(base_image_dir):
+ coco_api_pascal_part = COCO(
+ os.path.join(base_image_dir, "vlpart", "pascal_part", "train.json")
+ )
+ all_classes = coco_api_pascal_part.loadCats(coco_api_pascal_part.getCatIds())
+ class_map_pascal_part = {}
+ for cat in all_classes:
+ cat_main, cat_part = cat["name"].strip().split(":")
+ name = (cat_main, cat_part)
+ class_map_pascal_part[cat["id"]] = name
+ img_ids = coco_api_pascal_part.getImgIds()
+ print("pascal_part: ", len(img_ids))
+ return class_map_pascal_part, img_ids, coco_api_pascal_part
+
+
+class SemSegDataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch=500 * 8 * 2 * 10,
+ precision: str = "fp32",
+ image_size: int = 224,
+ num_classes_per_sample: int = 3,
+ exclude_val=False,
+ sem_seg_data="ade20k||cocostuff||partimagenet||pascal_part||paco_lvis||mapillary",
+ ):
+ self.exclude_val = exclude_val
+ self.samples_per_epoch = samples_per_epoch
+ self.num_classes_per_sample = num_classes_per_sample
+
+ self.base_image_dir = base_image_dir
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.precision = precision
+ self.transform = ResizeLongestSide(image_size)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+
+ self.short_question_list = SHORT_QUESTION_LIST
+ self.answer_list = ANSWER_LIST
+
+ self.data2list = {}
+ self.data2classes = {}
+
+ self.sem_seg_datas = sem_seg_data.split("||")
+ for ds in self.sem_seg_datas:
+ classes, images, labels = eval("init_{}".format(ds))(base_image_dir)
+ self.data2list[ds] = (images, labels)
+ self.data2classes[ds] = classes
+
+ if "cocostuff" in self.sem_seg_datas:
+ self.cocostuff_class2index = {
+ c: i for i, c in enumerate(self.data2classes["cocostuff"])
+ }
+
+ def __len__(self):
+ return self.samples_per_epoch
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.img_size - h
+ padw = self.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def __getitem__(self, idx):
+ ds = random.randint(0, len(self.sem_seg_datas) - 1)
+ ds = self.sem_seg_datas[ds]
+
+ if ds in ["paco_lvis", "pascal_part"]:
+ class_map = self.data2classes[ds]
+ img_ids, coco_api = self.data2list[ds]
+ idx = random.randint(0, len(img_ids) - 1)
+ img_id = img_ids[idx]
+ image = coco_api.loadImgs([img_id])[0]
+ file_name = image["file_name"]
+ if ds == "pascal_part":
+ file_name = os.path.join(
+ "VOCdevkit", "VOC2010", "JPEGImages", file_name
+ )
+ image_path = os.path.join(self.base_image_dir, "vlpart", ds, file_name)
+ elif ds == "paco_lvis":
+ image_path = os.path.join(self.base_image_dir, "coco", file_name)
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # preprocess images for clip
+ images_clip = self.clip_image_processor.preprocess(
+ images, return_tensors="pt"
+ )["pixel_values"][0]
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+
+ images = self.transform.apply_image(images) # preprocess images for sam
+ resize = images.shape[:2]
+ annIds = coco_api.getAnnIds(imgIds=image["id"])
+ anns = coco_api.loadAnns(annIds)
+ if len(anns) == 0:
+ return self.__getitem__(0)
+ if len(anns) >= self.num_classes_per_sample:
+ sampled_anns = np.random.choice(
+ anns, size=self.num_classes_per_sample, replace=False
+ ).tolist()
+ else:
+ sampled_anns = anns
+ sampled_classes = []
+ for ann in sampled_anns:
+ sampled_cls = class_map[ann["category_id"]]
+ if isinstance(sampled_cls, tuple):
+ obj, part = sampled_cls
+ if random.random() < 0.5:
+ name = obj + " " + part
+ else:
+ name = "the {} of the {}".format(part, obj)
+ else:
+ name = sampled_cls
+ sampled_classes.append(name)
+
+ elif ds in ["ade20k", "cocostuff", "mapillary"]:
+ images, labels = self.data2list[ds]
+ idx = random.randint(0, len(images) - 1)
+ image_path = images[idx]
+ label_path = labels[idx]
+ label = Image.open(label_path)
+ label = np.array(label)
+ if ds == "ade20k":
+ label[label == 0] = 255
+ label -= 1
+ label[label == 254] = 255
+ elif ds == "cocostuff":
+ for c, i in self.cocostuff_class2index.items():
+ if "-" in c:
+ label[label == i] = 255
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ # preprocess images for clip
+ images_clip = self.clip_image_processor.preprocess(
+ images, return_tensors="pt"
+ )["pixel_values"][0]
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+ images = self.transform.apply_image(images) # preprocess images for sam
+ resize = images.shape[:2]
+ unique_label = np.unique(label).tolist()
+ if 255 in unique_label:
+ unique_label.remove(255)
+ if len(unique_label) == 0:
+ return self.__getitem__(0)
+
+ classes = [self.data2classes[ds][class_id] for class_id in unique_label]
+ if len(classes) >= self.num_classes_per_sample:
+ sampled_classes = np.random.choice(
+ classes, size=self.num_classes_per_sample, replace=False
+ ).tolist()
+ else:
+ sampled_classes = classes
+
+ questions = []
+ answers = []
+ class_ids = []
+ for sampled_cls in sampled_classes:
+ text = sampled_cls
+
+ assert len(text.split("||")) == 1
+ question_template = random.choice(self.short_question_list)
+ questions.append(question_template.format(class_name=text.lower()))
+
+ answers.append(random.choice(self.answer_list))
+
+ if ds in ["paco_lvis", "pascal_part"]:
+ continue
+
+ class_id = self.data2classes[ds].tolist().index(sampled_cls)
+ class_ids.append(class_id)
+
+ conversations = []
+ conv = get_default_conv_template("vicuna").copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+
+ i = 0
+ while i < len(questions):
+ conv.messages = []
+ conv.append_message(conv.roles[0], questions[i])
+ conv.append_message(conv.roles[1], answers[i])
+ conversations.append(conv.get_prompt())
+ i += 1
+
+ # replace token
+ for i in range(len(conversations)):
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ conversations[i] = conversations[i].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+
+ images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
+
+ if ds in ["paco_lvis", "pascal_part"]:
+ masks = []
+ for ann in sampled_anns:
+ try:
+ masks.append(coco_api.annToMask(ann))
+ except Exception as e:
+ print(e)
+ return self.__getitem__(0)
+
+ masks = np.stack(masks, axis=0)
+ masks = torch.from_numpy(masks)
+ label = torch.ones(masks.shape[1], masks.shape[2]) * self.ignore_label
+
+ else:
+ label = torch.from_numpy(label).long()
+ masks = []
+ for class_id in class_ids:
+ masks.append(label == class_id)
+ masks = torch.stack(masks, dim=0)
+ return (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ label,
+ resize,
+ questions,
+ sampled_classes,
+ )
diff --git a/utils/utils.py b/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..275aa832daeb5136ef10d1774cdfaa9bd1ae5bae
--- /dev/null
+++ b/utils/utils.py
@@ -0,0 +1,156 @@
+from enum import Enum
+
+import numpy as np
+import torch
+import torch.distributed as dist
+
+DEFAULT_IMAGE_TOKEN = ""
+DEFAULT_IMAGE_PATCH_TOKEN = ""
+DEFAULT_IM_START_TOKEN = ""
+DEFAULT_IM_END_TOKEN = ""
+
+SHORT_QUESTION_LIST = [
+ DEFAULT_IMAGE_TOKEN + " " + "Can you segment the {class_name} in this image?",
+ DEFAULT_IMAGE_TOKEN + " " + "Please segment the {class_name} in this image.",
+ DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please respond with segmentation mask.",
+ DEFAULT_IMAGE_TOKEN + " " + "What is {class_name} in this image? Please output segmentation mask.",
+]
+
+LONG_QUESTION_LIST = [
+ DEFAULT_IMAGE_TOKEN + " " + "{sent} Please respond with segmentation mask.",
+ DEFAULT_IMAGE_TOKEN + " " + "{sent} Please output segmentation mask.",
+]
+
+EXPLANATORY_QUESTION_LIST = [
+ "Please output segmentation mask and explain why.",
+ "Please output segmentation mask and explain the reason.",
+ "Please output segmentation mask and give some explaination.",
+]
+
+ANSWER_LIST = [
+ "It is [SEG].",
+ "Sure, [SEG].",
+ "Sure, it is [SEG].",
+ "Sure, the segmentation result is [SEG].",
+ "[SEG].",
+]
+
+
+class Summary(Enum):
+ NONE = 0
+ AVERAGE = 1
+ SUM = 2
+ COUNT = 3
+
+
+class AverageMeter(object):
+ """Computes and stores the average and current value"""
+
+ def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
+ self.name = name
+ self.fmt = fmt
+ self.summary_type = summary_type
+ self.reset()
+
+ def reset(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+ def all_reduce(self):
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ if isinstance(self.sum, np.ndarray):
+ total = torch.tensor(
+ self.sum.tolist()
+ + [
+ self.count,
+ ],
+ dtype=torch.float32,
+ device=device,
+ )
+ else:
+ total = torch.tensor(
+ [self.sum, self.count], dtype=torch.float32, device=device
+ )
+
+ dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
+ if total.shape[0] > 2:
+ self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
+ else:
+ self.sum, self.count = total.tolist()
+ self.avg = self.sum / (self.count + 1e-5)
+
+ def __str__(self):
+ fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
+ return fmtstr.format(**self.__dict__)
+
+ def summary(self):
+ fmtstr = ""
+ if self.summary_type is Summary.NONE:
+ fmtstr = ""
+ elif self.summary_type is Summary.AVERAGE:
+ fmtstr = "{name} {avg:.3f}"
+ elif self.summary_type is Summary.SUM:
+ fmtstr = "{name} {sum:.3f}"
+ elif self.summary_type is Summary.COUNT:
+ fmtstr = "{name} {count:.3f}"
+ else:
+ raise ValueError("invalid summary type %r" % self.summary_type)
+
+ return fmtstr.format(**self.__dict__)
+
+
+def intersectionAndUnionGPU(output, target, K, ignore_index=255):
+ # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
+ assert output.dim() in [1, 2, 3]
+ assert output.shape == target.shape
+ output = output.view(-1)
+ target = target.view(-1)
+ output[target == ignore_index] = ignore_index
+ intersection = output[output == target]
+ area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
+ area_output = torch.histc(output, bins=K, min=0, max=K - 1)
+ area_target = torch.histc(target, bins=K, min=0, max=K - 1)
+ area_union = area_output + area_target - area_intersection
+ return area_intersection, area_union, area_target
+
+class ProgressMeter(object):
+ def __init__(self, num_batches, meters, prefix=""):
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+ self.meters = meters
+ self.prefix = prefix
+
+ def display(self, batch):
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
+ entries += [str(meter) for meter in self.meters]
+ print("\t".join(entries))
+
+ def display_summary(self):
+ entries = [" *"]
+ entries += [meter.summary() for meter in self.meters]
+ print(" ".join(entries))
+
+ def _get_batch_fmtstr(self, num_batches):
+ num_digits = len(str(num_batches // 1))
+ fmt = "{:" + str(num_digits) + "d}"
+ return "[" + fmt + "/" + fmt.format(num_batches) + "]"
+
+
+def dict_to_cuda(input_dict):
+ for k, v in input_dict.items():
+ if isinstance(input_dict[k], torch.Tensor):
+ input_dict[k] = v.cuda(non_blocking=True)
+ elif (
+ isinstance(input_dict[k], list)
+ and len(input_dict[k]) > 0
+ and isinstance(input_dict[k][0], torch.Tensor)
+ ):
+ input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
+ return input_dict
diff --git a/utils/vqa_dataset.py b/utils/vqa_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9aeb322a9f4fdbf5bd6659d133fcd000a737d48d
--- /dev/null
+++ b/utils/vqa_dataset.py
@@ -0,0 +1,126 @@
+import json
+import os
+import random
+
+import cv2
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor
+
+from model.segment_anything.utils.transforms import ResizeLongestSide
+
+from .conversation import get_default_conv_template
+from .utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
+ DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN)
+
+class VQADataset(torch.utils.data.Dataset):
+ pixel_mean = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
+ pixel_std = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
+ img_size = 1024
+ ignore_label = 255
+
+ def __init__(
+ self,
+ base_image_dir,
+ tokenizer,
+ vision_tower,
+ samples_per_epoch=500 * 8 * 2 * 10,
+ precision: str = "fp32",
+ image_size: int = 224,
+ num_classes_per_sample: int = 3,
+ exclude_val=False,
+ vqa_data="llava_instruct_150k",
+ ):
+ self.exclude_val = exclude_val
+ self.samples_per_epoch = samples_per_epoch
+ self.num_classes_per_sample = num_classes_per_sample
+
+ self.base_image_dir = base_image_dir
+ self.image_size = image_size
+ self.tokenizer = tokenizer
+ self.precision = precision
+ self.transform = ResizeLongestSide(image_size)
+ self.clip_image_processor = CLIPImageProcessor.from_pretrained(vision_tower)
+
+ DATA_DIR = os.path.join(base_image_dir, "llava_dataset")
+ self.vqa_image_root = os.path.join(base_image_dir, "coco/train2017")
+ with open(os.path.join(DATA_DIR, "{}.json".format(vqa_data))) as f:
+ vqa_data = json.load(f)
+ self.vqa_data = vqa_data
+
+ print("vqa_data: ", len(self.vqa_data))
+
+ def __len__(self):
+ return self.samples_per_epoch
+
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
+ """Normalize pixel values and pad to a square input."""
+ # Normalize colors
+ x = (x - self.pixel_mean) / self.pixel_std
+
+ # Pad
+ h, w = x.shape[-2:]
+ padh = self.img_size - h
+ padw = self.img_size - w
+ x = F.pad(x, (0, padw, 0, padh))
+ return x
+
+ def __getitem__(self, idx):
+ idx = random.randint(0, len(self.vqa_data) - 1)
+ item = self.vqa_data[idx]
+ image_path = os.path.join(self.vqa_image_root, item["image"])
+ img = cv2.imread(image_path)
+ images = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ ori_size = images.shape[:2]
+ images_clip = self.clip_image_processor.preprocess(images, return_tensors="pt")["pixel_values"][0] # preprocess images for clip
+ image_token_len = (images_clip.shape[1] // 14) * (
+ images_clip.shape[2] // 14
+ ) # FIXME: 14 is hardcoded patch size
+
+ images = self.transform.apply_image(images) # preprocess images for sam
+ resize = images.shape[:2]
+ source = item["conversations"]
+ conv = get_default_conv_template(
+ "vicuna"
+ ).copy() # conversation_lib.default_conversation.copy()
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
+ conversations = []
+ if roles[source[0]["from"]] != conv.roles[0]:
+ # Skip the first one if it is not from human
+ source = source[1:]
+ conv.messages = []
+ for j, sentence in enumerate(source):
+ role = roles[sentence["from"]]
+ assert role == conv.roles[j % 2], f"{i}"
+ conv.append_message(role, sentence["value"])
+ conversations.append(conv.get_prompt())
+
+ questions = conversations
+ sampled_classes = conversations
+
+ # replace token
+ for i in range(len(conversations)):
+ replace_token = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
+ replace_token = (
+ DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
+ )
+ conversations[i] = conversations[i].replace(
+ DEFAULT_IMAGE_TOKEN, replace_token
+ )
+
+ images = self.preprocess(torch.from_numpy(images).permute(2, 0, 1).contiguous())
+
+ masks = torch.rand(0, *ori_size)
+ label = torch.ones(ori_size) * self.ignore_label
+
+ return (
+ image_path,
+ images,
+ images_clip,
+ conversations,
+ masks,
+ label,
+ resize,
+ questions,
+ sampled_classes,
+ )