Jegree commited on
Commit
97f3677
·
verified ·
1 Parent(s): 306aed7

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +16 -20
models.py CHANGED
@@ -11,38 +11,34 @@ import gradio_helpers
11
  import paligemma_bv
12
 
13
 
14
- ORGANIZATION = 'google'
15
  BASE_MODELS = [
16
- ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
- ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
 
18
  ]
19
  MODELS = {
20
- **{
21
- model_name: (
22
- f'{ORGANIZATION}/{repo}',
23
- f'{model_name}.bf16.npz',
24
- 'bfloat16', # Model repo revision.
25
- )
26
- for repo, model_name in BASE_MODELS
27
- },
 
28
  }
29
 
30
  MODELS_INFO = {
31
- 'paligemma-3b-mix-224': (
32
  'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
33
  'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
34
  'bfloat16 and float16 format for research purposes only.'
35
  ),
36
- 'paligemma-3b-mix-448': (
37
- 'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
38
- 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
39
- 'bfloat16 and float16 format for research purposes only.'
40
- ),
41
  }
42
 
43
  MODELS_RES_SEQ = {
44
- 'paligemma-3b-mix-224': (224, 256),
45
- 'paligemma-3b-mix-448': (448, 512),
46
  }
47
 
48
  # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
@@ -53,7 +49,7 @@ MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
53
  config = paligemma_bv.PaligemmaConfig(
54
  ckpt='', # will be set below
55
  res=224,
56
- text_len=64,
57
  tokenizer='gemma(tokensets=("loc", "seg"))',
58
  vocab_size=256_000 + 1024 + 128,
59
  )
 
11
  import paligemma_bv
12
 
13
 
14
+ ORGANIZATION = 'Jegree'
15
  BASE_MODELS = [
16
+ # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
+ # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
+ ('myPaligem', 'fine-tuned-paligemma-3b-pt-224')
19
  ]
20
  MODELS = {
21
+ # **{
22
+ # model_name: (
23
+ # f'{ORGANIZATION}/{repo}',
24
+ # f'{model_name}.bf16.npz',
25
+ # 'bfloat16', # Model repo revision.
26
+ # )
27
+ # for repo, model_name in BASE_MODELS
28
+ # },
29
+ 'fine-tuned-paligemma-3b-pt-224':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
30
  }
31
 
32
  MODELS_INFO = {
33
+ 'fine-tuned-paligemma-3b-pt-224': (
34
  'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
35
  'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
36
  'bfloat16 and float16 format for research purposes only.'
37
  ),
 
 
 
 
 
38
  }
39
 
40
  MODELS_RES_SEQ = {
41
+ 'fine-tuned-paligemma-3b-pt-224': (224, 128),
 
42
  }
43
 
44
  # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
 
49
  config = paligemma_bv.PaligemmaConfig(
50
  ckpt='', # will be set below
51
  res=224,
52
+ text_len=128,
53
  tokenizer='gemma(tokensets=("loc", "seg"))',
54
  vocab_size=256_000 + 1024 + 128,
55
  )