Jegree commited on
Commit
307292d
·
verified ·
1 Parent(s): 2d491f1

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +83 -83
models.py CHANGED
@@ -1,83 +1,83 @@
1
- """Model-related code and constants."""
2
-
3
- import dataclasses
4
- import os
5
- import re
6
-
7
- import PIL.Image
8
-
9
- # pylint: disable=g-bad-import-order
10
- import gradio_helpers
11
- import paligemma_bv
12
-
13
-
14
- ORGANIZATION = 'Jegree'
15
- BASE_MODELS = [
16
- # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
- # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
- ('myPaligem', 'fine-tuned-paligemma-3b-pt-224')
19
- ]
20
- MODELS = {
21
- # **{
22
- # model_name: (
23
- # f'{ORGANIZATION}/{repo}',
24
- # f'{model_name}.bf16.npz',
25
- # 'bfloat16', # Model repo revision.
26
- # )
27
- # for repo, model_name in BASE_MODELS
28
- # },
29
- 'fine-tuned-paligemma-3b-pt-224':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
30
- }
31
-
32
- MODELS_INFO = {
33
- 'fine-tuned-paligemma-3b-pt-224': (
34
- 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
35
- 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
36
- 'bfloat16 and float16 format for research purposes only.'
37
- ),
38
- }
39
-
40
- MODELS_RES_SEQ = {
41
- 'fine-tuned-paligemma-3b-pt-224': (224, 128),
42
- }
43
-
44
- # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
45
- # Below value should be smaller than "available RAM - one model".
46
- # A single bf16 is about 5860 MB.
47
- MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
48
-
49
- config = paligemma_bv.PaligemmaConfig(
50
- ckpt='', # will be set below
51
- res=224,
52
- text_len=64,
53
- tokenizer='gemma(tokensets=("loc", "seg"))',
54
- vocab_size=256_000 + 1024 + 128,
55
- )
56
-
57
-
58
- def get_cached_model(
59
- model_name: str,
60
- ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
61
- """Returns model and params, using RAM cache."""
62
- res, seq = MODELS_RES_SEQ[model_name]
63
- model_path = gradio_helpers.get_paths()[model_name]
64
- config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
65
- model, params_cpu = gradio_helpers.get_memory_cache(
66
- config_,
67
- lambda: paligemma_bv.load_model(config_),
68
- max_cache_size_bytes=MAX_RAM_CACHE,
69
- )
70
- return model, params_cpu
71
-
72
-
73
- def generate(
74
- model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
75
- ) -> str:
76
- """Generates output with specified `model_name`, `sampler`."""
77
- model, params_cpu = get_cached_model(model_name)
78
- batch = model.shard_batch(model.prepare_batch([image], [prompt]))
79
- with gradio_helpers.timed('sharding'):
80
- params = model.shard_params(params_cpu)
81
- with gradio_helpers.timed('computation', start_message=True):
82
- tokens = model.predict(params, batch, sampler=sampler)
83
- return model.tokenizer.to_str(tokens[0])
 
1
+ """Model-related code and constants."""
2
+
3
+ import dataclasses
4
+ import os
5
+ import re
6
+
7
+ import PIL.Image
8
+
9
+ # pylint: disable=g-bad-import-order
10
+ import gradio_helpers
11
+ import paligemma_bv
12
+
13
+
14
+ ORGANIZATION = 'Jegree'
15
+ BASE_MODELS = [
16
+ # ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
17
+ # ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
18
+ ('myPaligem', 'fine-tuned-paligemma-3b-pt-224')
19
+ ]
20
+ MODELS = {
21
+ # **{
22
+ # model_name: (
23
+ # f'{ORGANIZATION}/{repo}',
24
+ # f'{model_name}.bf16.npz',
25
+ # 'bfloat16', # Model repo revision.
26
+ # )
27
+ # for repo, model_name in BASE_MODELS
28
+ # },
29
+ 'fine-tuned-paligemma-3b-pt-224':('Jegree/myPaligem', 'fine-tuned-paligemma-3b-pt-224.f16.npz', 'main'),
30
+ }
31
+
32
+ MODELS_INFO = {
33
+ 'fine-tuned-paligemma-3b-pt-224': (
34
+ 'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
35
+ 'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
36
+ 'bfloat16 and float16 format for research purposes only.'
37
+ ),
38
+ }
39
+
40
+ MODELS_RES_SEQ = {
41
+ 'fine-tuned-paligemma-3b-pt-224': (224, 128),
42
+ }
43
+
44
+ # "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
45
+ # Below value should be smaller than "available RAM - one model".
46
+ # A single bf16 is about 5860 MB.
47
+ MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
48
+
49
+ config = paligemma_bv.PaligemmaConfig(
50
+ ckpt='', # will be set below
51
+ res=224,
52
+ text_len=128,
53
+ tokenizer='gemma(tokensets=("loc", "seg"))',
54
+ vocab_size=256_000 + 1024 + 128,
55
+ )
56
+
57
+
58
+ def get_cached_model(
59
+ model_name: str,
60
+ ) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
61
+ """Returns model and params, using RAM cache."""
62
+ res, seq = MODELS_RES_SEQ[model_name]
63
+ model_path = gradio_helpers.get_paths()[model_name]
64
+ config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
65
+ model, params_cpu = gradio_helpers.get_memory_cache(
66
+ config_,
67
+ lambda: paligemma_bv.load_model(config_),
68
+ max_cache_size_bytes=MAX_RAM_CACHE,
69
+ )
70
+ return model, params_cpu
71
+
72
+
73
+ def generate(
74
+ model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
75
+ ) -> str:
76
+ """Generates output with specified `model_name`, `sampler`."""
77
+ model, params_cpu = get_cached_model(model_name)
78
+ batch = model.shard_batch(model.prepare_batch([image], [prompt]))
79
+ with gradio_helpers.timed('sharding'):
80
+ params = model.shard_params(params_cpu)
81
+ with gradio_helpers.timed('computation', start_message=True):
82
+ tokens = model.predict(params, batch, sampler=sampler)
83
+ return model.tokenizer.to_str(tokens[0])