initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +174 -0
- .vscode/launch.json +15 -0
- LICENSE +201 -0
- README.md +266 -11
- app.py +0 -7
- benchmark/__init__.py +0 -0
- benchmark/create_benchmark.py +352 -0
- benchmark/llm.py +42 -0
- benchmark/utils.py +78 -0
- data/eurorad_metadata.json +0 -0
- data/figures.py +74 -0
- data/get_cases.py +51 -0
- experiments/README.md +63 -0
- experiments/analyze_axes.py +385 -0
- experiments/benchmark_chexagent.py +316 -0
- experiments/benchmark_gpt4o.py +331 -0
- experiments/benchmark_llama.py +443 -0
- experiments/benchmark_llavamed.py +541 -0
- experiments/benchmark_medrax.ipynb +374 -0
- experiments/chexbench_gpt4.py +405 -0
- experiments/compare_runs.py +290 -0
- experiments/inspect_logs.py +210 -0
- experiments/validate_logs.py +162 -0
- interface.py +279 -0
- main.py +141 -0
- medrax/__init__.py +0 -0
- medrax/agent/__init__.py +1 -0
- medrax/agent/agent.py +193 -0
- medrax/docs/system_prompts.txt +9 -0
- medrax/llava/__init__.py +0 -0
- medrax/llava/constants.py +13 -0
- medrax/llava/conversation.py +448 -0
- medrax/llava/eval/eval_multimodal_chat_gpt_score.py +143 -0
- medrax/llava/eval/llm.py +154 -0
- medrax/llava/eval/model_vqa.py +133 -0
- medrax/llava/eval/summarize_gpt_review.py +62 -0
- medrax/llava/eval/util.py +10 -0
- medrax/llava/mm_utils.py +121 -0
- medrax/llava/model/__init__.py +1 -0
- medrax/llava/model/builder.py +134 -0
- medrax/llava/model/language_model/llava_mistral.py +144 -0
- medrax/llava/model/llava_arch.py +396 -0
- medrax/llava/model/multimodal_encoder/builder.py +15 -0
- medrax/llava/model/multimodal_encoder/clip_encoder.py +83 -0
- medrax/llava/model/multimodal_projector/builder.py +49 -0
- medrax/llava/serve/__init__.py +0 -0
- medrax/llava/serve/cli.py +152 -0
- medrax/llava/serve/controller.py +299 -0
- medrax/llava/serve/gradio_web_server.py +532 -0
- medrax/llava/serve/model_worker.py +337 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
+
# ruff
|
165 |
+
ruff-cache/
|
166 |
+
.ruff_cache/
|
167 |
+
|
168 |
+
afallah/
|
169 |
+
|
170 |
+
logs/
|
171 |
+
|
172 |
+
temp/
|
173 |
+
|
174 |
+
.gradio/
|
.vscode/launch.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
// Use IntelliSense to learn about possible attributes.
|
3 |
+
// Hover to view descriptions of existing attributes.
|
4 |
+
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
5 |
+
"version": "0.2.0",
|
6 |
+
"configurations": [
|
7 |
+
{
|
8 |
+
"name": "Python Debugger: main.py",
|
9 |
+
"type": "debugpy",
|
10 |
+
"request": "launch",
|
11 |
+
"program": "main.py",
|
12 |
+
"console": "integratedTerminal"
|
13 |
+
}
|
14 |
+
]
|
15 |
+
}
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,267 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<h1 align="center">
|
2 |
+
🤖 MedRAX: Medical Reasoning Agent for Chest X-ray
|
3 |
+
</h1>
|
4 |
+
<p align="center"> <a href="https://arxiv.org/abs/2502.02673" target="_blank"><img src="https://img.shields.io/badge/arXiv-Paper-FF6B6B?style=for-the-badge&logo=arxiv&logoColor=white" alt="arXiv"></a> <a href="https://github.com/bowang-lab/MedRAX"><img src="https://img.shields.io/badge/GitHub-Code-4A90E2?style=for-the-badge&logo=github&logoColor=white" alt="GitHub"></a> <a href="https://huggingface.co/datasets/wanglab/chest-agent-bench"><img src="https://img.shields.io/badge/HuggingFace-Dataset-FFBF00?style=for-the-badge&logo=huggingface&logoColor=white" alt="HuggingFace Dataset"></a> </p>
|
5 |
+
|
6 |
+

|
7 |
+
|
8 |
+
<br>
|
9 |
+
|
10 |
+
## Abstract
|
11 |
+
Chest X-rays (CXRs) play an integral role in driving critical decisions in disease management and patient care. While recent innovations have led to specialized models for various CXR interpretation tasks, these solutions often operate in isolation, limiting their practical utility in clinical practice. We present MedRAX, the first versatile AI agent that seamlessly integrates state-of-the-art CXR analysis tools and multimodal large language models into a unified framework. MedRAX dynamically leverages these models to address complex medical queries without requiring additional training. To rigorously evaluate its capabilities, we introduce ChestAgentBench, a comprehensive benchmark containing 2,500 complex medical queries across 7 diverse categories. Our experiments demonstrate that MedRAX achieves state-of-the-art performance compared to both open-source and proprietary models, representing a significant step toward the practical deployment of automated CXR interpretation systems.
|
12 |
+
<br><br>
|
13 |
+
|
14 |
+
|
15 |
+
## MedRAX
|
16 |
+
MedRAX is built on a robust technical foundation:
|
17 |
+
- **Core Architecture**: Built on LangChain and LangGraph frameworks
|
18 |
+
- **Language Model**: Uses GPT-4o with vision capabilities as the backbone LLM
|
19 |
+
- **Deployment**: Supports both local and cloud-based deployments
|
20 |
+
- **Interface**: Production-ready interface built with Gradio
|
21 |
+
- **Modular Design**: Tool-agnostic architecture allowing easy integration of new capabilities
|
22 |
+
|
23 |
+
### Integrated Tools
|
24 |
+
- **Visual QA**: Utilizes CheXagent and LLaVA-Med for complex visual understanding and medical reasoning
|
25 |
+
- **Segmentation**: Employs MedSAM and PSPNet model trained on ChestX-Det for precise anatomical structure identification
|
26 |
+
- **Grounding**: Uses Maira-2 for localizing specific findings in medical images
|
27 |
+
- **Report Generation**: Implements SwinV2 Transformer trained on CheXpert Plus for detailed medical reporting
|
28 |
+
- **Disease Classification**: Leverages DenseNet-121 from TorchXRayVision for detecting 18 pathology classes
|
29 |
+
- **X-ray Generation**: Utilizes RoentGen for synthetic CXR generation
|
30 |
+
- **Utilities**: Includes DICOM processing, visualization tools, and custom plotting capabilities
|
31 |
+
|
32 |
+
Note the current version of MedRAX is experimentally released and does not support vision for GPT-4o and MedSAM. We will be integrating these shortly.
|
33 |
+
<br><br>
|
34 |
+
|
35 |
+
|
36 |
+
## ChestAgentBench
|
37 |
+
We introduce ChestAgentBench, a comprehensive evaluation framework with 2,500 complex medical queries across 7 categories, built from 675 expert-curated clinical cases. The benchmark evaluates complex multi-step reasoning in CXR interpretation through:
|
38 |
+
|
39 |
+
- Detection
|
40 |
+
- Classification
|
41 |
+
- Localization
|
42 |
+
- Comparison
|
43 |
+
- Relationship
|
44 |
+
- Diagnosis
|
45 |
+
- Characterization
|
46 |
+
|
47 |
+
Download the benchmark: [ChestAgentBench on Hugging Face](https://huggingface.co/datasets/wanglab/chest-agent-bench)
|
48 |
+
```
|
49 |
+
huggingface-cli download wanglab/chestagentbench --repo-type dataset --local-dir chestagentbench
|
50 |
+
```
|
51 |
+
|
52 |
+
Unzip the Eurorad figures to your local `MedMAX` directory.
|
53 |
+
```
|
54 |
+
unzip chestagentbench/figures.zip
|
55 |
+
```
|
56 |
+
|
57 |
+
To evaluate with GPT-4o, set your OpenAI API key and run the quickstart script.
|
58 |
+
```
|
59 |
+
export OPENAI_API_KEY="<your-openai-api-key>"
|
60 |
+
python quickstart.py \
|
61 |
+
--model chatgpt-4o-latest \
|
62 |
+
--temperature 0.2 \
|
63 |
+
--max-cases 2 \
|
64 |
+
--log-prefix chatgpt-4o-latest \
|
65 |
+
--use-urls
|
66 |
+
```
|
67 |
+
|
68 |
+
|
69 |
+
<br>
|
70 |
+
|
71 |
+
## Installation
|
72 |
+
### Prerequisites
|
73 |
+
- Python 3.8+
|
74 |
+
- CUDA/GPU for best performance
|
75 |
+
|
76 |
+
### Installation Steps
|
77 |
+
```bash
|
78 |
+
# Clone the repository
|
79 |
+
git clone https://github.com/bowang-lab/MedRAX.git
|
80 |
+
cd MedRAX
|
81 |
+
|
82 |
+
# Install package
|
83 |
+
pip install -e .
|
84 |
+
```
|
85 |
+
|
86 |
+
### Getting Started
|
87 |
+
```bash
|
88 |
+
# Start the Gradio interface
|
89 |
+
python main.py
|
90 |
+
```
|
91 |
+
or if you run into permission issues
|
92 |
+
```bash
|
93 |
+
sudo -E env "PATH=$PATH" python main.py
|
94 |
+
```
|
95 |
+
You need to setup the `model_dir` inside `main.py` to the directory where you want to download or already have the weights of above tools from Hugging Face.
|
96 |
+
Comment out the tools that you do not have access to.
|
97 |
+
Make sure to setup your OpenAI API key in `.env` file!
|
98 |
+
<br><br><br>
|
99 |
+
|
100 |
+
|
101 |
+
## Tool Selection and Initialization
|
102 |
+
|
103 |
+
MedRAX supports selective tool initialization, allowing you to use only the tools you need. Tools can be specified when initializing the agent (look at `main.py`):
|
104 |
+
|
105 |
+
```python
|
106 |
+
selected_tools = [
|
107 |
+
"ImageVisualizerTool",
|
108 |
+
"ChestXRayClassifierTool",
|
109 |
+
"ChestXRaySegmentationTool",
|
110 |
+
# Add or remove tools as needed
|
111 |
+
]
|
112 |
|
113 |
+
agent, tools_dict = initialize_agent(
|
114 |
+
"medrax/docs/system_prompts.txt",
|
115 |
+
tools_to_use=selected_tools,
|
116 |
+
model_dir="/model-weights"
|
117 |
+
)
|
118 |
+
```
|
119 |
+
|
120 |
+
<br><br>
|
121 |
+
## Automatically Downloaded Models
|
122 |
+
|
123 |
+
The following tools will automatically download their model weights when initialized:
|
124 |
+
|
125 |
+
### Classification Tool
|
126 |
+
```python
|
127 |
+
ChestXRayClassifierTool(device=device)
|
128 |
+
```
|
129 |
+
|
130 |
+
### Segmentation Tool
|
131 |
+
```python
|
132 |
+
ChestXRaySegmentationTool(device=device)
|
133 |
+
```
|
134 |
+
|
135 |
+
### Grounding Tool
|
136 |
+
```python
|
137 |
+
XRayPhraseGroundingTool(
|
138 |
+
cache_dir=model_dir,
|
139 |
+
temp_dir=temp_dir,
|
140 |
+
load_in_8bit=True,
|
141 |
+
device=device
|
142 |
+
)
|
143 |
+
```
|
144 |
+
- Maira-2 weights download to specified `cache_dir`
|
145 |
+
- 8-bit and 4-bit quantization available for reduced memory usage
|
146 |
+
|
147 |
+
### LLaVA-Med Tool
|
148 |
+
```python
|
149 |
+
LlavaMedTool(
|
150 |
+
cache_dir=model_dir,
|
151 |
+
device=device,
|
152 |
+
load_in_8bit=True
|
153 |
+
)
|
154 |
+
```
|
155 |
+
- Automatic weight download to `cache_dir`
|
156 |
+
- 8-bit and 4-bit quantization available for reduced memory usage
|
157 |
+
|
158 |
+
### Report Generation Tool
|
159 |
+
```python
|
160 |
+
ChestXRayReportGeneratorTool(
|
161 |
+
cache_dir=model_dir,
|
162 |
+
device=device
|
163 |
+
)
|
164 |
+
```
|
165 |
+
|
166 |
+
### Visual QA Tool
|
167 |
+
```python
|
168 |
+
XRayVQATool(
|
169 |
+
cache_dir=model_dir,
|
170 |
+
device=device
|
171 |
+
)
|
172 |
+
```
|
173 |
+
- CheXagent weights download automatically
|
174 |
+
|
175 |
+
### MedSAM Tool
|
176 |
+
```
|
177 |
+
Support for MedSAM segmentation will be added in a future update.
|
178 |
+
```
|
179 |
+
|
180 |
+
### Utility Tools
|
181 |
+
No additional model weights required:
|
182 |
+
```python
|
183 |
+
ImageVisualizerTool()
|
184 |
+
DicomProcessorTool(temp_dir=temp_dir)
|
185 |
+
```
|
186 |
+
<br>
|
187 |
+
|
188 |
+
## Manual Setup Required
|
189 |
+
|
190 |
+
### Image Generation Tool
|
191 |
+
```python
|
192 |
+
ChestXRayGeneratorTool(
|
193 |
+
model_path=f"{model_dir}/roentgen",
|
194 |
+
temp_dir=temp_dir,
|
195 |
+
device=device
|
196 |
+
)
|
197 |
+
```
|
198 |
+
- RoentGen weights require manual setup:
|
199 |
+
1. Contact authors: https://github.com/StanfordMIMI/RoentGen
|
200 |
+
2. Place weights in `{model_dir}/roentgen`
|
201 |
+
3. Optional tool, can be excluded if not needed
|
202 |
+
<br>
|
203 |
+
|
204 |
+
## Configuration Notes
|
205 |
+
|
206 |
+
### Required Parameters
|
207 |
+
- `model_dir` or `cache_dir`: Base directory for model weights that Hugging Face uses
|
208 |
+
- `temp_dir`: Directory for temporary files
|
209 |
+
- `device`: "cuda" for GPU, "cpu" for CPU-only
|
210 |
+
|
211 |
+
### Memory Management
|
212 |
+
- Consider selective tool initialization for resource constraints
|
213 |
+
- Use 8-bit quantization where available
|
214 |
+
- Some tools (LLaVA-Med, Grounding) are more resource-intensive
|
215 |
+
<br>
|
216 |
+
|
217 |
+
### Local LLMs
|
218 |
+
If you are running a local LLM using frameworks like [Ollama](https://ollama.com/) or [LM Studio](https://lmstudio.ai/), you need to configure your environment variables accordingly. For example:
|
219 |
+
```
|
220 |
+
export OPENAI_BASE_URL="http://localhost:11434/v1"
|
221 |
+
export OPENAI_API_KEY="ollama"
|
222 |
+
```
|
223 |
+
<br>
|
224 |
+
|
225 |
+
## Star History
|
226 |
+
<div align="center">
|
227 |
+
|
228 |
+
[](https://star-history.com/#bowang-lab/MedRAX&Date)
|
229 |
+
|
230 |
+
</div>
|
231 |
+
<br>
|
232 |
+
|
233 |
+
|
234 |
+
## Authors
|
235 |
+
- **Adibvafa Fallahpour**¹²³ * ([email protected])
|
236 |
+
- **Jun Ma**²³ *
|
237 |
+
- **Alif Munim**³⁴ *
|
238 |
+
- **Hongwei Lyu**³
|
239 |
+
- **Bo Wang**¹²³⁵
|
240 |
+
|
241 |
+
¹ Department of Computer Science, University of Toronto, Toronto, Canada
|
242 |
+
² Vector Institute, Toronto, Canada
|
243 |
+
³ University Health Network, Toronto, Canada
|
244 |
+
⁴ Cohere For AI, Toronto, Canada
|
245 |
+
⁵ Department of Laboratory Medicine and Pathobiology, University of Toronto, Toronto, Canada <br>
|
246 |
+
\* Equal contribution
|
247 |
+
<br><br>
|
248 |
+
|
249 |
+
|
250 |
+
## Citation
|
251 |
+
If you find this work useful, please cite our paper:
|
252 |
+
```bibtex
|
253 |
+
@misc{fallahpour2025medraxmedicalreasoningagent,
|
254 |
+
title={MedRAX: Medical Reasoning Agent for Chest X-ray},
|
255 |
+
author={Adibvafa Fallahpour and Jun Ma and Alif Munim and Hongwei Lyu and Bo Wang},
|
256 |
+
year={2025},
|
257 |
+
eprint={2502.02673},
|
258 |
+
archivePrefix={arXiv},
|
259 |
+
primaryClass={cs.LG},
|
260 |
+
url={https://arxiv.org/abs/2502.02673},
|
261 |
+
}
|
262 |
+
```
|
263 |
+
|
264 |
+
---
|
265 |
+
<p align="center">
|
266 |
+
Made with ❤️ at University of Toronto, Vector Institute, and University Health Network
|
267 |
+
</p>
|
app.py
DELETED
@@ -1,7 +0,0 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
-
|
6 |
-
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
benchmark/__init__.py
ADDED
File without changes
|
benchmark/create_benchmark.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Medical X-ray Question Generation Benchmark aka ChestAgentBench
|
4 |
+
|
5 |
+
This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o.
|
6 |
+
It structures questions across different analytical categories and saves them as JSON.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
import json
|
12 |
+
from typing import *
|
13 |
+
from pprint import pprint
|
14 |
+
|
15 |
+
import openai
|
16 |
+
import numpy as np
|
17 |
+
from scipy import stats
|
18 |
+
import plotly.graph_objects as go
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
from benchmark.utils import load_eurorad_dataset
|
22 |
+
from benchmark.llm import get_llm_response
|
23 |
+
|
24 |
+
# Constants
|
25 |
+
DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data"
|
26 |
+
DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json")
|
27 |
+
|
28 |
+
SYSTEM_PROMPT = """
|
29 |
+
You are an expert medical benchmark creation assistant.
|
30 |
+
Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays.
|
31 |
+
""".strip()
|
32 |
+
|
33 |
+
CATEGORIES_META = {
|
34 |
+
"detection": "Identify and locate specific findings in the chest X-ray.",
|
35 |
+
"classification": "Determine whether specific findings are present or absent in the chest X-ray.",
|
36 |
+
"enumeration": "Count the number of target findings in the chest X-ray.",
|
37 |
+
"localization": "Locate a given finding in the chest X-ray.",
|
38 |
+
"comparison": "Compare the size or position of a specific finding in the chest X-ray.",
|
39 |
+
"relationship": "Determine the relationship between two or more findings in the chest X-ray.",
|
40 |
+
"diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.",
|
41 |
+
"characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.",
|
42 |
+
"reasoning": "Explain the medical rationale and thought process behind findings and conclusions.",
|
43 |
+
}
|
44 |
+
CATEGORIES = list(CATEGORIES_META.keys())
|
45 |
+
|
46 |
+
CATEGORY_COMBINATIONS = [
|
47 |
+
["detection", "localization", "characterization", "reasoning"], # Detailed Finding Analysis
|
48 |
+
["detection", "classification", "relationship", "reasoning"], # Pattern Recognition & Relations
|
49 |
+
["localization", "comparison", "relationship", "reasoning"], # Spatial Understanding
|
50 |
+
["classification", "comparison", "diagnosis", "reasoning"], # Clinical Decision Making
|
51 |
+
["classification", "characterization", "diagnosis", "reasoning"], # Diagnostic Characterization
|
52 |
+
]
|
53 |
+
|
54 |
+
DEFAULT_SECTIONS = [
|
55 |
+
"history",
|
56 |
+
"image_finding",
|
57 |
+
"discussion",
|
58 |
+
"differential_diagnosis",
|
59 |
+
"diagnosis",
|
60 |
+
"figures",
|
61 |
+
]
|
62 |
+
|
63 |
+
|
64 |
+
class Question:
|
65 |
+
"""A class to generate clinical questions from case data.
|
66 |
+
|
67 |
+
This class handles creating structured clinical questions by combining case data with
|
68 |
+
specified categories and difficulty levels.
|
69 |
+
|
70 |
+
Attributes:
|
71 |
+
type (str): The type of question (e.g. multiple choice)
|
72 |
+
difficulty (str): Difficulty level of the question
|
73 |
+
case_data (Dict[str, Any]): Dictionary containing the clinical case data
|
74 |
+
case_content (str): Formatted case data from selected sections
|
75 |
+
case_id (str): Unique identifier for the case
|
76 |
+
categories (List[str]): List of analytical categories this question tests
|
77 |
+
sections (List[str]): Case sections to include in question
|
78 |
+
raw_content (Optional[str]): Raw LLM response to the question prompt
|
79 |
+
content (Optional[Dict[str, str]]): Extracted content from the raw LLM response
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
type: str,
|
85 |
+
difficulty: str,
|
86 |
+
case_data: Dict[str, Any],
|
87 |
+
categories: List[str],
|
88 |
+
sections: List[str] = [
|
89 |
+
"history",
|
90 |
+
"image_finding",
|
91 |
+
"discussion",
|
92 |
+
"differential_diagnosis",
|
93 |
+
"diagnosis",
|
94 |
+
"figures",
|
95 |
+
],
|
96 |
+
system_prompt: str = "You are an expert medical benchmark creation assistant.",
|
97 |
+
) -> None:
|
98 |
+
self.type = type
|
99 |
+
self.difficulty = difficulty
|
100 |
+
self.case_data = case_data
|
101 |
+
self.case_id = case_data["case_id"]
|
102 |
+
self.categories = categories
|
103 |
+
self.sections = sections
|
104 |
+
self.system_prompt = system_prompt
|
105 |
+
self.case_content = self.select_case_sections()
|
106 |
+
self.raw_content: Optional[str] = None
|
107 |
+
self.content: Optional[Dict[str, str]] = None
|
108 |
+
|
109 |
+
def create_question_prompt(self) -> str:
|
110 |
+
"""Creates a formatted prompt for generating a clinical question.
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: A structured prompt containing the question parameters and clinical data
|
114 |
+
"""
|
115 |
+
category_descriptions = "\n".join(
|
116 |
+
f"{category}: {desc}"
|
117 |
+
for category, desc in CATEGORIES_META.items()
|
118 |
+
if category in self.categories
|
119 |
+
)
|
120 |
+
|
121 |
+
return f"""
|
122 |
+
You must follow these guidelines:
|
123 |
+
1. Questions must be answerable using only context and chest X-rays.
|
124 |
+
- Questions must explicitly mention the referenced figures
|
125 |
+
- Questions can only reference the chest X-ray figures
|
126 |
+
|
127 |
+
2. Questions must have unambiguous, verifiable answers, and should:
|
128 |
+
- Challenge the agent's analytical capabilities
|
129 |
+
- Require multi-step reasoning
|
130 |
+
- Test ability to make precise observations
|
131 |
+
- Evaluate capability to derive insights and findings from the chest X-ray
|
132 |
+
|
133 |
+
3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools.
|
134 |
+
|
135 |
+
|
136 |
+
Create a {self.difficulty} {self.type} clinical question that integrates the following:
|
137 |
+
|
138 |
+
{category_descriptions}
|
139 |
+
|
140 |
+
based on the following clinical case:
|
141 |
+
|
142 |
+
{self.case_content}
|
143 |
+
|
144 |
+
Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays.
|
145 |
+
Your question should require the agent to derive insights and findings from the chest X-ray by itself.
|
146 |
+
Your answer should be verifiable directly in the context of the case.
|
147 |
+
You can only use the image findings that come from the chest X-ray figures.
|
148 |
+
|
149 |
+
Your response must follow this exact format:
|
150 |
+
THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question]
|
151 |
+
QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.]
|
152 |
+
FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]]
|
153 |
+
EXPLANATION: [short explanation of why your answer is verifiable in the case]
|
154 |
+
ANSWER: [correct answer e.g. "A"]
|
155 |
+
""".strip().replace(
|
156 |
+
" ", ""
|
157 |
+
) # remove tabs
|
158 |
+
|
159 |
+
def select_case_sections(self) -> str:
|
160 |
+
"""Extract and format selected sections from case data into paragraphs.
|
161 |
+
|
162 |
+
Returns:
|
163 |
+
str: Formatted string with case sections and content
|
164 |
+
"""
|
165 |
+
section_mapping = {
|
166 |
+
"history": ("history", "No history provided."),
|
167 |
+
"image_finding": ("image_finding", "No findings provided."),
|
168 |
+
"discussion": ("discussion", "No discussion provided."),
|
169 |
+
"differential_diagnosis": (
|
170 |
+
"differential_diagnosis",
|
171 |
+
"No differential diagnosis provided.",
|
172 |
+
),
|
173 |
+
"diagnosis": ("diagnosis", "No diagnosis provided."),
|
174 |
+
"figures": ("figures", "No figures provided."),
|
175 |
+
}
|
176 |
+
|
177 |
+
formatted = []
|
178 |
+
for section in self.sections:
|
179 |
+
if section in section_mapping:
|
180 |
+
key, default = section_mapping[section]
|
181 |
+
content = self.case_data.get(key, default)
|
182 |
+
|
183 |
+
if key == "figures":
|
184 |
+
figures_text = []
|
185 |
+
for figure in content:
|
186 |
+
for subfig in figure["subfigures"]:
|
187 |
+
figures_text.append(f"{subfig['number']}: {subfig['caption']}")
|
188 |
+
content = "\n".join(figures_text)
|
189 |
+
|
190 |
+
formatted.append(f"{section}:\n{content}")
|
191 |
+
|
192 |
+
return "\n\n".join(formatted)
|
193 |
+
|
194 |
+
def create_question(
|
195 |
+
self,
|
196 |
+
client: openai.OpenAI,
|
197 |
+
temperature: float = 0.7,
|
198 |
+
top_p: float = 0.95,
|
199 |
+
max_tokens: int = 500,
|
200 |
+
model: str = "gpt-4o",
|
201 |
+
) -> str:
|
202 |
+
"""Create a clinical question using LLM.
|
203 |
+
|
204 |
+
Args:
|
205 |
+
client (openai.OpenAI): OpenAI client instance
|
206 |
+
temperature (float): Controls randomness in responses. Defaults to 0.7.
|
207 |
+
top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95.
|
208 |
+
max_tokens (int): Max tokens in model response. Defaults to 500.
|
209 |
+
model (str): OpenAI model to use. Defaults to "gpt-4o".
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
str: LLM response containing formatted question components
|
213 |
+
"""
|
214 |
+
self.raw_content = get_llm_response(
|
215 |
+
client=client,
|
216 |
+
prompt=self.create_question_prompt(),
|
217 |
+
system_prompt=self.system_prompt,
|
218 |
+
temperature=temperature,
|
219 |
+
top_p=top_p,
|
220 |
+
max_tokens=max_tokens,
|
221 |
+
model=model,
|
222 |
+
)
|
223 |
+
self.content = self.extract_content()
|
224 |
+
|
225 |
+
return self.raw_content
|
226 |
+
|
227 |
+
def extract_content(self) -> Dict[str, str]:
|
228 |
+
"""Extract sections from raw LLM response using regex patterns.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer
|
232 |
+
"""
|
233 |
+
keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"]
|
234 |
+
|
235 |
+
content = {}
|
236 |
+
for kw in keywords:
|
237 |
+
pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)"
|
238 |
+
match = re.search(pattern, self.raw_content, re.DOTALL)
|
239 |
+
content[kw.lower()] = match.group(1).strip() if match else None
|
240 |
+
|
241 |
+
return content
|
242 |
+
|
243 |
+
def save(self, output_path: str) -> Dict[str, Any]:
|
244 |
+
"""Save question content and metadata as a JSON file.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
output_path (str): Directory path where the JSON file will be saved
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
Dict[str, Any]: Question data including content (thoughts, question, figures, options,
|
251 |
+
explanation, answer) and metadata (type, difficulty, categories, etc.)
|
252 |
+
"""
|
253 |
+
question_metadata = self.content.copy()
|
254 |
+
|
255 |
+
# Add metadata
|
256 |
+
question_metadata["metadata"] = {
|
257 |
+
"case_id": self.case_id,
|
258 |
+
"type": self.type,
|
259 |
+
"difficulty": self.difficulty,
|
260 |
+
"categories": self.categories,
|
261 |
+
"sections": self.sections,
|
262 |
+
}
|
263 |
+
|
264 |
+
# Create a directory for the case
|
265 |
+
case_dir = os.path.join(output_path, str(self.case_id))
|
266 |
+
os.makedirs(case_dir, exist_ok=True)
|
267 |
+
|
268 |
+
# Save the question metadata to a JSON file
|
269 |
+
output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json")
|
270 |
+
with open(output_file, "w") as f:
|
271 |
+
json.dump(question_metadata, f, indent=2)
|
272 |
+
|
273 |
+
return question_metadata
|
274 |
+
|
275 |
+
|
276 |
+
def generate_questions(
|
277 |
+
dataset: Dict[str, Any],
|
278 |
+
client: openai.OpenAI,
|
279 |
+
output_dir: str,
|
280 |
+
skip_first: int = 100,
|
281 |
+
temperature: float = 0.7,
|
282 |
+
top_p: float = 0.95,
|
283 |
+
max_tokens: int = 1200,
|
284 |
+
model: str = "gpt-4o",
|
285 |
+
) -> None:
|
286 |
+
"""Generate questions for each case and category combination.
|
287 |
+
|
288 |
+
Args:
|
289 |
+
dataset: Dictionary of case data
|
290 |
+
client: OpenAI client instance
|
291 |
+
output_dir: Directory to save generated questions
|
292 |
+
skip_first: Number of initial cases to skip
|
293 |
+
temperature: LLM temperature parameter
|
294 |
+
top_p: LLM top_p parameter
|
295 |
+
max_tokens: Maximum tokens for LLM response
|
296 |
+
model: LLM model name
|
297 |
+
"""
|
298 |
+
target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first]
|
299 |
+
|
300 |
+
for case_id in tqdm(target_cases, desc="Processing cases"):
|
301 |
+
case_data = dataset[case_id]
|
302 |
+
|
303 |
+
for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"):
|
304 |
+
question = Question(
|
305 |
+
type="multiple choice (A/B/C/D/E/F)",
|
306 |
+
difficulty="complex",
|
307 |
+
case_data=case_data,
|
308 |
+
categories=category,
|
309 |
+
sections=DEFAULT_SECTIONS,
|
310 |
+
system_prompt=SYSTEM_PROMPT,
|
311 |
+
)
|
312 |
+
|
313 |
+
response = question.create_question(
|
314 |
+
client=client,
|
315 |
+
temperature=temperature,
|
316 |
+
top_p=top_p,
|
317 |
+
max_tokens=max_tokens,
|
318 |
+
model=model,
|
319 |
+
)
|
320 |
+
question.save(output_dir)
|
321 |
+
|
322 |
+
|
323 |
+
def main():
|
324 |
+
"""Main execution function."""
|
325 |
+
client = openai.OpenAI()
|
326 |
+
|
327 |
+
# Load and verify dataset
|
328 |
+
dataset = load_eurorad_dataset(
|
329 |
+
DATASET_PATH,
|
330 |
+
section="Chest Imaging",
|
331 |
+
as_dict=True,
|
332 |
+
filter_by_caption=[
|
333 |
+
"xray",
|
334 |
+
"x-ray",
|
335 |
+
"x ray",
|
336 |
+
"ray",
|
337 |
+
"xr",
|
338 |
+
"radiograph",
|
339 |
+
],
|
340 |
+
)
|
341 |
+
print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n")
|
342 |
+
|
343 |
+
# Optional: Print sample case for verification
|
344 |
+
case_data = dataset["16798"]
|
345 |
+
pprint(case_data, sort_dicts=False)
|
346 |
+
|
347 |
+
# Generate questions
|
348 |
+
generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions")
|
349 |
+
|
350 |
+
|
351 |
+
if __name__ == "__main__":
|
352 |
+
main()
|
benchmark/llm.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
def get_llm_response(
|
6 |
+
client: openai.OpenAI,
|
7 |
+
prompt: str,
|
8 |
+
system_prompt: str = "You are a helpful assistant.",
|
9 |
+
model: str = "gpt-4o-mini",
|
10 |
+
temperature: float = 0.7,
|
11 |
+
top_p: float = 0.95,
|
12 |
+
max_tokens: int = 500,
|
13 |
+
) -> str:
|
14 |
+
"""
|
15 |
+
Get response from OpenAI language model.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
client (openai.OpenAI): OpenAI client
|
19 |
+
prompt (str): The user prompt/question to send to the model
|
20 |
+
system_prompt (str, optional): System prompt to set model behavior.
|
21 |
+
model (str, optional): OpenAI model to use. Defaults to "gpt-4o-mini".
|
22 |
+
temperature (float, optional): Controls randomness in responses. Defaults to 0.7.
|
23 |
+
top_p (float, optional): Controls diversity via nucleus sampling. Defaults to 0.95.
|
24 |
+
max_tokens (int, optional): Max tokens in model response. Defaults to 200.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
str: The model's response text
|
28 |
+
"""
|
29 |
+
messages = [
|
30 |
+
{"role": "system", "content": system_prompt},
|
31 |
+
{"role": "user", "content": prompt},
|
32 |
+
]
|
33 |
+
|
34 |
+
response = client.chat.completions.create(
|
35 |
+
model=model,
|
36 |
+
messages=messages,
|
37 |
+
temperature=temperature,
|
38 |
+
top_p=top_p,
|
39 |
+
max_tokens=max_tokens,
|
40 |
+
)
|
41 |
+
|
42 |
+
return response.choices[0].message.content
|
benchmark/utils.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
from typing import Dict, List
|
4 |
+
|
5 |
+
|
6 |
+
def load_eurorad_dataset(
|
7 |
+
dataset_path: str,
|
8 |
+
section: str = "any",
|
9 |
+
as_dict: bool = False,
|
10 |
+
filter_by_caption: List[str] = [
|
11 |
+
"xray",
|
12 |
+
"x-ray",
|
13 |
+
"x ray",
|
14 |
+
"ray",
|
15 |
+
"xr",
|
16 |
+
"radiograph",
|
17 |
+
"radiogram",
|
18 |
+
"plain film",
|
19 |
+
],
|
20 |
+
) -> List[Dict] | Dict[str, Dict]:
|
21 |
+
"""
|
22 |
+
Load a dataset from a JSON file.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
dataset_path (str): Path to the JSON dataset file.
|
26 |
+
section (str, optional): Section of the dataset to load. Defaults to "any".
|
27 |
+
as_dict (bool, optional): Whether to return data as dict. Defaults to False.
|
28 |
+
filter_by_caption (List[str], optional): List of strings to filter cases by caption content. Defaults to [].
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
List[Dict] | Dict[str, Dict]: The loaded dataset as a list of dictionaries or dict if as_dict=True.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
FileNotFoundError: If dataset_path does not exist
|
35 |
+
json.JSONDecodeError: If file is not valid JSON
|
36 |
+
"""
|
37 |
+
|
38 |
+
with open(dataset_path, "r", encoding="utf-8") as file:
|
39 |
+
data = json.load(file)
|
40 |
+
|
41 |
+
if filter_by_caption:
|
42 |
+
filtered_data = {}
|
43 |
+
for case_id, case in data.items():
|
44 |
+
if any(
|
45 |
+
any(x in subfig["caption"].lower() for x in filter_by_caption)
|
46 |
+
for figure in case["figures"]
|
47 |
+
for subfig in figure["subfigures"]
|
48 |
+
) or any(x in case["image_finding"].lower() for x in filter_by_caption):
|
49 |
+
filtered_data[case_id] = case
|
50 |
+
data = filtered_data
|
51 |
+
|
52 |
+
if section != "any":
|
53 |
+
section = section.strip().lower()
|
54 |
+
if not as_dict:
|
55 |
+
data = [
|
56 |
+
item for item in data.values() if item.get("section", "").strip().lower() == section
|
57 |
+
]
|
58 |
+
else:
|
59 |
+
data = {
|
60 |
+
k: v for k, v in data.items() if v.get("section", "").strip().lower() == section
|
61 |
+
}
|
62 |
+
|
63 |
+
elif not as_dict:
|
64 |
+
data = list(data.values())
|
65 |
+
|
66 |
+
return data
|
67 |
+
|
68 |
+
|
69 |
+
def save_dataset(dataset: Dict | List[Dict], dataset_path: str):
|
70 |
+
"""
|
71 |
+
Save a dataset to a JSON file.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
dataset (Dict | List[Dict]): The dataset to save as a dictionary or list of dictionaries.
|
75 |
+
dataset_path (str): Path where the JSON dataset file will be saved.
|
76 |
+
"""
|
77 |
+
with open(dataset_path, "w", encoding="utf-8") as file:
|
78 |
+
json.dump(dataset, file)
|
data/eurorad_metadata.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
data/figures.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
from pathlib import Path
|
4 |
+
import requests
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def download_eurorad_figures(metadata_path: str, output_dir: str) -> None:
|
9 |
+
"""
|
10 |
+
Download figures from Eurorad dataset and save them organized by case_id.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
metadata_path: Path to the eurorad_metadata.json file
|
14 |
+
output_dir: Base directory where figures will be saved
|
15 |
+
|
16 |
+
The figures will be saved as:
|
17 |
+
{output_dir}/{case_id}/{figure_number}.jpg
|
18 |
+
Example:
|
19 |
+
figures/189/Figure_1a.jpg
|
20 |
+
"""
|
21 |
+
# Create output directory if it doesn't exist
|
22 |
+
output_path = Path(output_dir)
|
23 |
+
output_path.mkdir(exist_ok=True)
|
24 |
+
|
25 |
+
# Load metadata
|
26 |
+
with open(metadata_path) as f:
|
27 |
+
metadata = json.load(f)
|
28 |
+
|
29 |
+
# Iterate through all cases with progress bar
|
30 |
+
for case_id in tqdm(metadata, desc="Downloading cases", unit="case"):
|
31 |
+
case = metadata[case_id]
|
32 |
+
case_dir = output_path / str(case["case_id"])
|
33 |
+
case_dir.mkdir(exist_ok=True)
|
34 |
+
|
35 |
+
# Process all figures and their subfigures
|
36 |
+
for figure in case["figures"]:
|
37 |
+
for subfig in figure["subfigures"]:
|
38 |
+
|
39 |
+
# Remove leading and trailing whitespace and convert to lowercase
|
40 |
+
subfig_name = f"{subfig['number'].strip().replace(' ', '_').lower()}.jpg"
|
41 |
+
subfig_path = Path(case_dir) / subfig_name
|
42 |
+
|
43 |
+
save_figure(
|
44 |
+
url=subfig["url"],
|
45 |
+
output_path=subfig_path,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def save_figure(url: str, output_path: Path) -> None:
|
50 |
+
"""
|
51 |
+
Download and save a single figure.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
url: URL of the figure to download
|
55 |
+
output_path: Path where the figure should be saved
|
56 |
+
"""
|
57 |
+
if output_path.exists():
|
58 |
+
return
|
59 |
+
|
60 |
+
try:
|
61 |
+
response = requests.get(url, timeout=10)
|
62 |
+
response.raise_for_status()
|
63 |
+
with open(output_path, "wb") as f:
|
64 |
+
f.write(response.content)
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error downloading {url}: {e}")
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == "__main__":
|
70 |
+
root = os.path.dirname(os.path.abspath(__file__))
|
71 |
+
download_eurorad_figures(
|
72 |
+
metadata_path=os.path.join(root, "eurorad_metadata.json"),
|
73 |
+
output_dir=os.path.join(root, "figures"),
|
74 |
+
)
|
data/get_cases.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
import time
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
def get_response(url):
|
9 |
+
headers = {
|
10 |
+
"user-agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/108.0.0.0 Safari/537.36 Edg/108.0.1462.54"
|
11 |
+
}
|
12 |
+
return requests.get(url, headers=headers)
|
13 |
+
|
14 |
+
def get_case_numbers_from_page(page):
|
15 |
+
url = f"https://www.eurorad.org/advanced-search?sort_by=published_at&sort_order=ASC&page={page}&filter%5B0%5D=section%3A40"
|
16 |
+
|
17 |
+
# Remove proxy usage since it's likely triggering the protection
|
18 |
+
response = get_response(url)
|
19 |
+
print(response.text)
|
20 |
+
|
21 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
22 |
+
spans = soup.find_all("span", class_="case__number small")
|
23 |
+
|
24 |
+
# Remove '#' from the span text and strip extra whitespace
|
25 |
+
numbers = [span.text.strip().replace("#", "").strip() for span in spans]
|
26 |
+
return numbers
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
total_pages = 107 # Pages 0 through 106
|
31 |
+
all_numbers = []
|
32 |
+
|
33 |
+
for page in tqdm(range(total_pages)):
|
34 |
+
numbers = get_case_numbers_from_page(page)
|
35 |
+
all_numbers.extend(numbers)
|
36 |
+
|
37 |
+
if page != total_pages - 1 and len(numbers) != 9:
|
38 |
+
print(f"Warning: Page {page} returned {len(numbers)} cases instead of 9")
|
39 |
+
|
40 |
+
# Be kind to the server – avoid hitting it too fast
|
41 |
+
time.sleep(1)
|
42 |
+
break
|
43 |
+
|
44 |
+
with open('case_numbers.json', 'w') as f:
|
45 |
+
json.dump(all_numbers, f)
|
46 |
+
|
47 |
+
print(f"Saved {len(all_numbers)} case numbers to case_numbers.json")
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
experiments/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Experiments
|
2 |
+
Below are the instructions for running experiments using our novel ChestAgentBench and the previous SoTA CheXbench. ChestAgentBench is a comprehensive benchmark containing over 2,500 complex medical queries across 8 diverse categories.
|
3 |
+
|
4 |
+
### ChestAgentBench
|
5 |
+
|
6 |
+
To run gpt-4o on ChestAgentBench, enter the `experiments` directory and run the following script:
|
7 |
+
```bash
|
8 |
+
python benchmark_gpt4o.py
|
9 |
+
```
|
10 |
+
|
11 |
+
To run llama 3.2 vision 90B on ChestAgentBench, run the following:
|
12 |
+
```bash
|
13 |
+
python benchmark_llama.py
|
14 |
+
```
|
15 |
+
|
16 |
+
To run chexagent on ChestAgentBench, run the following:
|
17 |
+
```bash
|
18 |
+
python benchmark_chexagent.py
|
19 |
+
```
|
20 |
+
|
21 |
+
To run llava-med on ChestAgentBench, you'll need to clone their repo and copy the following script into it, after you follow their setup instructions.
|
22 |
+
```bash
|
23 |
+
mv benchmark_llavamed.py ~/LLaVA-Med/llava/serve
|
24 |
+
python -m llava.serve.benchmark_llavamed --model-name llava-med-v1.5-mistral-7b --controller http://localhost:10000
|
25 |
+
```
|
26 |
+
|
27 |
+
If you want to inspect the logs, you can run the following. It will select the most recent log file by default.
|
28 |
+
```bash
|
29 |
+
python inspect_logs.py [optional: log-file] -n [num-logs]
|
30 |
+
```
|
31 |
+
|
32 |
+
Finally, to analyze results, run:
|
33 |
+
```bash
|
34 |
+
python analyze_axes.py results/[logfile].json ../benchmark/questions/ --model [gpt4|llama|chexagent|llava-med] --max-questions [optional:int]
|
35 |
+
```
|
36 |
+
|
37 |
+
### CheXbench
|
38 |
+
|
39 |
+
To run the models on chexbench, you can use `chexbench_gpt4.py` as a reference. You'll need to download the dataset files locally, and upload them for each request. Rad-ReStruct and Open-I use the same set of images, so you can download the `NLMCXR.zip` file just once and copy the images to both directories.
|
40 |
+
|
41 |
+
You can find the datasets here:
|
42 |
+
1. [SLAKE: A Semantically-Labeled Knowledge-Enhanced Dataset for Medical Visual Question Answering](https://www.med-vqa.com/slake/). Save this to `MedMAX/data/slake`.
|
43 |
+
2. [Rad-ReStruct: A Novel VQA Benchmark and Method for Structured Radiology Reporting](https://github.com/ChantalMP/Rad-ReStruct). Save the images to `MedMAX/data/rad-restruct/images`.
|
44 |
+
3. [Open-I Service of the National Library of Medicine](https://openi.nlm.nih.gov/faq). Save the images to `MedMAX/data/openi/images`.
|
45 |
+
|
46 |
+
Once you're finished, you'll want to fix the paths in the `chexbench.json` file to your local paths using the `MedMax/data/fix_chexbench.py` script.
|
47 |
+
|
48 |
+
|
49 |
+
### Compare Runs
|
50 |
+
Analyze a single file based on overall accuracy and along different axes
|
51 |
+
```
|
52 |
+
python compare_runs.py results/medmax.json
|
53 |
+
```
|
54 |
+
|
55 |
+
For a direct evaluation comparing **2** models, on the exact same questions
|
56 |
+
```
|
57 |
+
python compare_runs.py results/medmax.json results/gpt4o.json
|
58 |
+
```
|
59 |
+
|
60 |
+
For a direct evaluation comparing **ALL** models, on the exact same questions (add as many model log files as you want).
|
61 |
+
```
|
62 |
+
python compare_runs.py results/medmax.json results/gpt4o.json results/llama.json results/chexagent.json results/llavamed.json
|
63 |
+
```
|
experiments/analyze_axes.py
ADDED
@@ -0,0 +1,385 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Tuple, Union, Any
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
import argparse
|
6 |
+
from collections import defaultdict
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
QUESTION_TYPES = {
|
10 |
+
"Detailed Finding Analysis": ["detection", "localization", "characterization"],
|
11 |
+
"Pattern Recognition & Relations": ["detection", "classification", "relationship"],
|
12 |
+
"Spatial Understanding": ["localization", "comparison", "relationship"],
|
13 |
+
"Clinical Decision Making": ["classification", "comparison", "diagnosis"],
|
14 |
+
"Diagnostic Classification": ["classification", "characterization", "diagnosis"],
|
15 |
+
}
|
16 |
+
|
17 |
+
|
18 |
+
def extract_answer_letter(answer: Optional[Union[str, Any]]) -> Optional[str]:
|
19 |
+
"""
|
20 |
+
Extract just the letter from various answer formats.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
answer: The answer text to extract letter from
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
Optional[str]: The extracted letter in uppercase, or None if no letter found
|
27 |
+
"""
|
28 |
+
if not answer:
|
29 |
+
return None
|
30 |
+
|
31 |
+
# Convert to string and clean
|
32 |
+
answer = str(answer).strip()
|
33 |
+
|
34 |
+
# If it's just a single letter, return it
|
35 |
+
if len(answer) == 1 and answer.isalpha():
|
36 |
+
return answer.upper()
|
37 |
+
|
38 |
+
# Try to extract letter from format like "A)" or "A."
|
39 |
+
if len(answer) >= 2 and answer[0].isalpha() and answer[1] in ").:- ":
|
40 |
+
return answer[0].upper()
|
41 |
+
|
42 |
+
# Try to extract letter from format like "A) Some text"
|
43 |
+
if answer.startswith(("A)", "B)", "C)", "D)", "E)", "F)")):
|
44 |
+
return answer[0].upper()
|
45 |
+
|
46 |
+
return None
|
47 |
+
|
48 |
+
|
49 |
+
def analyze_gpt4_results(
|
50 |
+
results_file: str, max_questions: Optional[int] = None
|
51 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
52 |
+
"""
|
53 |
+
Analyze results in GPT-4 format.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
results_file: Path to results file
|
57 |
+
max_questions: Maximum number of questions to analyze
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
Tuple containing:
|
61 |
+
- overall_accuracy (float)
|
62 |
+
- category_accuracies (Dict)
|
63 |
+
- question_type_stats (Dict)
|
64 |
+
- correct_ids (List[str])
|
65 |
+
- incorrect_ids (List[str])
|
66 |
+
"""
|
67 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
68 |
+
all_questions = 0
|
69 |
+
all_correct = 0
|
70 |
+
correct_ids = []
|
71 |
+
incorrect_ids = []
|
72 |
+
|
73 |
+
with open(results_file, "r") as f:
|
74 |
+
lines = f.readlines()
|
75 |
+
|
76 |
+
processed_questions = 0
|
77 |
+
|
78 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
79 |
+
# Check if we've hit the maximum questions
|
80 |
+
if max_questions is not None and processed_questions >= max_questions:
|
81 |
+
break
|
82 |
+
if line.startswith("HTTP Request:"):
|
83 |
+
continue
|
84 |
+
|
85 |
+
try:
|
86 |
+
entry = json.loads(line)
|
87 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
88 |
+
question_id = entry.get("question_id")
|
89 |
+
|
90 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
91 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
92 |
+
|
93 |
+
if model_letter and correct_letter:
|
94 |
+
all_questions += 1
|
95 |
+
processed_questions += 1
|
96 |
+
is_correct = model_letter == correct_letter
|
97 |
+
|
98 |
+
if is_correct:
|
99 |
+
all_correct += 1
|
100 |
+
correct_ids.append(question_id)
|
101 |
+
else:
|
102 |
+
incorrect_ids.append(question_id)
|
103 |
+
|
104 |
+
for category in metadata.get("categories", []):
|
105 |
+
category_performance[category]["total"] += 1
|
106 |
+
if is_correct:
|
107 |
+
category_performance[category]["correct"] += 1
|
108 |
+
|
109 |
+
except json.JSONDecodeError:
|
110 |
+
continue
|
111 |
+
|
112 |
+
return process_results(
|
113 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def analyze_llama_results(
|
118 |
+
results_file: str, max_questions: Optional[int] = None
|
119 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
120 |
+
"""
|
121 |
+
Analyze results in Llama format.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
results_file: Path to results file
|
125 |
+
max_questions: Maximum number of questions to analyze
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
Tuple containing:
|
129 |
+
- overall_accuracy (float)
|
130 |
+
- category_accuracies (Dict)
|
131 |
+
- question_type_stats (Dict)
|
132 |
+
- correct_ids (List[str])
|
133 |
+
- incorrect_ids (List[str])
|
134 |
+
"""
|
135 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
136 |
+
all_questions = 0
|
137 |
+
all_correct = 0
|
138 |
+
correct_ids = []
|
139 |
+
incorrect_ids = []
|
140 |
+
|
141 |
+
with open(results_file, "r") as f:
|
142 |
+
lines = f.readlines()
|
143 |
+
|
144 |
+
# If max_questions is set, limit the number of lines processed
|
145 |
+
if max_questions is not None:
|
146 |
+
lines = lines[:max_questions]
|
147 |
+
|
148 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
149 |
+
if line.startswith("HTTP Request:"):
|
150 |
+
continue
|
151 |
+
|
152 |
+
try:
|
153 |
+
entry = json.loads(line)
|
154 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
155 |
+
question_id = entry.get("question_id")
|
156 |
+
|
157 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
158 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
159 |
+
|
160 |
+
if model_letter and correct_letter:
|
161 |
+
all_questions += 1
|
162 |
+
is_correct = model_letter == correct_letter
|
163 |
+
|
164 |
+
if is_correct:
|
165 |
+
all_correct += 1
|
166 |
+
correct_ids.append(question_id)
|
167 |
+
else:
|
168 |
+
incorrect_ids.append(question_id)
|
169 |
+
|
170 |
+
for category in metadata.get("categories", []):
|
171 |
+
category_performance[category]["total"] += 1
|
172 |
+
if is_correct:
|
173 |
+
category_performance[category]["correct"] += 1
|
174 |
+
|
175 |
+
except json.JSONDecodeError:
|
176 |
+
continue
|
177 |
+
|
178 |
+
return process_results(
|
179 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
180 |
+
)
|
181 |
+
|
182 |
+
|
183 |
+
def analyze_chexagent_results(
|
184 |
+
results_file: str, max_questions: Optional[int] = None
|
185 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
186 |
+
"""
|
187 |
+
Analyze results in CheXagent format.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
results_file: Path to results file
|
191 |
+
max_questions: Maximum number of questions to analyze
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
Tuple containing:
|
195 |
+
- overall_accuracy (float)
|
196 |
+
- category_accuracies (Dict)
|
197 |
+
- question_type_stats (Dict)
|
198 |
+
- correct_ids (List[str])
|
199 |
+
- incorrect_ids (List[str])
|
200 |
+
"""
|
201 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
202 |
+
all_questions = 0
|
203 |
+
all_correct = 0
|
204 |
+
correct_ids = []
|
205 |
+
incorrect_ids = []
|
206 |
+
|
207 |
+
with open(results_file, "r") as f:
|
208 |
+
lines = f.readlines()
|
209 |
+
|
210 |
+
# If max_questions is set, limit the number of lines processed
|
211 |
+
if max_questions is not None:
|
212 |
+
lines = lines[:max_questions]
|
213 |
+
|
214 |
+
for line in tqdm(lines, desc="Analyzing Benchmark Results"):
|
215 |
+
try:
|
216 |
+
entry = json.loads(line)
|
217 |
+
metadata = entry.get("input", {}).get("question_data", {}).get("metadata", {})
|
218 |
+
question_id = entry.get("question_id")
|
219 |
+
|
220 |
+
model_letter = extract_answer_letter(entry.get("model_answer"))
|
221 |
+
correct_letter = extract_answer_letter(entry.get("correct_answer"))
|
222 |
+
|
223 |
+
if model_letter and correct_letter:
|
224 |
+
all_questions += 1
|
225 |
+
is_correct = model_letter == correct_letter
|
226 |
+
|
227 |
+
if is_correct:
|
228 |
+
all_correct += 1
|
229 |
+
correct_ids.append(question_id)
|
230 |
+
else:
|
231 |
+
incorrect_ids.append(question_id)
|
232 |
+
|
233 |
+
for category in metadata.get("categories", []):
|
234 |
+
category_performance[category]["total"] += 1
|
235 |
+
if is_correct:
|
236 |
+
category_performance[category]["correct"] += 1
|
237 |
+
|
238 |
+
except json.JSONDecodeError:
|
239 |
+
continue
|
240 |
+
|
241 |
+
return process_results(
|
242 |
+
category_performance, all_questions, all_correct, correct_ids, incorrect_ids
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
def process_results(
|
247 |
+
category_performance: Dict,
|
248 |
+
all_questions: int,
|
249 |
+
all_correct: int,
|
250 |
+
correct_ids: Optional[List[str]] = None,
|
251 |
+
incorrect_ids: Optional[List[str]] = None,
|
252 |
+
) -> Tuple[float, Dict, Dict, List[str], List[str]]:
|
253 |
+
"""
|
254 |
+
Process raw results into final statistics.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
category_performance: Dict containing performance by category
|
258 |
+
all_questions: Total number of questions
|
259 |
+
all_correct: Total number of correct answers
|
260 |
+
correct_ids: List of IDs for correctly answered questions
|
261 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
262 |
+
|
263 |
+
Returns:
|
264 |
+
Tuple containing:
|
265 |
+
- overall_accuracy (float)
|
266 |
+
- category_accuracies (Dict)
|
267 |
+
- question_type_stats (Dict)
|
268 |
+
- correct_ids (List[str])
|
269 |
+
- incorrect_ids (List[str])
|
270 |
+
"""
|
271 |
+
category_accuracies = {
|
272 |
+
category: {
|
273 |
+
"accuracy": stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0,
|
274 |
+
"total": stats["total"],
|
275 |
+
"correct": stats["correct"],
|
276 |
+
}
|
277 |
+
for category, stats in category_performance.items()
|
278 |
+
}
|
279 |
+
|
280 |
+
question_type_stats = {}
|
281 |
+
for qtype, categories in QUESTION_TYPES.items():
|
282 |
+
total = sum(
|
283 |
+
category_performance[cat]["total"] for cat in categories if cat in category_performance
|
284 |
+
)
|
285 |
+
correct = sum(
|
286 |
+
category_performance[cat]["correct"]
|
287 |
+
for cat in categories
|
288 |
+
if cat in category_performance
|
289 |
+
)
|
290 |
+
|
291 |
+
question_type_stats[qtype] = {
|
292 |
+
"accuracy": (correct / total * 100) if total > 0 else 0,
|
293 |
+
"total": total,
|
294 |
+
"correct": correct,
|
295 |
+
}
|
296 |
+
|
297 |
+
overall_accuracy = (all_correct / all_questions * 100) if all_questions > 0 else 0
|
298 |
+
|
299 |
+
return (
|
300 |
+
overall_accuracy,
|
301 |
+
category_accuracies,
|
302 |
+
question_type_stats,
|
303 |
+
correct_ids or [],
|
304 |
+
incorrect_ids or [],
|
305 |
+
)
|
306 |
+
|
307 |
+
|
308 |
+
def print_analysis(
|
309 |
+
overall_accuracy: float,
|
310 |
+
category_accuracies: Dict,
|
311 |
+
question_type_stats: Dict,
|
312 |
+
correct_ids: List[str],
|
313 |
+
incorrect_ids: List[str],
|
314 |
+
model_name: str,
|
315 |
+
) -> None:
|
316 |
+
"""
|
317 |
+
Print analysis results.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
overall_accuracy: Overall accuracy percentage
|
321 |
+
category_accuracies: Dict containing accuracy metrics by category
|
322 |
+
question_type_stats: Dict containing stats by question type
|
323 |
+
correct_ids: List of IDs for correctly answered questions
|
324 |
+
incorrect_ids: List of IDs for incorrectly answered questions
|
325 |
+
model_name: Name of the model being analyzed
|
326 |
+
"""
|
327 |
+
total_questions = len(correct_ids) + len(incorrect_ids)
|
328 |
+
print(
|
329 |
+
f"\nOverall Accuracy: {overall_accuracy:.2f}% ({len(correct_ids)} correct out of {total_questions} questions)"
|
330 |
+
)
|
331 |
+
|
332 |
+
print("\nCategory Performance:")
|
333 |
+
sorted_categories = sorted(
|
334 |
+
category_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
|
335 |
+
)
|
336 |
+
for category, metrics in sorted_categories:
|
337 |
+
print(f"{category}:")
|
338 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
339 |
+
print(f" Total Questions: {metrics['total']}")
|
340 |
+
print(f" Correct Questions: {metrics['correct']}")
|
341 |
+
|
342 |
+
print("\nQuestion Type Performance:")
|
343 |
+
sorted_types = sorted(question_type_stats.items(), key=lambda x: x[1]["accuracy"], reverse=True)
|
344 |
+
for qtype, metrics in sorted_types:
|
345 |
+
print(f"\n{qtype}:")
|
346 |
+
print(f" Accuracy: {metrics['accuracy']:.2f}%")
|
347 |
+
print(f" Total Questions: {metrics['total']}")
|
348 |
+
print(f" Correct Questions: {metrics['correct']}")
|
349 |
+
print(f" Categories: {', '.join(QUESTION_TYPES[qtype])}")
|
350 |
+
|
351 |
+
# Save question IDs to JSON
|
352 |
+
question_ids = {"correct_ids": correct_ids, "incorrect_ids": incorrect_ids}
|
353 |
+
|
354 |
+
output_filename = f"{model_name}_question_ids.json"
|
355 |
+
with open(output_filename, "w") as f:
|
356 |
+
json.dump(question_ids, f, indent=2)
|
357 |
+
|
358 |
+
print(f"\nQuestion IDs have been saved to {output_filename}")
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == "__main__":
|
362 |
+
parser = argparse.ArgumentParser(description="Analyze benchmark results")
|
363 |
+
parser.add_argument("results_file", help="Path to results file")
|
364 |
+
parser.add_argument("benchmark_dir", nargs="?", help="Path to benchmark questions directory")
|
365 |
+
parser.add_argument(
|
366 |
+
"--model",
|
367 |
+
choices=["llava-med", "chexagent", "llama", "gpt4", "medrax"],
|
368 |
+
default="gpt4",
|
369 |
+
help="Specify model format (default: gpt4)",
|
370 |
+
)
|
371 |
+
parser.add_argument("--max-questions", type=int, help="Maximum number of questions to analyze")
|
372 |
+
args = parser.parse_args()
|
373 |
+
|
374 |
+
if args.model == "gpt4":
|
375 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
376 |
+
elif args.model == "llama":
|
377 |
+
results = analyze_llama_results(args.results_file, args.max_questions)
|
378 |
+
elif args.model == "chexagent":
|
379 |
+
results = analyze_chexagent_results(args.results_file, args.max_questions)
|
380 |
+
elif args.model == "medrax":
|
381 |
+
results = analyze_gpt4_results(args.results_file, args.max_questions)
|
382 |
+
else:
|
383 |
+
parser.error(f"Unsupported model: {args.model}")
|
384 |
+
|
385 |
+
print_analysis(*results, args.model)
|
experiments/benchmark_chexagent.py
ADDED
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
# Configure model settings
|
14 |
+
MODEL_NAME = "StanfordAIMI/CheXagent-2-3b"
|
15 |
+
DTYPE = torch.bfloat16
|
16 |
+
DEVICE = "cuda"
|
17 |
+
|
18 |
+
# Configure logging
|
19 |
+
log_filename = f"model_inference_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
20 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
21 |
+
|
22 |
+
|
23 |
+
def initialize_model() -> tuple[AutoModelForCausalLM, AutoTokenizer]:
|
24 |
+
"""Initialize the CheXagent model and tokenizer.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
tuple containing:
|
28 |
+
- AutoModelForCausalLM: The initialized CheXagent model
|
29 |
+
- AutoTokenizer: The initialized tokenizer
|
30 |
+
"""
|
31 |
+
print("Loading model and tokenizer...")
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
33 |
+
model = AutoModelForCausalLM.from_pretrained(
|
34 |
+
MODEL_NAME, device_map="auto", trust_remote_code=True
|
35 |
+
)
|
36 |
+
model = model.to(DTYPE)
|
37 |
+
model.eval()
|
38 |
+
return model, tokenizer
|
39 |
+
|
40 |
+
|
41 |
+
def create_inference_request(
|
42 |
+
question_data: dict,
|
43 |
+
case_details: dict,
|
44 |
+
case_id: str,
|
45 |
+
question_id: str,
|
46 |
+
model: AutoModelForCausalLM,
|
47 |
+
tokenizer: AutoTokenizer,
|
48 |
+
) -> str | None:
|
49 |
+
"""Create and execute an inference request for the CheXagent model.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
question_data: Dictionary containing question details and metadata
|
53 |
+
case_details: Dictionary containing case information and image paths
|
54 |
+
case_id: Unique identifier for the medical case
|
55 |
+
question_id: Unique identifier for the question
|
56 |
+
model: The initialized CheXagent model
|
57 |
+
tokenizer: The initialized tokenizer
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
str | None: Single letter answer (A-F) if successful, None if failed
|
61 |
+
"""
|
62 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
63 |
+
Rules:
|
64 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
65 |
+
2. Do not add periods, explanations, or any other text
|
66 |
+
3. Do not use markdown or formatting
|
67 |
+
4. Do not restate the question
|
68 |
+
5. Do not explain your reasoning
|
69 |
+
|
70 |
+
Examples of valid responses:
|
71 |
+
A
|
72 |
+
B
|
73 |
+
C
|
74 |
+
|
75 |
+
Examples of invalid responses:
|
76 |
+
"A."
|
77 |
+
"Answer: B"
|
78 |
+
"C) This shows..."
|
79 |
+
"The answer is D"
|
80 |
+
"""
|
81 |
+
|
82 |
+
prompt = f"""Given the following medical case:
|
83 |
+
Please answer this multiple choice question:
|
84 |
+
{question_data['question']}
|
85 |
+
Base your answer only on the provided images and case information."""
|
86 |
+
|
87 |
+
# Parse required figures
|
88 |
+
try:
|
89 |
+
if isinstance(question_data["figures"], str):
|
90 |
+
try:
|
91 |
+
required_figures = json.loads(question_data["figures"])
|
92 |
+
except json.JSONDecodeError:
|
93 |
+
required_figures = [question_data["figures"]]
|
94 |
+
elif isinstance(question_data["figures"], list):
|
95 |
+
required_figures = question_data["figures"]
|
96 |
+
else:
|
97 |
+
required_figures = [str(question_data["figures"])]
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Error parsing figures: {e}")
|
100 |
+
required_figures = []
|
101 |
+
|
102 |
+
required_figures = [
|
103 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
104 |
+
]
|
105 |
+
|
106 |
+
# Get image paths
|
107 |
+
image_paths = []
|
108 |
+
for figure in required_figures:
|
109 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
110 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
111 |
+
|
112 |
+
matching_figures = [
|
113 |
+
case_figure
|
114 |
+
for case_figure in case_details.get("figures", [])
|
115 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
116 |
+
]
|
117 |
+
|
118 |
+
for case_figure in matching_figures:
|
119 |
+
subfigures = []
|
120 |
+
if figure_letter:
|
121 |
+
subfigures = [
|
122 |
+
subfig
|
123 |
+
for subfig in case_figure.get("subfigures", [])
|
124 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
125 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
126 |
+
]
|
127 |
+
else:
|
128 |
+
subfigures = case_figure.get("subfigures", [])
|
129 |
+
|
130 |
+
for subfig in subfigures:
|
131 |
+
if "local_path" in subfig:
|
132 |
+
image_paths.append("medrax/data/" + subfig["local_path"])
|
133 |
+
|
134 |
+
if not image_paths:
|
135 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
136 |
+
return None
|
137 |
+
|
138 |
+
try:
|
139 |
+
start_time = time.time()
|
140 |
+
|
141 |
+
# Prepare input for the model
|
142 |
+
query = tokenizer.from_list_format(
|
143 |
+
[*[{"image": path} for path in image_paths], {"text": prompt}]
|
144 |
+
)
|
145 |
+
conv = [{"from": "system", "value": system_prompt}, {"from": "human", "value": query}]
|
146 |
+
input_ids = tokenizer.apply_chat_template(
|
147 |
+
conv, add_generation_prompt=True, return_tensors="pt"
|
148 |
+
)
|
149 |
+
|
150 |
+
# Generate response
|
151 |
+
with torch.no_grad():
|
152 |
+
output = model.generate(
|
153 |
+
input_ids.to(DEVICE),
|
154 |
+
do_sample=False,
|
155 |
+
num_beams=1,
|
156 |
+
temperature=1.0,
|
157 |
+
top_p=1.0,
|
158 |
+
use_cache=True,
|
159 |
+
max_new_tokens=512,
|
160 |
+
)[0]
|
161 |
+
|
162 |
+
response = tokenizer.decode(output[input_ids.size(1) : -1])
|
163 |
+
duration = time.time() - start_time
|
164 |
+
|
165 |
+
# Clean response
|
166 |
+
clean_answer = validate_answer(response)
|
167 |
+
|
168 |
+
# Log response
|
169 |
+
log_entry = {
|
170 |
+
"case_id": case_id,
|
171 |
+
"question_id": question_id,
|
172 |
+
"timestamp": datetime.now().isoformat(),
|
173 |
+
"model": MODEL_NAME,
|
174 |
+
"duration": round(duration, 2),
|
175 |
+
"model_answer": clean_answer,
|
176 |
+
"correct_answer": question_data["answer"],
|
177 |
+
"input": {
|
178 |
+
"question_data": {
|
179 |
+
"question": question_data["question"],
|
180 |
+
"explanation": question_data["explanation"],
|
181 |
+
"metadata": question_data.get("metadata", {}),
|
182 |
+
"figures": question_data["figures"],
|
183 |
+
},
|
184 |
+
"image_paths": image_paths,
|
185 |
+
},
|
186 |
+
}
|
187 |
+
logging.info(json.dumps(log_entry))
|
188 |
+
return clean_answer
|
189 |
+
|
190 |
+
except Exception as e:
|
191 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
192 |
+
log_entry = {
|
193 |
+
"case_id": case_id,
|
194 |
+
"question_id": question_id,
|
195 |
+
"timestamp": datetime.now().isoformat(),
|
196 |
+
"model": MODEL_NAME,
|
197 |
+
"status": "error",
|
198 |
+
"error": str(e),
|
199 |
+
"input": {
|
200 |
+
"question_data": {
|
201 |
+
"question": question_data["question"],
|
202 |
+
"explanation": question_data["explanation"],
|
203 |
+
"metadata": question_data.get("metadata", {}),
|
204 |
+
"figures": question_data["figures"],
|
205 |
+
},
|
206 |
+
"image_paths": image_paths,
|
207 |
+
},
|
208 |
+
}
|
209 |
+
logging.info(json.dumps(log_entry))
|
210 |
+
return None
|
211 |
+
|
212 |
+
|
213 |
+
def validate_answer(response_text: str) -> str | None:
|
214 |
+
"""Enforce strict single-letter response format.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
response_text: Raw response text from the model
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
str | None: Single uppercase letter (A-F) if valid, None if invalid
|
221 |
+
"""
|
222 |
+
if not response_text:
|
223 |
+
return None
|
224 |
+
|
225 |
+
# Remove all whitespace and convert to uppercase
|
226 |
+
cleaned = response_text.strip().upper()
|
227 |
+
|
228 |
+
# Check if it's exactly one valid letter
|
229 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
230 |
+
return cleaned
|
231 |
+
|
232 |
+
# If not, try to extract just the letter
|
233 |
+
match = re.search(r"([A-F])", cleaned)
|
234 |
+
return match.group(1) if match else None
|
235 |
+
|
236 |
+
|
237 |
+
def load_benchmark_questions(case_id: str) -> list[str]:
|
238 |
+
"""Find all question files for a given case ID.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
case_id: Unique identifier for the medical case
|
242 |
+
|
243 |
+
Returns:
|
244 |
+
list[str]: List of paths to question JSON files
|
245 |
+
"""
|
246 |
+
benchmark_dir = "../benchmark/questions"
|
247 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
248 |
+
|
249 |
+
|
250 |
+
def count_total_questions() -> tuple[int, int]:
|
251 |
+
"""Count total number of cases and questions in benchmark.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
tuple containing:
|
255 |
+
- int: Total number of cases
|
256 |
+
- int: Total number of questions
|
257 |
+
"""
|
258 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
259 |
+
total_questions = sum(
|
260 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
261 |
+
for case_id in os.listdir("../benchmark/questions")
|
262 |
+
)
|
263 |
+
return total_cases, total_questions
|
264 |
+
|
265 |
+
|
266 |
+
def main():
|
267 |
+
# Load the cases with local paths
|
268 |
+
with open("medrax/data/updated_cases.json", "r") as file:
|
269 |
+
data = json.load(file)
|
270 |
+
|
271 |
+
# Initialize model and tokenizer
|
272 |
+
model, tokenizer = initialize_model()
|
273 |
+
|
274 |
+
total_cases, total_questions = count_total_questions()
|
275 |
+
cases_processed = 0
|
276 |
+
questions_processed = 0
|
277 |
+
skipped_questions = 0
|
278 |
+
|
279 |
+
print(f"\nBeginning inference with {MODEL_NAME}")
|
280 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
281 |
+
|
282 |
+
# Process each case with progress bar
|
283 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
284 |
+
question_files = load_benchmark_questions(case_id)
|
285 |
+
if not question_files:
|
286 |
+
continue
|
287 |
+
|
288 |
+
cases_processed += 1
|
289 |
+
for question_file in tqdm(
|
290 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
291 |
+
):
|
292 |
+
with open(question_file, "r") as file:
|
293 |
+
question_data = json.load(file)
|
294 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
295 |
+
|
296 |
+
questions_processed += 1
|
297 |
+
answer = create_inference_request(
|
298 |
+
question_data, case_details, case_id, question_id, model, tokenizer
|
299 |
+
)
|
300 |
+
|
301 |
+
if answer is None:
|
302 |
+
skipped_questions += 1
|
303 |
+
continue
|
304 |
+
|
305 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
306 |
+
print(f"Model Answer: {answer}")
|
307 |
+
print(f"Correct Answer: {question_data['answer']}")
|
308 |
+
|
309 |
+
print(f"\nInference Summary:")
|
310 |
+
print(f"Total Cases Processed: {cases_processed}")
|
311 |
+
print(f"Total Questions Processed: {questions_processed}")
|
312 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
313 |
+
|
314 |
+
|
315 |
+
if __name__ == "__main__":
|
316 |
+
main()
|
experiments/benchmark_gpt4o.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
import glob
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
from datetime import datetime
|
8 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
9 |
+
|
10 |
+
model_name = "chatgpt-4o-latest"
|
11 |
+
temperature = 0.2
|
12 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
13 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
14 |
+
|
15 |
+
|
16 |
+
def calculate_cost(
|
17 |
+
prompt_tokens: int, completion_tokens: int, model: str = "chatgpt-4o-latest"
|
18 |
+
) -> float:
|
19 |
+
"""Calculate the cost of API usage based on token counts.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
prompt_tokens: Number of tokens in the prompt
|
23 |
+
completion_tokens: Number of tokens in the completion
|
24 |
+
model: Model name to use for pricing, defaults to chatgpt-4o-latest
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
float: Cost in USD
|
28 |
+
"""
|
29 |
+
pricing = {"chatgpt-4o-latest": {"prompt": 5.0, "completion": 15.0}}
|
30 |
+
rates = pricing.get(model, {"prompt": 5.0, "completion": 15.0})
|
31 |
+
return (prompt_tokens * rates["prompt"] + completion_tokens * rates["completion"]) / 1000000
|
32 |
+
|
33 |
+
|
34 |
+
@retry(wait=wait_exponential(multiplier=1, min=4, max=10), stop=stop_after_attempt(3))
|
35 |
+
def create_multimodal_request(
|
36 |
+
question_data: dict, case_details: dict, case_id: str, question_id: str, client: openai.OpenAI
|
37 |
+
) -> openai.types.chat.ChatCompletion:
|
38 |
+
"""Create and send a multimodal request to the OpenAI API.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
question_data: Dictionary containing question details and figures
|
42 |
+
case_details: Dictionary containing case information and figures
|
43 |
+
case_id: Identifier for the medical case
|
44 |
+
question_id: Identifier for the specific question
|
45 |
+
client: OpenAI client instance
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
openai.types.chat.ChatCompletion: API response object, or None if request fails
|
49 |
+
"""
|
50 |
+
prompt = f"""Given the following medical case:
|
51 |
+
Please answer this multiple choice question:
|
52 |
+
{question_data['question']}
|
53 |
+
Base your answer only on the provided images and case information."""
|
54 |
+
|
55 |
+
content = [{"type": "text", "text": prompt}]
|
56 |
+
|
57 |
+
# Parse required figures
|
58 |
+
try:
|
59 |
+
# Try multiple ways of parsing figures
|
60 |
+
if isinstance(question_data["figures"], str):
|
61 |
+
try:
|
62 |
+
required_figures = json.loads(question_data["figures"])
|
63 |
+
except json.JSONDecodeError:
|
64 |
+
required_figures = [question_data["figures"]]
|
65 |
+
elif isinstance(question_data["figures"], list):
|
66 |
+
required_figures = question_data["figures"]
|
67 |
+
else:
|
68 |
+
required_figures = [str(question_data["figures"])]
|
69 |
+
except Exception as e:
|
70 |
+
print(f"Error parsing figures: {e}")
|
71 |
+
required_figures = []
|
72 |
+
|
73 |
+
# Ensure each figure starts with "Figure "
|
74 |
+
required_figures = [
|
75 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
76 |
+
]
|
77 |
+
|
78 |
+
subfigures = []
|
79 |
+
for figure in required_figures:
|
80 |
+
# Handle both regular figures and those with letter suffixes
|
81 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
82 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
83 |
+
|
84 |
+
# Find matching figures in case details
|
85 |
+
matching_figures = [
|
86 |
+
case_figure
|
87 |
+
for case_figure in case_details.get("figures", [])
|
88 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
89 |
+
]
|
90 |
+
|
91 |
+
if not matching_figures:
|
92 |
+
print(f"No matching figure found for {figure} in case {case_id}")
|
93 |
+
continue
|
94 |
+
|
95 |
+
for case_figure in matching_figures:
|
96 |
+
# If a specific letter is specified, filter subfigures
|
97 |
+
if figure_letter:
|
98 |
+
matching_subfigures = [
|
99 |
+
subfig
|
100 |
+
for subfig in case_figure.get("subfigures", [])
|
101 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
102 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
103 |
+
]
|
104 |
+
subfigures.extend(matching_subfigures)
|
105 |
+
else:
|
106 |
+
# If no letter specified, add all subfigures
|
107 |
+
subfigures.extend(case_figure.get("subfigures", []))
|
108 |
+
|
109 |
+
# Add images to content
|
110 |
+
for subfig in subfigures:
|
111 |
+
if "url" in subfig:
|
112 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
113 |
+
else:
|
114 |
+
print(f"Subfigure missing URL: {subfig}")
|
115 |
+
|
116 |
+
# If no images found, log and return None
|
117 |
+
if len(content) == 1: # Only the text prompt exists
|
118 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
119 |
+
return None
|
120 |
+
|
121 |
+
messages = [
|
122 |
+
{
|
123 |
+
"role": "system",
|
124 |
+
"content": "You are a medical imaging expert. Provide only the letter corresponding to your answer choice (A/B/C/D/E/F).",
|
125 |
+
},
|
126 |
+
{"role": "user", "content": content},
|
127 |
+
]
|
128 |
+
|
129 |
+
if len(content) == 1: # Only the text prompt exists
|
130 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
131 |
+
log_entry = {
|
132 |
+
"case_id": case_id,
|
133 |
+
"question_id": question_id,
|
134 |
+
"timestamp": datetime.now().isoformat(),
|
135 |
+
"model": model_name,
|
136 |
+
"temperature": temperature,
|
137 |
+
"status": "skipped",
|
138 |
+
"reason": "no_images",
|
139 |
+
"cost": 0,
|
140 |
+
"input": {
|
141 |
+
"messages": messages,
|
142 |
+
"question_data": {
|
143 |
+
"question": question_data["question"],
|
144 |
+
"explanation": question_data["explanation"],
|
145 |
+
"metadata": question_data.get("metadata", {}),
|
146 |
+
"figures": question_data["figures"],
|
147 |
+
},
|
148 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
149 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
150 |
+
},
|
151 |
+
}
|
152 |
+
logging.info(json.dumps(log_entry))
|
153 |
+
return None
|
154 |
+
|
155 |
+
try:
|
156 |
+
start_time = time.time()
|
157 |
+
|
158 |
+
response = client.chat.completions.create(
|
159 |
+
model=model_name, messages=messages, max_tokens=50, temperature=temperature
|
160 |
+
)
|
161 |
+
duration = time.time() - start_time
|
162 |
+
|
163 |
+
log_entry = {
|
164 |
+
"case_id": case_id,
|
165 |
+
"question_id": question_id,
|
166 |
+
"timestamp": datetime.now().isoformat(),
|
167 |
+
"model": model_name,
|
168 |
+
"temperature": temperature,
|
169 |
+
"duration": round(duration, 2),
|
170 |
+
"usage": {
|
171 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
172 |
+
"completion_tokens": response.usage.completion_tokens,
|
173 |
+
"total_tokens": response.usage.total_tokens,
|
174 |
+
},
|
175 |
+
"cost": calculate_cost(response.usage.prompt_tokens, response.usage.completion_tokens),
|
176 |
+
"model_answer": response.choices[0].message.content,
|
177 |
+
"correct_answer": question_data["answer"],
|
178 |
+
"input": {
|
179 |
+
"messages": messages,
|
180 |
+
"question_data": {
|
181 |
+
"question": question_data["question"],
|
182 |
+
"explanation": question_data["explanation"],
|
183 |
+
"metadata": question_data.get("metadata", {}),
|
184 |
+
"figures": question_data["figures"],
|
185 |
+
},
|
186 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
187 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
188 |
+
},
|
189 |
+
}
|
190 |
+
logging.info(json.dumps(log_entry))
|
191 |
+
return response
|
192 |
+
|
193 |
+
except openai.RateLimitError:
|
194 |
+
log_entry = {
|
195 |
+
"case_id": case_id,
|
196 |
+
"question_id": question_id,
|
197 |
+
"timestamp": datetime.now().isoformat(),
|
198 |
+
"model": model_name,
|
199 |
+
"temperature": temperature,
|
200 |
+
"status": "error",
|
201 |
+
"reason": "rate_limit",
|
202 |
+
"cost": 0,
|
203 |
+
"input": {
|
204 |
+
"messages": messages,
|
205 |
+
"question_data": {
|
206 |
+
"question": question_data["question"],
|
207 |
+
"explanation": question_data["explanation"],
|
208 |
+
"metadata": question_data.get("metadata", {}),
|
209 |
+
"figures": question_data["figures"],
|
210 |
+
},
|
211 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
212 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
213 |
+
},
|
214 |
+
}
|
215 |
+
logging.info(json.dumps(log_entry))
|
216 |
+
print(
|
217 |
+
f"\nRate limit hit for case {case_id}, question {question_id}. Waiting 20s...",
|
218 |
+
flush=True,
|
219 |
+
)
|
220 |
+
time.sleep(20)
|
221 |
+
raise
|
222 |
+
except Exception as e:
|
223 |
+
log_entry = {
|
224 |
+
"case_id": case_id,
|
225 |
+
"question_id": question_id,
|
226 |
+
"timestamp": datetime.now().isoformat(),
|
227 |
+
"model": model_name,
|
228 |
+
"temperature": temperature,
|
229 |
+
"status": "error",
|
230 |
+
"error": str(e),
|
231 |
+
"cost": 0,
|
232 |
+
"input": {
|
233 |
+
"messages": messages,
|
234 |
+
"question_data": {
|
235 |
+
"question": question_data["question"],
|
236 |
+
"explanation": question_data["explanation"],
|
237 |
+
"metadata": question_data.get("metadata", {}),
|
238 |
+
"figures": question_data["figures"],
|
239 |
+
},
|
240 |
+
"image_urls": [subfig["url"] for subfig in subfigures if "url" in subfig],
|
241 |
+
"image_captions": [subfig.get("caption", "") for subfig in subfigures],
|
242 |
+
},
|
243 |
+
}
|
244 |
+
logging.info(json.dumps(log_entry))
|
245 |
+
print(f"Error processing case {case_id}, question {question_id}: {str(e)}")
|
246 |
+
raise
|
247 |
+
|
248 |
+
|
249 |
+
def load_benchmark_questions(case_id: str) -> list:
|
250 |
+
"""Load benchmark questions for a given case.
|
251 |
+
|
252 |
+
Args:
|
253 |
+
case_id: Identifier for the medical case
|
254 |
+
|
255 |
+
Returns:
|
256 |
+
list: List of paths to question files
|
257 |
+
"""
|
258 |
+
benchmark_dir = "../benchmark/questions"
|
259 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
260 |
+
|
261 |
+
|
262 |
+
def count_total_questions() -> tuple[int, int]:
|
263 |
+
"""Count total number of cases and questions in benchmark.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
tuple: (total_cases, total_questions)
|
267 |
+
"""
|
268 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
269 |
+
total_questions = sum(
|
270 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
271 |
+
for case_id in os.listdir("../benchmark/questions")
|
272 |
+
)
|
273 |
+
return total_cases, total_questions
|
274 |
+
|
275 |
+
|
276 |
+
def main() -> None:
|
277 |
+
"""Main function to run the benchmark evaluation."""
|
278 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
279 |
+
data = json.load(file)
|
280 |
+
|
281 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
282 |
+
if not api_key:
|
283 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
284 |
+
global client
|
285 |
+
client = openai.OpenAI(api_key=api_key)
|
286 |
+
|
287 |
+
total_cases, total_questions = count_total_questions()
|
288 |
+
cases_processed = 0
|
289 |
+
questions_processed = 0
|
290 |
+
skipped_questions = 0
|
291 |
+
|
292 |
+
print(f"Beginning benchmark evaluation for model {model_name} with temperature {temperature}")
|
293 |
+
|
294 |
+
for case_id, case_details in data.items():
|
295 |
+
question_files = load_benchmark_questions(case_id)
|
296 |
+
if not question_files:
|
297 |
+
continue
|
298 |
+
|
299 |
+
cases_processed += 1
|
300 |
+
for question_file in question_files:
|
301 |
+
with open(question_file, "r") as file:
|
302 |
+
question_data = json.load(file)
|
303 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
304 |
+
|
305 |
+
questions_processed += 1
|
306 |
+
response = create_multimodal_request(
|
307 |
+
question_data, case_details, case_id, question_id, client
|
308 |
+
)
|
309 |
+
|
310 |
+
# Handle cases where response is None
|
311 |
+
if response is None:
|
312 |
+
skipped_questions += 1
|
313 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
314 |
+
continue
|
315 |
+
|
316 |
+
print(
|
317 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
318 |
+
)
|
319 |
+
print(f"Case ID: {case_id}")
|
320 |
+
print(f"Question ID: {question_id}")
|
321 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
322 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
323 |
+
|
324 |
+
print(f"\nBenchmark Summary:")
|
325 |
+
print(f"Total Cases Processed: {cases_processed}")
|
326 |
+
print(f"Total Questions Processed: {questions_processed}")
|
327 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
328 |
+
|
329 |
+
|
330 |
+
if __name__ == "__main__":
|
331 |
+
main()
|
experiments/benchmark_llama.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Optional, Any, Union
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import glob
|
6 |
+
import time
|
7 |
+
import logging
|
8 |
+
import socket
|
9 |
+
import requests
|
10 |
+
import httpx
|
11 |
+
import backoff
|
12 |
+
from datetime import datetime
|
13 |
+
from tenacity import retry, wait_exponential, stop_after_attempt
|
14 |
+
from openai import OpenAI
|
15 |
+
|
16 |
+
# Configure model settings
|
17 |
+
MODEL_NAME = "meta-llama/llama-3.2-90b-vision-instruct"
|
18 |
+
temperature = 0.2
|
19 |
+
|
20 |
+
# Configure logging
|
21 |
+
log_filename = f"api_usage_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
22 |
+
logging.basicConfig(filename=log_filename, level=logging.INFO, format="%(message)s")
|
23 |
+
|
24 |
+
|
25 |
+
def verify_dns() -> bool:
|
26 |
+
"""Verify DNS resolution and connectivity.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
bool: True if DNS resolution succeeds, False otherwise
|
30 |
+
"""
|
31 |
+
try:
|
32 |
+
# Try to resolve openrouter.ai
|
33 |
+
socket.gethostbyname("openrouter.ai")
|
34 |
+
return True
|
35 |
+
except socket.gaierror:
|
36 |
+
print("DNS resolution failed. Trying to use Google DNS (8.8.8.8)...")
|
37 |
+
# Modify resolv.conf to use Google DNS
|
38 |
+
try:
|
39 |
+
with open("/etc/resolv.conf", "w") as f:
|
40 |
+
f.write("nameserver 8.8.8.8\n")
|
41 |
+
return True
|
42 |
+
except Exception as e:
|
43 |
+
print(f"Failed to update DNS settings: {e}")
|
44 |
+
return False
|
45 |
+
|
46 |
+
|
47 |
+
def verify_connection() -> bool:
|
48 |
+
"""Verify connection to OpenRouter API.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
bool: True if connection succeeds, False otherwise
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
response = requests.get("https://openrouter.ai/api/v1/status", timeout=10)
|
55 |
+
return response.status_code == 200
|
56 |
+
except Exception as e:
|
57 |
+
print(f"Connection test failed: {e}")
|
58 |
+
return False
|
59 |
+
|
60 |
+
|
61 |
+
def initialize_client() -> OpenAI:
|
62 |
+
"""Initialize the OpenRouter client with proper timeout settings and connection verification.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
OpenAI: Configured OpenAI client for OpenRouter
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
ValueError: If OPENROUTER_API_KEY environment variable is not set
|
69 |
+
ConnectionError: If DNS verification or connection test fails
|
70 |
+
"""
|
71 |
+
api_key = os.getenv("OPENROUTER_API_KEY")
|
72 |
+
if not api_key:
|
73 |
+
raise ValueError("OPENROUTER_API_KEY environment variable is not set.")
|
74 |
+
|
75 |
+
# Configure timeout settings for the client
|
76 |
+
timeout_settings = 120 # Increased timeout for large images/responses
|
77 |
+
|
78 |
+
# Verify DNS and connection
|
79 |
+
if not verify_dns():
|
80 |
+
raise ConnectionError("DNS verification failed. Please check your network settings.")
|
81 |
+
|
82 |
+
if not verify_connection():
|
83 |
+
raise ConnectionError(
|
84 |
+
"Cannot connect to OpenRouter. Please check your internet connection."
|
85 |
+
)
|
86 |
+
|
87 |
+
# Set up client with retry and timeout settings
|
88 |
+
return OpenAI(
|
89 |
+
base_url="https://openrouter.ai/api/v1",
|
90 |
+
api_key=api_key,
|
91 |
+
timeout=timeout_settings,
|
92 |
+
http_client=httpx.Client(
|
93 |
+
timeout=timeout_settings, transport=httpx.HTTPTransport(retries=3)
|
94 |
+
),
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
@backoff.on_exception(
|
99 |
+
backoff.expo,
|
100 |
+
(ConnectionError, TimeoutError, socket.gaierror, httpx.ConnectError),
|
101 |
+
max_tries=5,
|
102 |
+
max_time=300, # Maximum total time to try in seconds
|
103 |
+
)
|
104 |
+
def create_multimodal_request(
|
105 |
+
question_data: Dict[str, Any],
|
106 |
+
case_details: Dict[str, Any],
|
107 |
+
case_id: str,
|
108 |
+
question_id: str,
|
109 |
+
client: OpenAI,
|
110 |
+
) -> Optional[Any]:
|
111 |
+
"""Create and send a multimodal request to the model.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
question_data: Dictionary containing question details
|
115 |
+
case_details: Dictionary containing case information
|
116 |
+
case_id: ID of the medical case
|
117 |
+
question_id: ID of the specific question
|
118 |
+
client: OpenAI client instance
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Optional[Any]: Model response if successful, None if skipped
|
122 |
+
|
123 |
+
Raises:
|
124 |
+
ConnectionError: If connection fails
|
125 |
+
TimeoutError: If request times out
|
126 |
+
Exception: For other errors
|
127 |
+
"""
|
128 |
+
|
129 |
+
system_prompt = """You are a medical imaging expert. Your task is to provide ONLY a single letter answer.
|
130 |
+
Rules:
|
131 |
+
1. Respond with exactly one uppercase letter (A/B/C/D/E/F)
|
132 |
+
2. Do not add periods, explanations, or any other text
|
133 |
+
3. Do not use markdown or formatting
|
134 |
+
4. Do not restate the question
|
135 |
+
5. Do not explain your reasoning
|
136 |
+
|
137 |
+
Examples of valid responses:
|
138 |
+
A
|
139 |
+
B
|
140 |
+
C
|
141 |
+
|
142 |
+
Examples of invalid responses:
|
143 |
+
"A."
|
144 |
+
"Answer: B"
|
145 |
+
"C) This shows..."
|
146 |
+
"The answer is D"
|
147 |
+
"""
|
148 |
+
|
149 |
+
prompt = f"""Given the following medical case:
|
150 |
+
Please answer this multiple choice question:
|
151 |
+
{question_data['question']}
|
152 |
+
Base your answer only on the provided images and case information."""
|
153 |
+
|
154 |
+
# Parse required figures
|
155 |
+
try:
|
156 |
+
if isinstance(question_data["figures"], str):
|
157 |
+
try:
|
158 |
+
required_figures = json.loads(question_data["figures"])
|
159 |
+
except json.JSONDecodeError:
|
160 |
+
required_figures = [question_data["figures"]]
|
161 |
+
elif isinstance(question_data["figures"], list):
|
162 |
+
required_figures = question_data["figures"]
|
163 |
+
else:
|
164 |
+
required_figures = [str(question_data["figures"])]
|
165 |
+
except Exception as e:
|
166 |
+
print(f"Error parsing figures: {e}")
|
167 |
+
required_figures = []
|
168 |
+
|
169 |
+
required_figures = [
|
170 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
171 |
+
]
|
172 |
+
|
173 |
+
# Process subfigures and prepare content
|
174 |
+
content = [{"type": "text", "text": prompt}]
|
175 |
+
image_urls = []
|
176 |
+
image_captions = []
|
177 |
+
|
178 |
+
for figure in required_figures:
|
179 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
180 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
181 |
+
|
182 |
+
matching_figures = [
|
183 |
+
case_figure
|
184 |
+
for case_figure in case_details.get("figures", [])
|
185 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
186 |
+
]
|
187 |
+
|
188 |
+
for case_figure in matching_figures:
|
189 |
+
subfigures = []
|
190 |
+
if figure_letter:
|
191 |
+
subfigures = [
|
192 |
+
subfig
|
193 |
+
for subfig in case_figure.get("subfigures", [])
|
194 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
195 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
196 |
+
]
|
197 |
+
else:
|
198 |
+
subfigures = case_figure.get("subfigures", [])
|
199 |
+
|
200 |
+
for subfig in subfigures:
|
201 |
+
if "url" in subfig:
|
202 |
+
content.append({"type": "image_url", "image_url": {"url": subfig["url"]}})
|
203 |
+
image_urls.append(subfig["url"])
|
204 |
+
image_captions.append(subfig.get("caption", ""))
|
205 |
+
|
206 |
+
if len(content) == 1: # Only the text prompt exists
|
207 |
+
print(f"No images found for case {case_id}, question {question_id}")
|
208 |
+
# Log the skipped question
|
209 |
+
log_entry = {
|
210 |
+
"case_id": case_id,
|
211 |
+
"question_id": question_id,
|
212 |
+
"timestamp": datetime.now().isoformat(),
|
213 |
+
"model": MODEL_NAME,
|
214 |
+
"status": "skipped",
|
215 |
+
"reason": "no_images",
|
216 |
+
"input": {
|
217 |
+
"question_data": {
|
218 |
+
"question": question_data["question"],
|
219 |
+
"explanation": question_data["explanation"],
|
220 |
+
"metadata": question_data.get("metadata", {}),
|
221 |
+
"figures": question_data["figures"],
|
222 |
+
},
|
223 |
+
"image_urls": image_urls,
|
224 |
+
},
|
225 |
+
}
|
226 |
+
logging.info(json.dumps(log_entry))
|
227 |
+
return None
|
228 |
+
|
229 |
+
try:
|
230 |
+
start_time = time.time()
|
231 |
+
|
232 |
+
response = client.chat.completions.create(
|
233 |
+
model=MODEL_NAME,
|
234 |
+
temperature=temperature,
|
235 |
+
messages=[
|
236 |
+
{"role": "system", "content": system_prompt},
|
237 |
+
{"role": "user", "content": content},
|
238 |
+
],
|
239 |
+
)
|
240 |
+
duration = time.time() - start_time
|
241 |
+
|
242 |
+
# Get raw response
|
243 |
+
raw_answer = response.choices[0].message.content
|
244 |
+
|
245 |
+
# Validate and clean
|
246 |
+
clean_answer = validate_answer(raw_answer)
|
247 |
+
|
248 |
+
if not clean_answer:
|
249 |
+
print(f"Warning: Invalid response format for case {case_id}, question {question_id}")
|
250 |
+
print(f"Raw response: {raw_answer}")
|
251 |
+
|
252 |
+
# Update response object with cleaned answer
|
253 |
+
response.choices[0].message.content = clean_answer
|
254 |
+
|
255 |
+
# Log response
|
256 |
+
log_entry = {
|
257 |
+
"case_id": case_id,
|
258 |
+
"question_id": question_id,
|
259 |
+
"timestamp": datetime.now().isoformat(),
|
260 |
+
"model": MODEL_NAME,
|
261 |
+
"temperature": temperature,
|
262 |
+
"duration": round(duration, 2),
|
263 |
+
"usage": {
|
264 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
265 |
+
"completion_tokens": response.usage.completion_tokens,
|
266 |
+
"total_tokens": response.usage.total_tokens,
|
267 |
+
},
|
268 |
+
"model_answer": response.choices[0].message.content,
|
269 |
+
"correct_answer": question_data["answer"],
|
270 |
+
"input": {
|
271 |
+
"question_data": {
|
272 |
+
"question": question_data["question"],
|
273 |
+
"explanation": question_data["explanation"],
|
274 |
+
"metadata": question_data.get("metadata", {}),
|
275 |
+
"figures": question_data["figures"],
|
276 |
+
},
|
277 |
+
"image_urls": image_urls,
|
278 |
+
},
|
279 |
+
}
|
280 |
+
logging.info(json.dumps(log_entry))
|
281 |
+
return response
|
282 |
+
|
283 |
+
except ConnectionError as e:
|
284 |
+
print(f"Connection error for case {case_id}, question {question_id}: {str(e)}")
|
285 |
+
print("Retrying after a longer delay...")
|
286 |
+
time.sleep(30) # Add a longer delay before retry
|
287 |
+
raise
|
288 |
+
except TimeoutError as e:
|
289 |
+
print(f"Timeout error for case {case_id}, question {question_id}: {str(e)}")
|
290 |
+
print("Retrying with increased timeout...")
|
291 |
+
raise
|
292 |
+
except Exception as e:
|
293 |
+
# Log failed requests too
|
294 |
+
log_entry = {
|
295 |
+
"case_id": case_id,
|
296 |
+
"question_id": question_id,
|
297 |
+
"timestamp": datetime.now().isoformat(),
|
298 |
+
"model": MODEL_NAME,
|
299 |
+
"temperature": temperature,
|
300 |
+
"status": "error",
|
301 |
+
"error": str(e),
|
302 |
+
"input": {
|
303 |
+
"question_data": {
|
304 |
+
"question": question_data["question"],
|
305 |
+
"explanation": question_data["explanation"],
|
306 |
+
"metadata": question_data.get("metadata", {}),
|
307 |
+
"figures": question_data["figures"],
|
308 |
+
},
|
309 |
+
"image_urls": image_urls,
|
310 |
+
},
|
311 |
+
}
|
312 |
+
logging.info(json.dumps(log_entry))
|
313 |
+
raise
|
314 |
+
|
315 |
+
|
316 |
+
def extract_answer(response_text: str) -> Optional[str]:
|
317 |
+
"""Extract single letter answer from model response.
|
318 |
+
|
319 |
+
Args:
|
320 |
+
response_text: Raw text response from model
|
321 |
+
|
322 |
+
Returns:
|
323 |
+
Optional[str]: Single letter answer if found, None otherwise
|
324 |
+
"""
|
325 |
+
# Convert to uppercase and remove periods
|
326 |
+
text = response_text.upper().replace(".", "")
|
327 |
+
|
328 |
+
# Look for common patterns
|
329 |
+
patterns = [
|
330 |
+
r"ANSWER:\s*([A-F])", # Matches "ANSWER: X"
|
331 |
+
r"OPTION\s*([A-F])", # Matches "OPTION X"
|
332 |
+
r"([A-F])\)", # Matches "X)"
|
333 |
+
r"\b([A-F])\b", # Matches single letter
|
334 |
+
]
|
335 |
+
|
336 |
+
for pattern in patterns:
|
337 |
+
matches = re.findall(pattern, text)
|
338 |
+
if matches:
|
339 |
+
return matches[0]
|
340 |
+
|
341 |
+
return None
|
342 |
+
|
343 |
+
|
344 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
345 |
+
"""Enforce strict single-letter response format.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
response_text: Raw text response from model
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
Optional[str]: Valid single letter answer if found, None otherwise
|
352 |
+
"""
|
353 |
+
if not response_text:
|
354 |
+
return None
|
355 |
+
|
356 |
+
# Remove all whitespace and convert to uppercase
|
357 |
+
cleaned = response_text.strip().upper()
|
358 |
+
|
359 |
+
# Check if it's exactly one valid letter
|
360 |
+
if len(cleaned) == 1 and cleaned in "ABCDEF":
|
361 |
+
return cleaned
|
362 |
+
|
363 |
+
# If not, try to extract just the letter
|
364 |
+
match = re.search(r"([A-F])", cleaned)
|
365 |
+
return match.group(1) if match else None
|
366 |
+
|
367 |
+
|
368 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
369 |
+
"""Find all question files for a given case ID.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
case_id: ID of the medical case
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
List[str]: List of paths to question files
|
376 |
+
"""
|
377 |
+
benchmark_dir = "../benchmark/questions"
|
378 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
379 |
+
|
380 |
+
|
381 |
+
def count_total_questions() -> Tuple[int, int]:
|
382 |
+
"""Count total number of cases and questions.
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
Tuple[int, int]: (total_cases, total_questions)
|
386 |
+
"""
|
387 |
+
total_cases = len(glob.glob("../benchmark/questions/*"))
|
388 |
+
total_questions = sum(
|
389 |
+
len(glob.glob(f"../benchmark/questions/{case_id}/*.json"))
|
390 |
+
for case_id in os.listdir("../benchmark/questions")
|
391 |
+
)
|
392 |
+
return total_cases, total_questions
|
393 |
+
|
394 |
+
|
395 |
+
def main():
|
396 |
+
with open("../data/eurorad_metadata.json", "r") as file:
|
397 |
+
data = json.load(file)
|
398 |
+
|
399 |
+
client = initialize_client()
|
400 |
+
total_cases, total_questions = count_total_questions()
|
401 |
+
cases_processed = 0
|
402 |
+
questions_processed = 0
|
403 |
+
skipped_questions = 0
|
404 |
+
|
405 |
+
print(f"Beginning benchmark evaluation for {MODEL_NAME} with temperature {temperature}")
|
406 |
+
|
407 |
+
for case_id, case_details in data.items():
|
408 |
+
question_files = load_benchmark_questions(case_id)
|
409 |
+
if not question_files:
|
410 |
+
continue
|
411 |
+
|
412 |
+
cases_processed += 1
|
413 |
+
for question_file in question_files:
|
414 |
+
with open(question_file, "r") as file:
|
415 |
+
question_data = json.load(file)
|
416 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
417 |
+
|
418 |
+
questions_processed += 1
|
419 |
+
response = create_multimodal_request(
|
420 |
+
question_data, case_details, case_id, question_id, client
|
421 |
+
)
|
422 |
+
|
423 |
+
if response is None:
|
424 |
+
skipped_questions += 1
|
425 |
+
print(f"Skipped question: Case ID {case_id}, Question ID {question_id}")
|
426 |
+
continue
|
427 |
+
|
428 |
+
print(
|
429 |
+
f"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}"
|
430 |
+
)
|
431 |
+
print(f"Case ID: {case_id}")
|
432 |
+
print(f"Question ID: {question_id}")
|
433 |
+
print(f"Model Answer: {response.choices[0].message.content}")
|
434 |
+
print(f"Correct Answer: {question_data['answer']}\n")
|
435 |
+
|
436 |
+
print(f"\nBenchmark Summary:")
|
437 |
+
print(f"Total Cases Processed: {cases_processed}")
|
438 |
+
print(f"Total Questions Processed: {questions_processed}")
|
439 |
+
print(f"Total Questions Skipped: {skipped_questions}")
|
440 |
+
|
441 |
+
|
442 |
+
if __name__ == "__main__":
|
443 |
+
main()
|
experiments/benchmark_llavamed.py
ADDED
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import base64
|
5 |
+
from PIL import Image
|
6 |
+
from io import BytesIO
|
7 |
+
from llava.conversation import conv_templates
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import glob
|
11 |
+
import logging
|
12 |
+
from datetime import datetime
|
13 |
+
from tqdm import tqdm
|
14 |
+
import re
|
15 |
+
from typing import Dict, List, Optional, Union, Any, Tuple
|
16 |
+
|
17 |
+
|
18 |
+
def process_image(image_path: str, target_size: int = 640) -> Image.Image:
|
19 |
+
"""Process and resize an image to match model requirements.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
image_path: Path to the input image file
|
23 |
+
target_size: Target size for both width and height in pixels
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
PIL.Image: Processed and padded image with dimensions (target_size, target_size)
|
27 |
+
"""
|
28 |
+
image = Image.open(image_path)
|
29 |
+
if image.mode != "RGB":
|
30 |
+
image = image.convert("RGB")
|
31 |
+
|
32 |
+
# Calculate scaling to maintain aspect ratio
|
33 |
+
ratio = min(target_size / image.width, target_size / image.height)
|
34 |
+
new_size = (int(image.width * ratio), int(image.height * ratio))
|
35 |
+
|
36 |
+
# Resize image
|
37 |
+
image = image.resize(new_size, Image.LANCZOS)
|
38 |
+
|
39 |
+
# Create new image with padding
|
40 |
+
new_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
|
41 |
+
# Paste resized image in center
|
42 |
+
offset = ((target_size - new_size[0]) // 2, (target_size - new_size[1]) // 2)
|
43 |
+
new_image.paste(image, offset)
|
44 |
+
|
45 |
+
return new_image
|
46 |
+
|
47 |
+
|
48 |
+
def validate_answer(response_text: str) -> Optional[str]:
|
49 |
+
"""Extract and validate a single-letter response from the model's output.
|
50 |
+
Handles multiple response formats and edge cases.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
response_text: The full text output from the model
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
A single letter answer (A-F) or None if no valid answer found
|
57 |
+
"""
|
58 |
+
if not response_text:
|
59 |
+
return None
|
60 |
+
|
61 |
+
# Clean the response text
|
62 |
+
cleaned = response_text.strip()
|
63 |
+
|
64 |
+
# Comprehensive set of patterns to extract the answer
|
65 |
+
extraction_patterns = [
|
66 |
+
# Strict format with explicit letter answer
|
67 |
+
r"(?:THE\s*)?(?:SINGLE\s*)?LETTER\s*(?:ANSWER\s*)?(?:IS:?)\s*([A-F])\b",
|
68 |
+
# Patterns for extracting from longer descriptions
|
69 |
+
r"(?:correct\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
70 |
+
r"\b(?:answer|option)\s*([A-F])[):]\s*",
|
71 |
+
# Patterns for extracting from descriptive sentences
|
72 |
+
r"(?:most\s+likely\s+)?(?:answer|option)\s*(?:is\s*)?([A-F])\b",
|
73 |
+
r"suggest[s]?\s+(?:that\s+)?(?:the\s+)?(?:answer\s+)?(?:is\s*)?([A-F])\b",
|
74 |
+
# Patterns with contextual words
|
75 |
+
r"characteriz[e]?d?\s+by\s+([A-F])\b",
|
76 |
+
r"indicat[e]?s?\s+([A-F])\b",
|
77 |
+
# Fallback to Option X or Letterr X formats
|
78 |
+
r"Option\s*([A-F])\b",
|
79 |
+
r"\b([A-F])\)\s*",
|
80 |
+
# Fallback to standalone letter
|
81 |
+
r"^\s*([A-F])\s*$",
|
82 |
+
]
|
83 |
+
|
84 |
+
# Try each pattern
|
85 |
+
for pattern in extraction_patterns:
|
86 |
+
matches = re.findall(pattern, cleaned, re.IGNORECASE)
|
87 |
+
for match in matches:
|
88 |
+
# Ensure match is a single valid letter
|
89 |
+
if isinstance(match, tuple):
|
90 |
+
match = match[0] if match[0] in "ABCDEF" else None
|
91 |
+
if match and match.upper() in "ABCDEF":
|
92 |
+
return match.upper()
|
93 |
+
|
94 |
+
# Final fallback: look for standalone letters in context
|
95 |
+
context_matches = re.findall(r"\b([A-F])\b", cleaned.upper())
|
96 |
+
context_letters = [m for m in context_matches if m in "ABCDEF"]
|
97 |
+
if context_letters:
|
98 |
+
return context_letters[0]
|
99 |
+
|
100 |
+
# No valid answer found
|
101 |
+
return None
|
102 |
+
|
103 |
+
|
104 |
+
def load_benchmark_questions(case_id: str) -> List[str]:
|
105 |
+
"""Find all question files for a given case ID.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
case_id: The ID of the medical case
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
List of paths to question JSON files
|
112 |
+
"""
|
113 |
+
benchmark_dir = "MedMAX/benchmark/questions"
|
114 |
+
return glob.glob(f"{benchmark_dir}/{case_id}/{case_id}_*.json")
|
115 |
+
|
116 |
+
|
117 |
+
def count_total_questions() -> Tuple[int, int]:
|
118 |
+
"""Count total number of cases and questions in benchmark.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
Tuple containing (total_cases, total_questions)
|
122 |
+
"""
|
123 |
+
total_cases = len(glob.glob("MedMAX/benchmark/questions/*"))
|
124 |
+
total_questions = sum(
|
125 |
+
len(glob.glob(f"MedMAX/benchmark/questions/{case_id}/*.json"))
|
126 |
+
for case_id in os.listdir("MedMAX/benchmark/questions")
|
127 |
+
)
|
128 |
+
return total_cases, total_questions
|
129 |
+
|
130 |
+
|
131 |
+
def create_inference_request(
|
132 |
+
question_data: Dict[str, Any],
|
133 |
+
case_details: Dict[str, Any],
|
134 |
+
case_id: str,
|
135 |
+
question_id: str,
|
136 |
+
worker_addr: str,
|
137 |
+
model_name: str,
|
138 |
+
raw_output: bool = False,
|
139 |
+
) -> Union[Tuple[Optional[str], Optional[float]], Dict[str, Any]]:
|
140 |
+
"""Create and send inference request to worker.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
question_data: Dictionary containing question details and figures
|
144 |
+
case_details: Dictionary containing case information and figures
|
145 |
+
case_id: Identifier for the medical case
|
146 |
+
question_id: Identifier for the specific question
|
147 |
+
worker_addr: Address of the worker endpoint
|
148 |
+
model_name: Name of the model to use
|
149 |
+
raw_output: Whether to return raw model output
|
150 |
+
|
151 |
+
Returns:
|
152 |
+
If raw_output is False: Tuple of (validated_answer, duration)
|
153 |
+
If raw_output is True: Dictionary with full inference details
|
154 |
+
"""
|
155 |
+
system_prompt = """You are a medical imaging expert. Your answer MUST be a SINGLE LETTER (A/B/C/D/E/F), provided in this format: 'The SINGLE LETTER answer is: X'.
|
156 |
+
"""
|
157 |
+
|
158 |
+
prompt = f"""Given the following medical case:
|
159 |
+
Please answer this multiple choice question:
|
160 |
+
{question_data['question']}
|
161 |
+
Base your answer only on the provided images and case information. Respond with your SINGLE LETTER answer: """
|
162 |
+
|
163 |
+
try:
|
164 |
+
# Parse required figures
|
165 |
+
if isinstance(question_data["figures"], str):
|
166 |
+
try:
|
167 |
+
required_figures = json.loads(question_data["figures"])
|
168 |
+
except json.JSONDecodeError:
|
169 |
+
required_figures = [question_data["figures"]]
|
170 |
+
elif isinstance(question_data["figures"], list):
|
171 |
+
required_figures = question_data["figures"]
|
172 |
+
else:
|
173 |
+
required_figures = [str(question_data["figures"])]
|
174 |
+
except Exception as e:
|
175 |
+
print(f"Error parsing figures: {e}")
|
176 |
+
required_figures = []
|
177 |
+
|
178 |
+
required_figures = [
|
179 |
+
fig if fig.startswith("Figure ") else f"Figure {fig}" for fig in required_figures
|
180 |
+
]
|
181 |
+
|
182 |
+
# Get image paths
|
183 |
+
image_paths = []
|
184 |
+
for figure in required_figures:
|
185 |
+
base_figure_num = "".join(filter(str.isdigit, figure))
|
186 |
+
figure_letter = "".join(filter(str.isalpha, figure.split()[-1])) or None
|
187 |
+
|
188 |
+
matching_figures = [
|
189 |
+
case_figure
|
190 |
+
for case_figure in case_details.get("figures", [])
|
191 |
+
if case_figure["number"] == f"Figure {base_figure_num}"
|
192 |
+
]
|
193 |
+
|
194 |
+
for case_figure in matching_figures:
|
195 |
+
subfigures = []
|
196 |
+
if figure_letter:
|
197 |
+
subfigures = [
|
198 |
+
subfig
|
199 |
+
for subfig in case_figure.get("subfigures", [])
|
200 |
+
if subfig.get("number", "").lower().endswith(figure_letter.lower())
|
201 |
+
or subfig.get("label", "").lower() == figure_letter.lower()
|
202 |
+
]
|
203 |
+
else:
|
204 |
+
subfigures = case_figure.get("subfigures", [])
|
205 |
+
|
206 |
+
for subfig in subfigures:
|
207 |
+
if "local_path" in subfig:
|
208 |
+
image_paths.append("MedMAX/data/" + subfig["local_path"])
|
209 |
+
|
210 |
+
if not image_paths:
|
211 |
+
print(f"No local images found for case {case_id}, question {question_id}")
|
212 |
+
return "skipped", 0.0 # Return a special 'skipped' marker
|
213 |
+
|
214 |
+
try:
|
215 |
+
start_time = time.time()
|
216 |
+
|
217 |
+
# Process each image
|
218 |
+
processed_images = [process_image(path) for path in image_paths]
|
219 |
+
|
220 |
+
# Create conversation
|
221 |
+
conv = conv_templates["mistral_instruct"].copy()
|
222 |
+
|
223 |
+
# Add image and message
|
224 |
+
if "<image>" not in prompt:
|
225 |
+
text = prompt + "\n<image>"
|
226 |
+
else:
|
227 |
+
text = prompt
|
228 |
+
|
229 |
+
message = (text, processed_images[0], "Default") # Currently handling first image
|
230 |
+
conv.append_message(conv.roles[0], message)
|
231 |
+
conv.append_message(conv.roles[1], None)
|
232 |
+
|
233 |
+
prompt = conv.get_prompt()
|
234 |
+
headers = {"User-Agent": "LLaVA-Med Client"}
|
235 |
+
pload = {
|
236 |
+
"model": model_name,
|
237 |
+
"prompt": prompt,
|
238 |
+
"max_new_tokens": 150, # Reduce this since we only need one letter
|
239 |
+
"temperature": 0.5, # Lower temperature for more focused responses
|
240 |
+
"stop": conv.sep2,
|
241 |
+
"images": conv.get_images(),
|
242 |
+
"top_p": 1, # Lower top_p for more focused sampling
|
243 |
+
"frequency_penalty": 0.0,
|
244 |
+
"presence_penalty": 0.0,
|
245 |
+
}
|
246 |
+
|
247 |
+
max_retries = 3
|
248 |
+
retry_delay = 5
|
249 |
+
response_text = None
|
250 |
+
|
251 |
+
for attempt in range(max_retries):
|
252 |
+
try:
|
253 |
+
response = requests.post(
|
254 |
+
worker_addr + "/worker_generate_stream",
|
255 |
+
headers=headers,
|
256 |
+
json=pload,
|
257 |
+
stream=True,
|
258 |
+
timeout=30,
|
259 |
+
)
|
260 |
+
|
261 |
+
complete_output = ""
|
262 |
+
for chunk in response.iter_lines(
|
263 |
+
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
264 |
+
):
|
265 |
+
if chunk:
|
266 |
+
data = json.loads(chunk.decode("utf-8"))
|
267 |
+
if data["error_code"] == 0:
|
268 |
+
output = data["text"].split("[/INST]")[-1]
|
269 |
+
complete_output = output
|
270 |
+
else:
|
271 |
+
print(f"\nError: {data['text']} (error_code: {data['error_code']})")
|
272 |
+
if attempt < max_retries - 1:
|
273 |
+
time.sleep(retry_delay)
|
274 |
+
break
|
275 |
+
return None, None
|
276 |
+
|
277 |
+
if complete_output:
|
278 |
+
response_text = complete_output
|
279 |
+
break
|
280 |
+
|
281 |
+
except (requests.exceptions.RequestException, json.JSONDecodeError) as e:
|
282 |
+
if attempt < max_retries - 1:
|
283 |
+
print(f"\nNetwork error: {str(e)}. Retrying in {retry_delay} seconds...")
|
284 |
+
time.sleep(retry_delay)
|
285 |
+
else:
|
286 |
+
print(f"\nFailed after {max_retries} attempts: {str(e)}")
|
287 |
+
return None, None
|
288 |
+
|
289 |
+
duration = time.time() - start_time
|
290 |
+
|
291 |
+
if raw_output:
|
292 |
+
inference_details = {
|
293 |
+
"raw_output": response_text,
|
294 |
+
"validated_answer": validate_answer(response_text),
|
295 |
+
"duration": duration,
|
296 |
+
"prompt": prompt,
|
297 |
+
"system_prompt": system_prompt,
|
298 |
+
"image_paths": image_paths,
|
299 |
+
"payload": pload,
|
300 |
+
}
|
301 |
+
return inference_details
|
302 |
+
|
303 |
+
return validate_answer(response_text), duration
|
304 |
+
|
305 |
+
except Exception as e:
|
306 |
+
print(f"Error in inference request: {str(e)}")
|
307 |
+
return None, None
|
308 |
+
|
309 |
+
|
310 |
+
def clean_payload(payload: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
311 |
+
"""Remove image-related and large data from the payload to keep the log lean.
|
312 |
+
|
313 |
+
Args:
|
314 |
+
payload: Original request payload dictionary
|
315 |
+
|
316 |
+
Returns:
|
317 |
+
Cleaned payload dictionary with large data removed
|
318 |
+
"""
|
319 |
+
if not payload:
|
320 |
+
return None
|
321 |
+
|
322 |
+
# Create a copy of the payload to avoid modifying the original
|
323 |
+
cleaned_payload = payload.copy()
|
324 |
+
|
325 |
+
# Remove large or sensitive data
|
326 |
+
if "images" in cleaned_payload:
|
327 |
+
del cleaned_payload["images"]
|
328 |
+
|
329 |
+
return cleaned_payload
|
330 |
+
|
331 |
+
|
332 |
+
def main():
|
333 |
+
parser = argparse.ArgumentParser()
|
334 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
335 |
+
parser.add_argument("--worker-address", type=str)
|
336 |
+
parser.add_argument("--model-name", type=str, default="llava-med-v1.5-mistral-7b")
|
337 |
+
parser.add_argument("--output-dir", type=str, default="benchmark_results")
|
338 |
+
parser.add_argument(
|
339 |
+
"--raw-output", action="store_true", help="Return raw model output without validation"
|
340 |
+
)
|
341 |
+
parser.add_argument(
|
342 |
+
"--num-cases",
|
343 |
+
type=int,
|
344 |
+
help="Number of cases to process if looking at raw outputs",
|
345 |
+
default=2,
|
346 |
+
)
|
347 |
+
args = parser.parse_args()
|
348 |
+
|
349 |
+
# Setup output directory
|
350 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
351 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
352 |
+
|
353 |
+
# Setup live logging files
|
354 |
+
live_log_filename = os.path.join(args.output_dir, f"live_benchmark_log_{timestamp}.json")
|
355 |
+
final_results_filename = os.path.join(args.output_dir, f"final_results_{timestamp}.json")
|
356 |
+
|
357 |
+
# Initialize live log file
|
358 |
+
with open(live_log_filename, "w") as live_log_file:
|
359 |
+
live_log_file.write("[\n") # Start of JSON array
|
360 |
+
|
361 |
+
# Setup logging
|
362 |
+
logging.basicConfig(
|
363 |
+
filename=os.path.join(args.output_dir, f"benchmark_{timestamp}.log"),
|
364 |
+
level=logging.INFO,
|
365 |
+
format="%(message)s",
|
366 |
+
)
|
367 |
+
|
368 |
+
# Get worker address
|
369 |
+
if args.worker_address:
|
370 |
+
worker_addr = args.worker_address
|
371 |
+
else:
|
372 |
+
try:
|
373 |
+
requests.post(args.controller_address + "/refresh_all_workers")
|
374 |
+
ret = requests.post(args.controller_address + "/list_models")
|
375 |
+
models = ret.json()["models"]
|
376 |
+
ret = requests.post(
|
377 |
+
args.controller_address + "/get_worker_address", json={"model": args.model_name}
|
378 |
+
)
|
379 |
+
worker_addr = ret.json()["address"]
|
380 |
+
print(f"Worker address: {worker_addr}")
|
381 |
+
except requests.exceptions.RequestException as e:
|
382 |
+
print(f"Failed to connect to controller: {e}")
|
383 |
+
return
|
384 |
+
|
385 |
+
if worker_addr == "":
|
386 |
+
print("No available worker")
|
387 |
+
return
|
388 |
+
|
389 |
+
# Load cases with local paths
|
390 |
+
with open("MedMAX/data/updated_cases.json", "r") as file:
|
391 |
+
data = json.load(file)
|
392 |
+
|
393 |
+
total_cases, total_questions = count_total_questions()
|
394 |
+
print(f"\nStarting benchmark with {args.model_name}")
|
395 |
+
print(f"Found {total_cases} cases with {total_questions} total questions")
|
396 |
+
|
397 |
+
results = {
|
398 |
+
"model": args.model_name,
|
399 |
+
"timestamp": datetime.now().isoformat(),
|
400 |
+
"total_cases": total_cases,
|
401 |
+
"total_questions": total_questions,
|
402 |
+
"results": [],
|
403 |
+
}
|
404 |
+
|
405 |
+
cases_processed = 0
|
406 |
+
questions_processed = 0
|
407 |
+
correct_answers = 0
|
408 |
+
skipped_questions = 0
|
409 |
+
total_processed_entries = 0
|
410 |
+
|
411 |
+
# Process each case
|
412 |
+
for case_id, case_details in tqdm(data.items(), desc="Processing cases"):
|
413 |
+
question_files = load_benchmark_questions(case_id)
|
414 |
+
if not question_files:
|
415 |
+
continue
|
416 |
+
|
417 |
+
cases_processed += 1
|
418 |
+
for question_file in tqdm(
|
419 |
+
question_files, desc=f"Processing questions for case {case_id}", leave=False
|
420 |
+
):
|
421 |
+
with open(question_file, "r") as file:
|
422 |
+
question_data = json.load(file)
|
423 |
+
question_id = os.path.basename(question_file).split(".")[0]
|
424 |
+
|
425 |
+
questions_processed += 1
|
426 |
+
|
427 |
+
# Get model's answer
|
428 |
+
inference_result = create_inference_request(
|
429 |
+
question_data,
|
430 |
+
case_details,
|
431 |
+
case_id,
|
432 |
+
question_id,
|
433 |
+
worker_addr,
|
434 |
+
args.model_name,
|
435 |
+
raw_output=True, # Always use raw output for detailed logging
|
436 |
+
)
|
437 |
+
|
438 |
+
# Handle skipped questions
|
439 |
+
if inference_result == ("skipped", 0.0):
|
440 |
+
skipped_questions += 1
|
441 |
+
print(f"\nCase {case_id}, Question {question_id}: Skipped (No images)")
|
442 |
+
|
443 |
+
# Log skipped question
|
444 |
+
skipped_entry = {
|
445 |
+
"case_id": case_id,
|
446 |
+
"question_id": question_id,
|
447 |
+
"status": "skipped",
|
448 |
+
"reason": "No images found",
|
449 |
+
}
|
450 |
+
with open(live_log_filename, "a") as live_log_file:
|
451 |
+
json.dump(skipped_entry, live_log_file, indent=2)
|
452 |
+
live_log_file.write(",\n") # Add comma for next entry
|
453 |
+
|
454 |
+
continue
|
455 |
+
|
456 |
+
# Extract information
|
457 |
+
answer = inference_result["validated_answer"]
|
458 |
+
duration = inference_result["duration"]
|
459 |
+
|
460 |
+
# Prepare detailed logging entry
|
461 |
+
log_entry = {
|
462 |
+
"case_id": case_id,
|
463 |
+
"question_id": question_id,
|
464 |
+
"question": question_data["question"],
|
465 |
+
"correct_answer": question_data["answer"],
|
466 |
+
"raw_output": inference_result["raw_output"],
|
467 |
+
"validated_answer": answer,
|
468 |
+
"model_answer": answer,
|
469 |
+
"is_correct": answer == question_data["answer"] if answer else False,
|
470 |
+
"duration": duration,
|
471 |
+
"system_prompt": inference_result["system_prompt"],
|
472 |
+
"input_prompt": inference_result["prompt"],
|
473 |
+
"image_paths": inference_result["image_paths"],
|
474 |
+
"payload": clean_payload(inference_result["payload"]),
|
475 |
+
}
|
476 |
+
|
477 |
+
# Write to live log file
|
478 |
+
with open(live_log_filename, "a") as live_log_file:
|
479 |
+
json.dump(log_entry, live_log_file, indent=2)
|
480 |
+
live_log_file.write(",\n") # Add comma for next entry
|
481 |
+
|
482 |
+
# Print to console
|
483 |
+
print(f"\nCase {case_id}, Question {question_id}")
|
484 |
+
print(f"Model Answer: {answer}")
|
485 |
+
print(f"Correct Answer: {question_data['answer']}")
|
486 |
+
print(f"Time taken: {duration:.2f}s")
|
487 |
+
|
488 |
+
# Track correct answers
|
489 |
+
if answer == question_data["answer"]:
|
490 |
+
correct_answers += 1
|
491 |
+
|
492 |
+
# Append to results
|
493 |
+
results["results"].append(log_entry)
|
494 |
+
total_processed_entries += 1
|
495 |
+
|
496 |
+
# Optional: break if reached specified number of cases
|
497 |
+
if args.raw_output and cases_processed == args.num_cases:
|
498 |
+
break
|
499 |
+
|
500 |
+
# Optional: break if reached specified number of cases
|
501 |
+
if args.raw_output and cases_processed == args.num_cases:
|
502 |
+
break
|
503 |
+
|
504 |
+
# Close live log file
|
505 |
+
with open(live_log_filename, "a") as live_log_file:
|
506 |
+
# Remove trailing comma and close JSON array
|
507 |
+
live_log_file.seek(live_log_file.tell() - 2, 0) # Go back 2 chars to remove ',\n'
|
508 |
+
live_log_file.write("\n]")
|
509 |
+
|
510 |
+
# Calculate final statistics
|
511 |
+
results["summary"] = {
|
512 |
+
"cases_processed": cases_processed,
|
513 |
+
"questions_processed": questions_processed,
|
514 |
+
"total_processed_entries": total_processed_entries,
|
515 |
+
"correct_answers": correct_answers,
|
516 |
+
"skipped_questions": skipped_questions,
|
517 |
+
"accuracy": (
|
518 |
+
correct_answers / (questions_processed - skipped_questions)
|
519 |
+
if (questions_processed - skipped_questions) > 0
|
520 |
+
else 0
|
521 |
+
),
|
522 |
+
}
|
523 |
+
|
524 |
+
# Save final results
|
525 |
+
with open(final_results_filename, "w") as f:
|
526 |
+
json.dump(results, f, indent=2)
|
527 |
+
|
528 |
+
print(f"\nBenchmark Summary:")
|
529 |
+
print(f"Total Cases Processed: {cases_processed}")
|
530 |
+
print(f"Total Questions Processed: {questions_processed}")
|
531 |
+
print(f"Total Processed Entries: {total_processed_entries}")
|
532 |
+
print(f"Correct Answers: {correct_answers}")
|
533 |
+
print(f"Skipped Questions: {skipped_questions}")
|
534 |
+
print(f"Accuracy: {(correct_answers / (questions_processed - skipped_questions) * 100):.2f}%")
|
535 |
+
print(f"\nResults saved to {args.output_dir}")
|
536 |
+
print(f"Live log: {live_log_filename}")
|
537 |
+
print(f"Final results: {final_results_filename}")
|
538 |
+
|
539 |
+
|
540 |
+
if __name__ == "__main__":
|
541 |
+
main()
|
experiments/benchmark_medrax.ipynb
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"import operator\n",
|
10 |
+
"import warnings\n",
|
11 |
+
"from typing import *\n",
|
12 |
+
"import traceback\n",
|
13 |
+
"\n",
|
14 |
+
"import os\n",
|
15 |
+
"import torch\n",
|
16 |
+
"from dotenv import load_dotenv\n",
|
17 |
+
"from IPython.display import Image\n",
|
18 |
+
"from langgraph.checkpoint.memory import MemorySaver\n",
|
19 |
+
"from langgraph.graph import END, StateGraph\n",
|
20 |
+
"from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage\n",
|
21 |
+
"from langchain_openai import ChatOpenAI\n",
|
22 |
+
"from transformers import logging\n",
|
23 |
+
"import matplotlib.pyplot as plt\n",
|
24 |
+
"import numpy as np\n",
|
25 |
+
"import re\n",
|
26 |
+
"\n",
|
27 |
+
"from medrax.agent import *\n",
|
28 |
+
"from medrax.tools import *\n",
|
29 |
+
"from medrax.utils import *\n",
|
30 |
+
"\n",
|
31 |
+
"import json\n",
|
32 |
+
"import openai\n",
|
33 |
+
"import os\n",
|
34 |
+
"import glob\n",
|
35 |
+
"import time\n",
|
36 |
+
"import logging\n",
|
37 |
+
"from datetime import datetime\n",
|
38 |
+
"from tenacity import retry, wait_exponential, stop_after_attempt\n",
|
39 |
+
"\n",
|
40 |
+
"warnings.filterwarnings(\"ignore\")\n",
|
41 |
+
"_ = load_dotenv()\n",
|
42 |
+
"\n",
|
43 |
+
"\n",
|
44 |
+
"# Setup directory paths\n",
|
45 |
+
"ROOT = \"set this directory to where MedRAX is, .e.g /home/MedRAX\"\n",
|
46 |
+
"PROMPT_FILE = f\"{ROOT}/medrax/docs/system_prompts.txt\"\n",
|
47 |
+
"BENCHMARK_FILE = f\"{ROOT}/benchmark/questions\"\n",
|
48 |
+
"MODEL_DIR = f\"set this to where the tool models are, e.g /home/models\"\n",
|
49 |
+
"FIGURES_DIR = f\"{ROOT}/benchmark/figures\"\n",
|
50 |
+
"\n",
|
51 |
+
"model_name = \"medrax\"\n",
|
52 |
+
"temperature = 0.2\n",
|
53 |
+
"medrax_logs = f\"{ROOT}/experiments/medrax_logs\"\n",
|
54 |
+
"log_filename = f\"{medrax_logs}/{model_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json\"\n",
|
55 |
+
"logging.basicConfig(filename=log_filename, level=logging.INFO, format=\"%(message)s\", force=True)\n",
|
56 |
+
"device = \"cuda\""
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 2,
|
62 |
+
"metadata": {},
|
63 |
+
"outputs": [],
|
64 |
+
"source": [
|
65 |
+
"def get_tools():\n",
|
66 |
+
" report_tool = ChestXRayReportGeneratorTool(cache_dir=MODEL_DIR, device=device)\n",
|
67 |
+
" xray_classification_tool = ChestXRayClassifierTool(device=device)\n",
|
68 |
+
" segmentation_tool = ChestXRaySegmentationTool(device=device)\n",
|
69 |
+
" grounding_tool = XRayPhraseGroundingTool(\n",
|
70 |
+
" cache_dir=MODEL_DIR, temp_dir=\"temp\", device=device, load_in_8bit=True\n",
|
71 |
+
" )\n",
|
72 |
+
" xray_vqa_tool = XRayVQATool(cache_dir=MODEL_DIR, device=device)\n",
|
73 |
+
" llava_med_tool = LlavaMedTool(cache_dir=MODEL_DIR, device=device, load_in_8bit=True)\n",
|
74 |
+
"\n",
|
75 |
+
" return [\n",
|
76 |
+
" report_tool,\n",
|
77 |
+
" xray_classification_tool,\n",
|
78 |
+
" segmentation_tool,\n",
|
79 |
+
" grounding_tool,\n",
|
80 |
+
" xray_vqa_tool,\n",
|
81 |
+
" llava_med_tool,\n",
|
82 |
+
" ]\n",
|
83 |
+
"\n",
|
84 |
+
"\n",
|
85 |
+
"def get_agent(tools):\n",
|
86 |
+
" prompts = load_prompts_from_file(PROMPT_FILE)\n",
|
87 |
+
" prompt = prompts[\"MEDICAL_ASSISTANT\"]\n",
|
88 |
+
"\n",
|
89 |
+
" checkpointer = MemorySaver()\n",
|
90 |
+
" model = ChatOpenAI(model=\"gpt-4o\", temperature=temperature, top_p=0.95)\n",
|
91 |
+
" agent = Agent(\n",
|
92 |
+
" model,\n",
|
93 |
+
" tools=tools,\n",
|
94 |
+
" log_tools=True,\n",
|
95 |
+
" log_dir=\"logs\",\n",
|
96 |
+
" system_prompt=prompt,\n",
|
97 |
+
" checkpointer=checkpointer,\n",
|
98 |
+
" )\n",
|
99 |
+
" thread = {\"configurable\": {\"thread_id\": \"1\"}}\n",
|
100 |
+
" return agent, thread\n",
|
101 |
+
"\n",
|
102 |
+
"\n",
|
103 |
+
"def run_medrax(agent, thread, prompt, image_urls=[]):\n",
|
104 |
+
" messages = [\n",
|
105 |
+
" HumanMessage(\n",
|
106 |
+
" content=[\n",
|
107 |
+
" {\"type\": \"text\", \"text\": prompt},\n",
|
108 |
+
" ]\n",
|
109 |
+
" + [{\"type\": \"image_url\", \"image_url\": {\"url\": image_url}} for image_url in image_urls]\n",
|
110 |
+
" )\n",
|
111 |
+
" ]\n",
|
112 |
+
"\n",
|
113 |
+
" final_response = None\n",
|
114 |
+
" for event in agent.workflow.stream({\"messages\": messages}, thread):\n",
|
115 |
+
" for v in event.values():\n",
|
116 |
+
" final_response = v\n",
|
117 |
+
"\n",
|
118 |
+
" final_response = final_response[\"messages\"][-1].content.strip()\n",
|
119 |
+
" agent_state = agent.workflow.get_state(thread)\n",
|
120 |
+
"\n",
|
121 |
+
" return final_response, str(agent_state)"
|
122 |
+
]
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"cell_type": "code",
|
126 |
+
"execution_count": 3,
|
127 |
+
"metadata": {},
|
128 |
+
"outputs": [],
|
129 |
+
"source": [
|
130 |
+
"def create_multimodal_request(question_data, case_details, case_id, question_id, agent, thread):\n",
|
131 |
+
" # Parse required figures\n",
|
132 |
+
" try:\n",
|
133 |
+
" # Try multiple ways of parsing figures\n",
|
134 |
+
" if isinstance(question_data[\"figures\"], str):\n",
|
135 |
+
" try:\n",
|
136 |
+
" required_figures = json.loads(question_data[\"figures\"])\n",
|
137 |
+
" except json.JSONDecodeError:\n",
|
138 |
+
" required_figures = [question_data[\"figures\"]]\n",
|
139 |
+
" elif isinstance(question_data[\"figures\"], list):\n",
|
140 |
+
" required_figures = question_data[\"figures\"]\n",
|
141 |
+
" else:\n",
|
142 |
+
" required_figures = [str(question_data[\"figures\"])]\n",
|
143 |
+
" except Exception as e:\n",
|
144 |
+
" print(f\"Error parsing figures: {e}\")\n",
|
145 |
+
" required_figures = []\n",
|
146 |
+
"\n",
|
147 |
+
" # Ensure each figure starts with \"Figure \"\n",
|
148 |
+
" required_figures = [\n",
|
149 |
+
" fig if fig.startswith(\"Figure \") else f\"Figure {fig}\" for fig in required_figures\n",
|
150 |
+
" ]\n",
|
151 |
+
"\n",
|
152 |
+
" subfigures = []\n",
|
153 |
+
" for figure in required_figures:\n",
|
154 |
+
" # Handle both regular figures and those with letter suffixes\n",
|
155 |
+
" base_figure_num = \"\".join(filter(str.isdigit, figure))\n",
|
156 |
+
" figure_letter = \"\".join(filter(str.isalpha, figure.split()[-1])) or None\n",
|
157 |
+
"\n",
|
158 |
+
" # Find matching figures in case details\n",
|
159 |
+
" matching_figures = [\n",
|
160 |
+
" case_figure\n",
|
161 |
+
" for case_figure in case_details.get(\"figures\", [])\n",
|
162 |
+
" if case_figure[\"number\"] == f\"Figure {base_figure_num}\"\n",
|
163 |
+
" ]\n",
|
164 |
+
"\n",
|
165 |
+
" if not matching_figures:\n",
|
166 |
+
" print(f\"No matching figure found for {figure} in case {case_id}\")\n",
|
167 |
+
" continue\n",
|
168 |
+
"\n",
|
169 |
+
" for case_figure in matching_figures:\n",
|
170 |
+
" # If a specific letter is specified, filter subfigures\n",
|
171 |
+
" if figure_letter:\n",
|
172 |
+
" matching_subfigures = [\n",
|
173 |
+
" subfig\n",
|
174 |
+
" for subfig in case_figure.get(\"subfigures\", [])\n",
|
175 |
+
" if subfig.get(\"number\", \"\").lower().endswith(figure_letter.lower())\n",
|
176 |
+
" or subfig.get(\"label\", \"\").lower() == figure_letter.lower()\n",
|
177 |
+
" ]\n",
|
178 |
+
" subfigures.extend(matching_subfigures)\n",
|
179 |
+
" else:\n",
|
180 |
+
" # If no letter specified, add all subfigures\n",
|
181 |
+
" subfigures.extend(case_figure.get(\"subfigures\", []))\n",
|
182 |
+
"\n",
|
183 |
+
" # Add images to content\n",
|
184 |
+
" figure_prompt = \"\"\n",
|
185 |
+
" image_urls = []\n",
|
186 |
+
"\n",
|
187 |
+
" for subfig in subfigures:\n",
|
188 |
+
" if \"number\" in subfig:\n",
|
189 |
+
" subfig_number = subfig[\"number\"].lower().strip().replace(\" \", \"_\") + \".jpg\"\n",
|
190 |
+
" subfig_path = os.path.join(FIGURES_DIR, case_id, subfig_number)\n",
|
191 |
+
" figure_prompt += f\"{subfig_number} located at {subfig_path}\\n\"\n",
|
192 |
+
" if \"url\" in subfig:\n",
|
193 |
+
" image_urls.append(subfig[\"url\"])\n",
|
194 |
+
" else:\n",
|
195 |
+
" print(f\"Subfigure missing URL: {subfig}\")\n",
|
196 |
+
"\n",
|
197 |
+
" prompt = (\n",
|
198 |
+
" f\"Answer this question correctly using chain of thought reasoning and \"\n",
|
199 |
+
" \"carefully evaluating choices. Solve using our own vision and reasoning and then\"\n",
|
200 |
+
" \"use tools to complement your reasoning. Trust your own judgement over any tools.\\n\"\n",
|
201 |
+
" f\"{question_data['question']}\\n{figure_prompt}\"\n",
|
202 |
+
" )\n",
|
203 |
+
"\n",
|
204 |
+
" try:\n",
|
205 |
+
" start_time = time.time()\n",
|
206 |
+
"\n",
|
207 |
+
" final_response, agent_state = run_medrax(\n",
|
208 |
+
" agent=agent, thread=thread, prompt=prompt, image_urls=image_urls\n",
|
209 |
+
" )\n",
|
210 |
+
" model_answer, agent_state = run_medrax(\n",
|
211 |
+
" agent=agent,\n",
|
212 |
+
" thread=thread,\n",
|
213 |
+
" prompt=\"If you had to choose the best option, only respond with the letter of choice (only one of A, B, C, D, E, F)\",\n",
|
214 |
+
" )\n",
|
215 |
+
" duration = time.time() - start_time\n",
|
216 |
+
"\n",
|
217 |
+
" log_entry = {\n",
|
218 |
+
" \"case_id\": case_id,\n",
|
219 |
+
" \"question_id\": question_id,\n",
|
220 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
221 |
+
" \"model\": model_name,\n",
|
222 |
+
" \"temperature\": temperature,\n",
|
223 |
+
" \"duration\": round(duration, 2),\n",
|
224 |
+
" \"usage\": \"\",\n",
|
225 |
+
" \"cost\": 0,\n",
|
226 |
+
" \"raw_response\": final_response,\n",
|
227 |
+
" \"model_answer\": model_answer.strip(),\n",
|
228 |
+
" \"correct_answer\": question_data[\"answer\"][0],\n",
|
229 |
+
" \"input\": {\n",
|
230 |
+
" \"messages\": prompt,\n",
|
231 |
+
" \"question_data\": {\n",
|
232 |
+
" \"question\": question_data[\"question\"],\n",
|
233 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
234 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
235 |
+
" \"figures\": question_data[\"figures\"],\n",
|
236 |
+
" },\n",
|
237 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
238 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
239 |
+
" },\n",
|
240 |
+
" \"agent_state\": agent_state,\n",
|
241 |
+
" }\n",
|
242 |
+
" logging.info(json.dumps(log_entry))\n",
|
243 |
+
" return final_response, model_answer.strip()\n",
|
244 |
+
"\n",
|
245 |
+
" except Exception as e:\n",
|
246 |
+
" log_entry = {\n",
|
247 |
+
" \"case_id\": case_id,\n",
|
248 |
+
" \"question_id\": question_id,\n",
|
249 |
+
" \"timestamp\": datetime.now().isoformat(),\n",
|
250 |
+
" \"model\": model_name,\n",
|
251 |
+
" \"temperature\": temperature,\n",
|
252 |
+
" \"status\": \"error\",\n",
|
253 |
+
" \"error\": str(e),\n",
|
254 |
+
" \"cost\": 0,\n",
|
255 |
+
" \"input\": {\n",
|
256 |
+
" \"messages\": prompt,\n",
|
257 |
+
" \"question_data\": {\n",
|
258 |
+
" \"question\": question_data[\"question\"],\n",
|
259 |
+
" \"explanation\": question_data[\"explanation\"],\n",
|
260 |
+
" \"metadata\": question_data.get(\"metadata\", {}),\n",
|
261 |
+
" \"figures\": question_data[\"figures\"],\n",
|
262 |
+
" },\n",
|
263 |
+
" \"image_urls\": [subfig[\"url\"] for subfig in subfigures if \"url\" in subfig],\n",
|
264 |
+
" \"image_captions\": [subfig.get(\"caption\", \"\") for subfig in subfigures],\n",
|
265 |
+
" },\n",
|
266 |
+
" }\n",
|
267 |
+
" logging.info(json.dumps(log_entry))\n",
|
268 |
+
" print(f\"Error processing case {case_id}, question {question_id}: {str(e)}\")\n",
|
269 |
+
" return \"\", \"\"\n",
|
270 |
+
"\n",
|
271 |
+
"\n",
|
272 |
+
"def load_benchmark_questions(case_id):\n",
|
273 |
+
" benchmark_dir = \"../benchmark/questions\"\n",
|
274 |
+
" return glob.glob(f\"{benchmark_dir}/{case_id}/{case_id}_*.json\")\n",
|
275 |
+
"\n",
|
276 |
+
"\n",
|
277 |
+
"def count_total_questions():\n",
|
278 |
+
" total_cases = len(glob.glob(\"../benchmark/questions/*\"))\n",
|
279 |
+
" total_questions = sum(\n",
|
280 |
+
" len(glob.glob(f\"../benchmark/questions/{case_id}/*.json\"))\n",
|
281 |
+
" for case_id in os.listdir(\"../benchmark/questions\")\n",
|
282 |
+
" )\n",
|
283 |
+
" return total_cases, total_questions\n",
|
284 |
+
"\n",
|
285 |
+
"\n",
|
286 |
+
"def main(tools):\n",
|
287 |
+
" with open(\"../data/eurorad_metadata.json\", \"r\") as file:\n",
|
288 |
+
" data = json.load(file)\n",
|
289 |
+
"\n",
|
290 |
+
" total_cases, total_questions = count_total_questions()\n",
|
291 |
+
" cases_processed = 0\n",
|
292 |
+
" questions_processed = 0\n",
|
293 |
+
" skipped_questions = 0\n",
|
294 |
+
"\n",
|
295 |
+
" print(f\"Beginning benchmark evaluation for model {model_name} with temperature {temperature}\\n\")\n",
|
296 |
+
"\n",
|
297 |
+
" for case_id, case_details in data.items():\n",
|
298 |
+
" if int(case_details[\"case_id\"]) <= 17158:\n",
|
299 |
+
" continue\n",
|
300 |
+
"\n",
|
301 |
+
" print(f\"----------------------------------------------------------------\")\n",
|
302 |
+
" agent, thread = get_agent(tools)\n",
|
303 |
+
"\n",
|
304 |
+
" question_files = load_benchmark_questions(case_id)\n",
|
305 |
+
" if not question_files:\n",
|
306 |
+
" continue\n",
|
307 |
+
"\n",
|
308 |
+
" cases_processed += 1\n",
|
309 |
+
" for question_file in question_files:\n",
|
310 |
+
" with open(question_file, \"r\") as file:\n",
|
311 |
+
" question_data = json.load(file)\n",
|
312 |
+
" question_id = os.path.basename(question_file).split(\".\")[0]\n",
|
313 |
+
"\n",
|
314 |
+
" # agent, thread = get_agent(tools)\n",
|
315 |
+
" questions_processed += 1\n",
|
316 |
+
" final_response, model_answer = create_multimodal_request(\n",
|
317 |
+
" question_data, case_details, case_id, question_id, agent, thread\n",
|
318 |
+
" )\n",
|
319 |
+
"\n",
|
320 |
+
" # Handle cases where response is None\n",
|
321 |
+
" if final_response is None:\n",
|
322 |
+
" skipped_questions += 1\n",
|
323 |
+
" print(f\"Skipped question: Case ID {case_id}, Question ID {question_id}\")\n",
|
324 |
+
" continue\n",
|
325 |
+
"\n",
|
326 |
+
" print(\n",
|
327 |
+
" f\"Progress: Case {cases_processed}/{total_cases}, Question {questions_processed}/{total_questions}\"\n",
|
328 |
+
" )\n",
|
329 |
+
" print(f\"Case ID: {case_id}\")\n",
|
330 |
+
" print(f\"Question ID: {question_id}\")\n",
|
331 |
+
" print(f\"Final Response: {final_response}\")\n",
|
332 |
+
" print(f\"Model Answer: {model_answer}\")\n",
|
333 |
+
" print(f\"Correct Answer: {question_data['answer']}\")\n",
|
334 |
+
" print(f\"----------------------------------------------------------------\\n\")\n",
|
335 |
+
"\n",
|
336 |
+
" print(f\"\\nBenchmark Summary:\")\n",
|
337 |
+
" print(f\"Total Cases Processed: {cases_processed}\")\n",
|
338 |
+
" print(f\"Total Questions Processed: {questions_processed}\")\n",
|
339 |
+
" print(f\"Total Questions Skipped: {skipped_questions}\")"
|
340 |
+
]
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"cell_type": "code",
|
344 |
+
"execution_count": null,
|
345 |
+
"metadata": {},
|
346 |
+
"outputs": [],
|
347 |
+
"source": [
|
348 |
+
"tools = get_tools()\n",
|
349 |
+
"main(tools)"
|
350 |
+
]
|
351 |
+
}
|
352 |
+
],
|
353 |
+
"metadata": {
|
354 |
+
"kernelspec": {
|
355 |
+
"display_name": "medmax",
|
356 |
+
"language": "python",
|
357 |
+
"name": "python3"
|
358 |
+
},
|
359 |
+
"language_info": {
|
360 |
+
"codemirror_mode": {
|
361 |
+
"name": "ipython",
|
362 |
+
"version": 3
|
363 |
+
},
|
364 |
+
"file_extension": ".py",
|
365 |
+
"mimetype": "text/x-python",
|
366 |
+
"name": "python",
|
367 |
+
"nbconvert_exporter": "python",
|
368 |
+
"pygments_lexer": "ipython3",
|
369 |
+
"version": "3.10.16"
|
370 |
+
}
|
371 |
+
},
|
372 |
+
"nbformat": 4,
|
373 |
+
"nbformat_minor": 2
|
374 |
+
}
|
experiments/chexbench_gpt4.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import openai
|
3 |
+
import os
|
4 |
+
from datetime import datetime
|
5 |
+
import base64
|
6 |
+
import logging
|
7 |
+
from pathlib import Path
|
8 |
+
import time
|
9 |
+
from tqdm import tqdm
|
10 |
+
from typing import Dict, List, Optional, Union, Any
|
11 |
+
|
12 |
+
# Configuration constants
|
13 |
+
DEBUG_MODE = False
|
14 |
+
OUTPUT_DIR = "results"
|
15 |
+
MODEL_NAME = "gpt-4o-2024-05-13"
|
16 |
+
TEMPERATURE = 0.2
|
17 |
+
SUBSET = "Visual Question Answering"
|
18 |
+
|
19 |
+
# Set up logging configuration
|
20 |
+
logging_level = logging.DEBUG if DEBUG_MODE else logging.INFO
|
21 |
+
logging.basicConfig(level=logging_level, format="%(asctime)s - %(levelname)s - %(message)s")
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
def get_mime_type(file_path: str) -> str:
|
26 |
+
"""
|
27 |
+
Determine MIME type based on file extension.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
file_path (str): Path to the file
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
str: MIME type string for the file
|
34 |
+
"""
|
35 |
+
extension = os.path.splitext(file_path)[1].lower()
|
36 |
+
mime_types = {
|
37 |
+
".png": "image/png",
|
38 |
+
".jpg": "image/jpeg",
|
39 |
+
".jpeg": "image/jpeg",
|
40 |
+
".gif": "image/gif",
|
41 |
+
}
|
42 |
+
return mime_types.get(extension, "application/octet-stream")
|
43 |
+
|
44 |
+
|
45 |
+
def encode_image(image_path: str) -> str:
|
46 |
+
"""
|
47 |
+
Encode image to base64 with extensive error checking.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
image_path (str): Path to the image file
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
str: Base64 encoded image string
|
54 |
+
|
55 |
+
Raises:
|
56 |
+
FileNotFoundError: If image file does not exist
|
57 |
+
ValueError: If image file is empty or too large
|
58 |
+
Exception: For other image processing errors
|
59 |
+
"""
|
60 |
+
logger.debug(f"Attempting to read image from: {image_path}")
|
61 |
+
if not os.path.exists(image_path):
|
62 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
63 |
+
|
64 |
+
# Add check for file size
|
65 |
+
file_size = os.path.getsize(image_path)
|
66 |
+
if file_size > 20 * 1024 * 1024: # 20MB limit
|
67 |
+
raise ValueError("Image file size exceeds 20MB limit")
|
68 |
+
if file_size == 0:
|
69 |
+
raise ValueError("Image file is empty")
|
70 |
+
logger.debug(f"Image file size: {file_size / 1024:.2f} KB")
|
71 |
+
|
72 |
+
try:
|
73 |
+
from PIL import Image
|
74 |
+
|
75 |
+
# Try to open and verify the image
|
76 |
+
with Image.open(image_path) as img:
|
77 |
+
# Get image details
|
78 |
+
width, height = img.size
|
79 |
+
format = img.format
|
80 |
+
mode = img.mode
|
81 |
+
logger.debug(
|
82 |
+
f"Image verification - Format: {format}, Size: {width}x{height}, Mode: {mode}"
|
83 |
+
)
|
84 |
+
|
85 |
+
if format not in ["PNG", "JPEG", "GIF"]:
|
86 |
+
raise ValueError(f"Unsupported image format: {format}")
|
87 |
+
|
88 |
+
with open(image_path, "rb") as image_file:
|
89 |
+
# Read the first few bytes to verify it's a valid PNG
|
90 |
+
header = image_file.read(8)
|
91 |
+
# if header != b'\x89PNG\r\n\x1a\n':
|
92 |
+
# logger.warning("File does not have a valid PNG signature")
|
93 |
+
|
94 |
+
# Reset file pointer and read entire file
|
95 |
+
image_file.seek(0)
|
96 |
+
encoded = base64.b64encode(image_file.read()).decode("utf-8")
|
97 |
+
encoded_length = len(encoded)
|
98 |
+
logger.debug(f"Base64 encoded length: {encoded_length} characters")
|
99 |
+
|
100 |
+
# Verify the encoded string is not empty and starts correctly
|
101 |
+
if encoded_length == 0:
|
102 |
+
raise ValueError("Base64 encoding produced empty string")
|
103 |
+
if not encoded.startswith("/9j/") and not encoded.startswith("iVBOR"):
|
104 |
+
logger.warning("Base64 string doesn't start with expected JPEG or PNG header")
|
105 |
+
|
106 |
+
return encoded
|
107 |
+
except Exception as e:
|
108 |
+
logger.error(f"Error reading/encoding image: {str(e)}")
|
109 |
+
raise
|
110 |
+
|
111 |
+
|
112 |
+
def create_single_request(
|
113 |
+
image_path: str, question: str, options: Dict[str, str]
|
114 |
+
) -> List[Dict[str, Any]]:
|
115 |
+
"""
|
116 |
+
Create a single API request with image and question.
|
117 |
+
|
118 |
+
Args:
|
119 |
+
image_path (str): Path to the image file
|
120 |
+
question (str): Question text
|
121 |
+
options (Dict[str, str]): Dictionary containing options with keys 'option_0' and 'option_1'
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
List[Dict[str, Any]]: List of message dictionaries for the API request
|
125 |
+
|
126 |
+
Raises:
|
127 |
+
Exception: For errors in request creation
|
128 |
+
"""
|
129 |
+
if DEBUG_MODE:
|
130 |
+
logger.debug("Creating API request...")
|
131 |
+
|
132 |
+
prompt = f"""Given the following medical examination question:
|
133 |
+
Please answer this multiple choice question:
|
134 |
+
|
135 |
+
Question: {question}
|
136 |
+
|
137 |
+
Options:
|
138 |
+
A) {options['option_0']}
|
139 |
+
B) {options['option_1']}
|
140 |
+
|
141 |
+
Base your answer only on the provided image and select either A or B."""
|
142 |
+
|
143 |
+
try:
|
144 |
+
encoded_image = encode_image(image_path)
|
145 |
+
mime_type = get_mime_type(image_path)
|
146 |
+
|
147 |
+
if DEBUG_MODE:
|
148 |
+
logger.debug(f"Image encoded with MIME type: {mime_type}")
|
149 |
+
|
150 |
+
messages = [
|
151 |
+
{
|
152 |
+
"role": "system",
|
153 |
+
"content": "You are taking a medical exam. Answer ONLY with the letter (A/B) corresponding to your answer.",
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"role": "user",
|
157 |
+
"content": [
|
158 |
+
{"type": "text", "text": prompt},
|
159 |
+
{
|
160 |
+
"type": "image_url",
|
161 |
+
"image_url": {"url": f"data:{mime_type};base64,{encoded_image}"},
|
162 |
+
},
|
163 |
+
],
|
164 |
+
},
|
165 |
+
]
|
166 |
+
|
167 |
+
if DEBUG_MODE:
|
168 |
+
log_messages = json.loads(json.dumps(messages))
|
169 |
+
log_messages[1]["content"][1]["image_url"][
|
170 |
+
"url"
|
171 |
+
] = f"data:{mime_type};base64,[BASE64_IMAGE_TRUNCATED]"
|
172 |
+
logger.debug(f"Complete API request payload:\n{json.dumps(log_messages, indent=2)}")
|
173 |
+
|
174 |
+
return messages
|
175 |
+
|
176 |
+
except Exception as e:
|
177 |
+
logger.error(f"Error creating request: {str(e)}")
|
178 |
+
raise
|
179 |
+
|
180 |
+
|
181 |
+
def check_answer(model_answer: str, correct_answer: int) -> bool:
|
182 |
+
"""
|
183 |
+
Check if the model's answer matches the correct answer.
|
184 |
+
|
185 |
+
Args:
|
186 |
+
model_answer (str): The model's answer (A or B)
|
187 |
+
correct_answer (int): The correct answer index (0 for A, 1 for B)
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
bool: True if answer is correct, False otherwise
|
191 |
+
"""
|
192 |
+
if not isinstance(model_answer, str):
|
193 |
+
return False
|
194 |
+
|
195 |
+
# Clean the model answer to get just the letter
|
196 |
+
model_letter = model_answer.strip().upper()
|
197 |
+
if model_letter.startswith("A"):
|
198 |
+
model_index = 0
|
199 |
+
elif model_letter.startswith("B"):
|
200 |
+
model_index = 1
|
201 |
+
else:
|
202 |
+
return False
|
203 |
+
|
204 |
+
return model_index == correct_answer
|
205 |
+
|
206 |
+
|
207 |
+
def save_results_to_json(results: List[Dict[str, Any]], output_dir: str) -> str:
|
208 |
+
"""
|
209 |
+
Save results to a JSON file with timestamp.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
213 |
+
output_dir (str): Directory to save results
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
str: Path to the saved file
|
217 |
+
"""
|
218 |
+
Path(output_dir).mkdir(parents=True, exist_ok=True)
|
219 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
220 |
+
output_file = os.path.join(output_dir, f"batch_results_{timestamp}.json")
|
221 |
+
|
222 |
+
with open(output_file, "w") as f:
|
223 |
+
json.dump(results, f, indent=2)
|
224 |
+
|
225 |
+
logger.info(f"Batch results saved to {output_file}")
|
226 |
+
return output_file
|
227 |
+
|
228 |
+
|
229 |
+
def calculate_accuracy(results: List[Dict[str, Any]]) -> tuple[float, int, int]:
|
230 |
+
"""
|
231 |
+
Calculate accuracy from results, handling error cases.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
235 |
+
|
236 |
+
Returns:
|
237 |
+
tuple[float, int, int]: Tuple containing (accuracy percentage, number correct, total)
|
238 |
+
"""
|
239 |
+
if not results:
|
240 |
+
return 0.0, 0, 0
|
241 |
+
|
242 |
+
total = len(results)
|
243 |
+
valid_results = [r for r in results if "output" in r]
|
244 |
+
correct = sum(
|
245 |
+
1 for result in valid_results if result.get("output", {}).get("is_correct", False)
|
246 |
+
)
|
247 |
+
|
248 |
+
accuracy = (correct / total * 100) if total > 0 else 0
|
249 |
+
return accuracy, correct, total
|
250 |
+
|
251 |
+
|
252 |
+
def calculate_batch_accuracy(results: List[Dict[str, Any]]) -> float:
|
253 |
+
"""
|
254 |
+
Calculate accuracy for the current batch.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
results (List[Dict[str, Any]]): List of result dictionaries
|
258 |
+
|
259 |
+
Returns:
|
260 |
+
float: Accuracy percentage for the batch
|
261 |
+
"""
|
262 |
+
valid_results = [r for r in results if "output" in r]
|
263 |
+
if not valid_results:
|
264 |
+
return 0.0
|
265 |
+
return sum(1 for r in valid_results if r["output"]["is_correct"]) / len(valid_results) * 100
|
266 |
+
|
267 |
+
|
268 |
+
def process_batch(
|
269 |
+
data: List[Dict[str, Any]], client: openai.OpenAI, start_idx: int = 0, batch_size: int = 50
|
270 |
+
) -> List[Dict[str, Any]]:
|
271 |
+
"""
|
272 |
+
Process a batch of examples and return results.
|
273 |
+
|
274 |
+
Args:
|
275 |
+
data (List[Dict[str, Any]]): List of data items to process
|
276 |
+
client (openai.OpenAI): OpenAI client instance
|
277 |
+
start_idx (int, optional): Starting index for batch. Defaults to 0
|
278 |
+
batch_size (int, optional): Size of batch to process. Defaults to 50
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
List[Dict[str, Any]]: List of processed results
|
282 |
+
"""
|
283 |
+
batch_results = []
|
284 |
+
end_idx = min(start_idx + batch_size, len(data))
|
285 |
+
|
286 |
+
pbar = tqdm(
|
287 |
+
range(start_idx, end_idx),
|
288 |
+
desc=f"Processing batch {start_idx//batch_size + 1}",
|
289 |
+
unit="example",
|
290 |
+
)
|
291 |
+
|
292 |
+
for index in pbar:
|
293 |
+
vqa_item = data[index]
|
294 |
+
options = {"option_0": vqa_item["option_0"], "option_1": vqa_item["option_1"]}
|
295 |
+
|
296 |
+
try:
|
297 |
+
messages = create_single_request(
|
298 |
+
image_path=vqa_item["image_path"], question=vqa_item["question"], options=options
|
299 |
+
)
|
300 |
+
|
301 |
+
response = client.chat.completions.create(
|
302 |
+
model=MODEL_NAME, messages=messages, max_tokens=50, temperature=TEMPERATURE
|
303 |
+
)
|
304 |
+
|
305 |
+
model_answer = response.choices[0].message.content.strip()
|
306 |
+
is_correct = check_answer(model_answer, vqa_item["answer"])
|
307 |
+
|
308 |
+
result = {
|
309 |
+
"timestamp": datetime.now().isoformat(),
|
310 |
+
"example_index": index,
|
311 |
+
"input": {
|
312 |
+
"question": vqa_item["question"],
|
313 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
314 |
+
"image_path": vqa_item["image_path"],
|
315 |
+
},
|
316 |
+
"output": {
|
317 |
+
"model_answer": model_answer,
|
318 |
+
"correct_answer": "A" if vqa_item["answer"] == 0 else "B",
|
319 |
+
"is_correct": is_correct,
|
320 |
+
"usage": {
|
321 |
+
"prompt_tokens": response.usage.prompt_tokens,
|
322 |
+
"completion_tokens": response.usage.completion_tokens,
|
323 |
+
"total_tokens": response.usage.total_tokens,
|
324 |
+
},
|
325 |
+
},
|
326 |
+
}
|
327 |
+
batch_results.append(result)
|
328 |
+
|
329 |
+
# Update progress bar with current accuracy
|
330 |
+
current_accuracy = calculate_batch_accuracy(batch_results)
|
331 |
+
pbar.set_description(
|
332 |
+
f"Batch {start_idx//batch_size + 1} - Accuracy: {current_accuracy:.2f}% "
|
333 |
+
f"({len(batch_results)}/{index-start_idx+1} examples)"
|
334 |
+
)
|
335 |
+
|
336 |
+
except Exception as e:
|
337 |
+
error_result = {
|
338 |
+
"timestamp": datetime.now().isoformat(),
|
339 |
+
"example_index": index,
|
340 |
+
"error": str(e),
|
341 |
+
"input": {
|
342 |
+
"question": vqa_item["question"],
|
343 |
+
"options": {"A": vqa_item["option_0"], "B": vqa_item["option_1"]},
|
344 |
+
"image_path": vqa_item["image_path"],
|
345 |
+
},
|
346 |
+
}
|
347 |
+
batch_results.append(error_result)
|
348 |
+
if DEBUG_MODE:
|
349 |
+
pbar.write(f"Error processing example {index}: {str(e)}")
|
350 |
+
|
351 |
+
time.sleep(1) # Rate limiting
|
352 |
+
|
353 |
+
return batch_results
|
354 |
+
|
355 |
+
|
356 |
+
def main() -> None:
|
357 |
+
"""
|
358 |
+
Main function to process the entire dataset.
|
359 |
+
|
360 |
+
Raises:
|
361 |
+
ValueError: If OPENAI_API_KEY is not set
|
362 |
+
Exception: For other processing errors
|
363 |
+
"""
|
364 |
+
logger.info("Starting full dataset processing...")
|
365 |
+
json_path = "../data/chexbench_updated.json"
|
366 |
+
|
367 |
+
try:
|
368 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
369 |
+
if not api_key:
|
370 |
+
raise ValueError("OPENAI_API_KEY environment variable is not set.")
|
371 |
+
client = openai.OpenAI(api_key=api_key)
|
372 |
+
|
373 |
+
with open(json_path, "r") as f:
|
374 |
+
data = json.load(f)
|
375 |
+
|
376 |
+
subset_data = data[SUBSET]
|
377 |
+
total_examples = len(subset_data)
|
378 |
+
logger.info(f"Found {total_examples} examples in {SUBSET} subset")
|
379 |
+
|
380 |
+
all_results = []
|
381 |
+
batch_size = 50 # Process in batches of 50 examples
|
382 |
+
|
383 |
+
# Process all examples in batches
|
384 |
+
for start_idx in range(0, total_examples, batch_size):
|
385 |
+
batch_results = process_batch(subset_data, client, start_idx, batch_size)
|
386 |
+
all_results.extend(batch_results)
|
387 |
+
|
388 |
+
# Save intermediate results after each batch
|
389 |
+
output_file = save_results_to_json(all_results, OUTPUT_DIR)
|
390 |
+
|
391 |
+
# Calculate and log overall progress
|
392 |
+
overall_accuracy, correct, total = calculate_accuracy(all_results)
|
393 |
+
logger.info(f"Overall Progress: {len(all_results)}/{total_examples} examples processed")
|
394 |
+
logger.info(f"Current Accuracy: {overall_accuracy:.2f}% ({correct}/{total} correct)")
|
395 |
+
|
396 |
+
logger.info("Processing completed!")
|
397 |
+
logger.info(f"Final results saved to: {output_file}")
|
398 |
+
|
399 |
+
except Exception as e:
|
400 |
+
logger.error(f"Fatal error: {str(e)}")
|
401 |
+
raise
|
402 |
+
|
403 |
+
|
404 |
+
if __name__ == "__main__":
|
405 |
+
main()
|
experiments/compare_runs.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import argparse
|
3 |
+
import random
|
4 |
+
from typing import List, Dict, Any, Tuple
|
5 |
+
import re
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
# Define category order
|
9 |
+
CATEGORY_ORDER = [
|
10 |
+
"detection",
|
11 |
+
"classification",
|
12 |
+
"localization",
|
13 |
+
"comparison",
|
14 |
+
"relationship",
|
15 |
+
"diagnosis",
|
16 |
+
"characterization",
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def extract_letter_answer(answer: str) -> str:
|
21 |
+
"""Extract just the letter answer from various answer formats.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
answer: The answer string to extract a letter from
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
str: The extracted letter in uppercase, or empty string if no letter found
|
28 |
+
"""
|
29 |
+
if not answer:
|
30 |
+
return ""
|
31 |
+
|
32 |
+
# Convert to string and clean
|
33 |
+
answer = str(answer).strip()
|
34 |
+
|
35 |
+
# If it's just a single letter A-F, return it
|
36 |
+
if len(answer) == 1 and answer.upper() in "ABCDEF":
|
37 |
+
return answer.upper()
|
38 |
+
|
39 |
+
# Try to match patterns like "A)", "A.", "A ", etc.
|
40 |
+
match = re.match(r"^([A-F])[).\s]", answer, re.IGNORECASE)
|
41 |
+
if match:
|
42 |
+
return match.group(1).upper()
|
43 |
+
|
44 |
+
# Try to find any standalone A-F letters preceded by space or start of string
|
45 |
+
# and followed by space, period, parenthesis or end of string
|
46 |
+
matches = re.findall(r"(?:^|\s)([A-F])(?:[).\s]|$)", answer, re.IGNORECASE)
|
47 |
+
if matches:
|
48 |
+
return matches[0].upper()
|
49 |
+
|
50 |
+
# Last resort: just find any A-F letter
|
51 |
+
letters = re.findall(r"[A-F]", answer, re.IGNORECASE)
|
52 |
+
if letters:
|
53 |
+
return letters[0].upper()
|
54 |
+
|
55 |
+
# If no letter found, return original (cleaned)
|
56 |
+
return answer.strip().upper()
|
57 |
+
|
58 |
+
|
59 |
+
def parse_json_lines(file_path: str) -> Tuple[str, List[Dict[str, Any]]]:
|
60 |
+
"""Parse JSON Lines file and extract valid predictions.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
file_path: Path to the JSON Lines file to parse
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Tuple containing:
|
67 |
+
- str: Model name or file path if model name not found
|
68 |
+
- List[Dict[str, Any]]: List of valid prediction entries
|
69 |
+
"""
|
70 |
+
valid_predictions = []
|
71 |
+
model_name = None
|
72 |
+
|
73 |
+
# First try to parse as LLaVA format
|
74 |
+
try:
|
75 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
76 |
+
data = json.load(f)
|
77 |
+
if data.get("model") == "llava-med-v1.5-mistral-7b":
|
78 |
+
model_name = data["model"]
|
79 |
+
for result in data.get("results", []):
|
80 |
+
if all(k in result for k in ["case_id", "question_id", "correct_answer"]):
|
81 |
+
# Extract answer with priority: model_answer > validated_answer > raw_output
|
82 |
+
model_answer = (
|
83 |
+
result.get("model_answer")
|
84 |
+
or result.get("validated_answer")
|
85 |
+
or result.get("raw_output", "")
|
86 |
+
)
|
87 |
+
|
88 |
+
# Add default categories for LLaVA results
|
89 |
+
prediction = {
|
90 |
+
"case_id": result["case_id"],
|
91 |
+
"question_id": result["question_id"],
|
92 |
+
"model_answer": model_answer,
|
93 |
+
"correct_answer": result["correct_answer"],
|
94 |
+
"input": {
|
95 |
+
"question_data": {
|
96 |
+
"metadata": {
|
97 |
+
"categories": [
|
98 |
+
"detection",
|
99 |
+
"classification",
|
100 |
+
"localization",
|
101 |
+
"comparison",
|
102 |
+
"relationship",
|
103 |
+
"diagnosis",
|
104 |
+
"characterization",
|
105 |
+
]
|
106 |
+
}
|
107 |
+
}
|
108 |
+
},
|
109 |
+
}
|
110 |
+
valid_predictions.append(prediction)
|
111 |
+
return model_name, valid_predictions
|
112 |
+
except (json.JSONDecodeError, KeyError):
|
113 |
+
pass
|
114 |
+
|
115 |
+
# If not LLaVA format, process as original format
|
116 |
+
with open(file_path, "r", encoding="utf-8") as f:
|
117 |
+
for line in f:
|
118 |
+
if line.startswith("HTTP Request:"):
|
119 |
+
continue
|
120 |
+
try:
|
121 |
+
data = json.loads(line.strip())
|
122 |
+
if "model" in data:
|
123 |
+
model_name = data["model"]
|
124 |
+
if all(
|
125 |
+
k in data for k in ["model_answer", "correct_answer", "case_id", "question_id"]
|
126 |
+
):
|
127 |
+
valid_predictions.append(data)
|
128 |
+
except json.JSONDecodeError:
|
129 |
+
continue
|
130 |
+
|
131 |
+
return model_name if model_name else file_path, valid_predictions
|
132 |
+
|
133 |
+
|
134 |
+
def filter_common_questions(
|
135 |
+
predictions_list: List[List[Dict[str, Any]]]
|
136 |
+
) -> List[List[Dict[str, Any]]]:
|
137 |
+
"""Ensure only questions that exist across all models are evaluated.
|
138 |
+
|
139 |
+
Args:
|
140 |
+
predictions_list: List of prediction lists from different models
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
List[List[Dict[str, Any]]]: Filtered predictions containing only common questions
|
144 |
+
"""
|
145 |
+
question_sets = [
|
146 |
+
set((p["case_id"], p["question_id"]) for p in preds) for preds in predictions_list
|
147 |
+
]
|
148 |
+
common_questions = set.intersection(*question_sets)
|
149 |
+
|
150 |
+
return [
|
151 |
+
[p for p in preds if (p["case_id"], p["question_id"]) in common_questions]
|
152 |
+
for preds in predictions_list
|
153 |
+
]
|
154 |
+
|
155 |
+
|
156 |
+
def calculate_accuracy(
|
157 |
+
predictions: List[Dict[str, Any]]
|
158 |
+
) -> Tuple[float, int, int, Dict[str, Dict[str, float]]]:
|
159 |
+
"""Compute overall and category-level accuracy.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
predictions: List of prediction entries to analyze
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
Tuple containing:
|
166 |
+
- float: Overall accuracy percentage
|
167 |
+
- int: Number of correct predictions
|
168 |
+
- int: Total number of predictions
|
169 |
+
- Dict[str, Dict[str, float]]: Category-level accuracy statistics
|
170 |
+
"""
|
171 |
+
if not predictions:
|
172 |
+
return 0.0, 0, 0, {}
|
173 |
+
|
174 |
+
category_performance = defaultdict(lambda: {"total": 0, "correct": 0})
|
175 |
+
correct = 0
|
176 |
+
total = 0
|
177 |
+
sample_size = min(5, len(predictions))
|
178 |
+
sampled_indices = random.sample(range(len(predictions)), sample_size)
|
179 |
+
|
180 |
+
print("\nSample extracted answers:")
|
181 |
+
for i in sampled_indices:
|
182 |
+
pred = predictions[i]
|
183 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
184 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
185 |
+
print(f"QID: {pred['question_id']}")
|
186 |
+
print(f" Raw Model Answer: {pred['model_answer']}")
|
187 |
+
print(f" Extracted Model Answer: {model_ans}")
|
188 |
+
print(f" Raw Correct Answer: {pred['correct_answer']}")
|
189 |
+
print(f" Extracted Correct Answer: {correct_ans}")
|
190 |
+
print("-" * 80)
|
191 |
+
|
192 |
+
for pred in predictions:
|
193 |
+
try:
|
194 |
+
model_ans = extract_letter_answer(pred["model_answer"])
|
195 |
+
correct_ans = extract_letter_answer(pred["correct_answer"])
|
196 |
+
categories = (
|
197 |
+
pred.get("input", {})
|
198 |
+
.get("question_data", {})
|
199 |
+
.get("metadata", {})
|
200 |
+
.get("categories", [])
|
201 |
+
)
|
202 |
+
|
203 |
+
if model_ans and correct_ans:
|
204 |
+
total += 1
|
205 |
+
is_correct = model_ans == correct_ans
|
206 |
+
if is_correct:
|
207 |
+
correct += 1
|
208 |
+
|
209 |
+
for category in categories:
|
210 |
+
category_performance[category]["total"] += 1
|
211 |
+
if is_correct:
|
212 |
+
category_performance[category]["correct"] += 1
|
213 |
+
|
214 |
+
except KeyError:
|
215 |
+
continue
|
216 |
+
|
217 |
+
category_accuracies = {
|
218 |
+
category: {
|
219 |
+
"accuracy": (stats["correct"] / stats["total"]) * 100 if stats["total"] > 0 else 0,
|
220 |
+
"total": stats["total"],
|
221 |
+
"correct": stats["correct"],
|
222 |
+
}
|
223 |
+
for category, stats in category_performance.items()
|
224 |
+
}
|
225 |
+
|
226 |
+
return (correct / total * 100 if total > 0 else 0.0, correct, total, category_accuracies)
|
227 |
+
|
228 |
+
|
229 |
+
def compare_models(file_paths: List[str]) -> None:
|
230 |
+
"""Compare accuracy between multiple model prediction files.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
file_paths: List of paths to model prediction files to compare
|
234 |
+
"""
|
235 |
+
# Parse all files
|
236 |
+
parsed_results = [parse_json_lines(file_path) for file_path in file_paths]
|
237 |
+
model_names, predictions_list = zip(*parsed_results)
|
238 |
+
|
239 |
+
# Get initial stats
|
240 |
+
print(f"\n📊 **Initial Accuracy**:")
|
241 |
+
results = []
|
242 |
+
category_results = []
|
243 |
+
|
244 |
+
for preds, name in zip(predictions_list, model_names):
|
245 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
246 |
+
results.append((acc, correct, total, name))
|
247 |
+
category_results.append(category_acc)
|
248 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
249 |
+
|
250 |
+
# Get common questions across all models
|
251 |
+
filtered_predictions = filter_common_questions(predictions_list)
|
252 |
+
print(
|
253 |
+
f"\nQuestions per model after ensuring common questions: {[len(p) for p in filtered_predictions]}"
|
254 |
+
)
|
255 |
+
|
256 |
+
# Compute accuracy on common questions
|
257 |
+
print(f"\n📊 **Accuracy on Common Questions**:")
|
258 |
+
filtered_results = []
|
259 |
+
filtered_category_results = []
|
260 |
+
|
261 |
+
for preds, name in zip(filtered_predictions, model_names):
|
262 |
+
acc, correct, total, category_acc = calculate_accuracy(preds)
|
263 |
+
filtered_results.append((acc, correct, total, name))
|
264 |
+
filtered_category_results.append(category_acc)
|
265 |
+
print(f"{name}: Accuracy = {acc:.2f}% ({correct}/{total} correct)")
|
266 |
+
|
267 |
+
# Print category-wise accuracy
|
268 |
+
print("\nCategory Performance (Common Questions):")
|
269 |
+
for category in CATEGORY_ORDER:
|
270 |
+
print(f"\n{category.capitalize()}:")
|
271 |
+
for model_name, category_acc in zip(model_names, filtered_category_results):
|
272 |
+
stats = category_acc.get(category, {"accuracy": 0, "total": 0, "correct": 0})
|
273 |
+
print(f" {model_name}: {stats['accuracy']:.2f}% ({stats['correct']}/{stats['total']})")
|
274 |
+
|
275 |
+
|
276 |
+
def main():
|
277 |
+
parser = argparse.ArgumentParser(
|
278 |
+
description="Compare accuracy across multiple model prediction files"
|
279 |
+
)
|
280 |
+
parser.add_argument("files", nargs="+", help="Paths to model prediction files")
|
281 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")
|
282 |
+
|
283 |
+
args = parser.parse_args()
|
284 |
+
random.seed(args.seed)
|
285 |
+
|
286 |
+
compare_models(args.files)
|
287 |
+
|
288 |
+
|
289 |
+
if __name__ == "__main__":
|
290 |
+
main()
|
experiments/inspect_logs.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional, List
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
from datetime import datetime
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest_log() -> str:
|
10 |
+
"""Find the most recently modified log file in the current directory.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: Path to the most recently modified log file
|
14 |
+
|
15 |
+
Raises:
|
16 |
+
FileNotFoundError: If no log files are found in the current directory
|
17 |
+
"""
|
18 |
+
logs = list(Path(".").glob("api_usage_*.json"))
|
19 |
+
if not logs:
|
20 |
+
raise FileNotFoundError("No log files found in the current directory.")
|
21 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
22 |
+
|
23 |
+
|
24 |
+
def format_cost(entry: dict) -> str:
|
25 |
+
"""Format cost if available, otherwise return 'N/A'
|
26 |
+
|
27 |
+
Args:
|
28 |
+
entry: Log entry dictionary containing cost information
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
str: Formatted cost string with $ and 4 decimal places, or 'N/A' if cost not found
|
32 |
+
"""
|
33 |
+
return f"${entry.get('cost', 'N/A'):.4f}" if "cost" in entry else "N/A"
|
34 |
+
|
35 |
+
|
36 |
+
def print_gpt4_entry(entry: dict) -> None:
|
37 |
+
"""Print entry for GPT-4 format
|
38 |
+
|
39 |
+
Args:
|
40 |
+
entry: Log entry dictionary in GPT-4 format containing model info, inputs and outputs
|
41 |
+
"""
|
42 |
+
print("\n=== Log Entry ===")
|
43 |
+
print(f"Model: {entry['model']}")
|
44 |
+
print(f"Case ID: {entry['case_id']}")
|
45 |
+
print(f"Question ID: {entry['question_id']}")
|
46 |
+
|
47 |
+
print("\n=== Model Input ===")
|
48 |
+
messages = entry["input"]["messages"]
|
49 |
+
print("System message:", messages[0]["content"])
|
50 |
+
user_content = messages[1]["content"]
|
51 |
+
print("\nUser prompt:", user_content[0]["text"])
|
52 |
+
print("\nImages provided:")
|
53 |
+
for content in user_content[1:]:
|
54 |
+
print(f" - {content['image_url']['url']}")
|
55 |
+
|
56 |
+
print("\n=== Model Output ===")
|
57 |
+
print(f"Answer: {entry['model_answer']}")
|
58 |
+
print(f"Correct: {entry['correct_answer']}")
|
59 |
+
|
60 |
+
print("\n=== Usage Stats ===")
|
61 |
+
print(f"Duration: {entry['duration']}s")
|
62 |
+
print(f"Cost: {format_cost(entry)}")
|
63 |
+
print(
|
64 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
65 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
66 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def print_llama_entry(entry: dict) -> None:
|
71 |
+
"""Print entry for Llama-3.2 format
|
72 |
+
|
73 |
+
Args:
|
74 |
+
entry: Log entry dictionary in Llama format containing model info, inputs and outputs
|
75 |
+
"""
|
76 |
+
print("\n=== Log Entry ===")
|
77 |
+
print(f"Model: {entry['model']}")
|
78 |
+
print(f"Case ID: {entry['case_id']}")
|
79 |
+
print(f"Question ID: {entry['question_id']}")
|
80 |
+
|
81 |
+
print("\n=== Model Input ===")
|
82 |
+
print(f"Question: {entry['input']['question_data']['question']}")
|
83 |
+
print("\nImages provided:")
|
84 |
+
for url in entry["input"]["image_urls"]:
|
85 |
+
print(f" - {url}")
|
86 |
+
if entry["input"]["image_captions"]:
|
87 |
+
print("\nImage captions:")
|
88 |
+
for caption in entry["input"]["image_captions"]:
|
89 |
+
if caption:
|
90 |
+
print(f" - {caption}")
|
91 |
+
|
92 |
+
print("\n=== Model Output ===")
|
93 |
+
print(f"Answer: {entry['model_answer']}")
|
94 |
+
print(f"Correct: {entry['correct_answer']}")
|
95 |
+
|
96 |
+
print("\n=== Usage Stats ===")
|
97 |
+
print(f"Duration: {entry['duration']}s")
|
98 |
+
if "usage" in entry:
|
99 |
+
print(
|
100 |
+
f"Tokens: {entry['usage']['total_tokens']}",
|
101 |
+
f"(prompt: {entry['usage']['prompt_tokens']},",
|
102 |
+
f"completion: {entry['usage']['completion_tokens']})",
|
103 |
+
)
|
104 |
+
|
105 |
+
|
106 |
+
def determine_model_type(entry: dict) -> str:
|
107 |
+
"""Determine the model type from the entry
|
108 |
+
|
109 |
+
Args:
|
110 |
+
entry: Log entry dictionary containing model information
|
111 |
+
|
112 |
+
Returns:
|
113 |
+
str: Model type - 'gpt4', 'llama', or 'unknown'
|
114 |
+
"""
|
115 |
+
model = entry.get("model", "").lower()
|
116 |
+
if "gpt-4" in model:
|
117 |
+
return "gpt4"
|
118 |
+
elif "llama" in model:
|
119 |
+
return "llama"
|
120 |
+
elif "chexagent" in model:
|
121 |
+
return "chexagent"
|
122 |
+
elif "medrax" in model:
|
123 |
+
return "medrax"
|
124 |
+
else:
|
125 |
+
return "unknown"
|
126 |
+
|
127 |
+
|
128 |
+
def print_log_entry(
|
129 |
+
log_file: Optional[str] = None,
|
130 |
+
num_entries: Optional[int] = None,
|
131 |
+
model_filter: Optional[str] = None,
|
132 |
+
) -> None:
|
133 |
+
"""Print log entries from the specified log file or the latest log file.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
log_file: Path to the log file. If None, uses the latest log file.
|
137 |
+
num_entries: Number of entries to print. If None, prints all entries.
|
138 |
+
model_filter: Filter entries by model type ('gpt4' or 'llama'). If None, prints all.
|
139 |
+
"""
|
140 |
+
if log_file is None:
|
141 |
+
log_file = get_latest_log()
|
142 |
+
print(f"Using latest log file: {log_file}")
|
143 |
+
|
144 |
+
entries_printed = 0
|
145 |
+
total_entries = 0
|
146 |
+
filtered_entries = 0
|
147 |
+
|
148 |
+
with open(log_file, "r") as f:
|
149 |
+
for line in f:
|
150 |
+
if line.startswith("HTTP"):
|
151 |
+
continue
|
152 |
+
try:
|
153 |
+
total_entries += 1
|
154 |
+
entry = json.loads(line)
|
155 |
+
|
156 |
+
# Apply model filter if specified
|
157 |
+
model_type = determine_model_type(entry)
|
158 |
+
if model_filter and model_type != model_filter:
|
159 |
+
filtered_entries += 1
|
160 |
+
continue
|
161 |
+
|
162 |
+
if model_type == "gpt4":
|
163 |
+
print_gpt4_entry(entry)
|
164 |
+
elif model_type == "llama":
|
165 |
+
print_llama_entry(entry)
|
166 |
+
else:
|
167 |
+
print(f"Unknown model type in entry: {entry['model']}")
|
168 |
+
continue
|
169 |
+
|
170 |
+
print("=" * 50)
|
171 |
+
entries_printed += 1
|
172 |
+
if num_entries and entries_printed >= num_entries:
|
173 |
+
break
|
174 |
+
|
175 |
+
except (json.JSONDecodeError, KeyError) as e:
|
176 |
+
print(f"Error processing entry: {e}")
|
177 |
+
continue
|
178 |
+
|
179 |
+
print(f"\nSummary:")
|
180 |
+
print(f"Total entries: {total_entries}")
|
181 |
+
print(f"Entries printed: {entries_printed}")
|
182 |
+
if model_filter:
|
183 |
+
print(f"Entries filtered: {filtered_entries}")
|
184 |
+
|
185 |
+
|
186 |
+
def main() -> None:
|
187 |
+
"""Main entry point for the script"""
|
188 |
+
parser = argparse.ArgumentParser(
|
189 |
+
description="Parse and display log entries from API usage logs."
|
190 |
+
)
|
191 |
+
parser.add_argument("-l", "--log_file", nargs="?", help="Path to the log file (optional)")
|
192 |
+
parser.add_argument("-n", "--num_entries", type=int, help="Number of entries to display")
|
193 |
+
parser.add_argument(
|
194 |
+
"-m",
|
195 |
+
"--model",
|
196 |
+
choices=["gpt4", "llama"],
|
197 |
+
default="gpt4",
|
198 |
+
help="Model type to display (default: gpt4)",
|
199 |
+
)
|
200 |
+
args = parser.parse_args()
|
201 |
+
|
202 |
+
try:
|
203 |
+
print_log_entry(args.log_file, args.num_entries, args.model)
|
204 |
+
except FileNotFoundError as e:
|
205 |
+
print(f"Error: {e}")
|
206 |
+
exit(1)
|
207 |
+
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
main()
|
experiments/validate_logs.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Optional
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import glob
|
5 |
+
from pathlib import Path
|
6 |
+
from collections import defaultdict
|
7 |
+
|
8 |
+
|
9 |
+
def get_latest_log() -> str:
|
10 |
+
"""Find the most recently modified log file in the current directory.
|
11 |
+
|
12 |
+
Returns:
|
13 |
+
str: Path to the most recently modified log file
|
14 |
+
|
15 |
+
Raises:
|
16 |
+
SystemExit: If no log files are found in current directory
|
17 |
+
"""
|
18 |
+
log_pattern = "api_usage_*.json"
|
19 |
+
logs = list(Path(".").glob(log_pattern))
|
20 |
+
if not logs:
|
21 |
+
print(f"No files matching pattern '{log_pattern}' found in current directory")
|
22 |
+
sys.exit(1)
|
23 |
+
return str(max(logs, key=lambda p: p.stat().st_mtime))
|
24 |
+
|
25 |
+
|
26 |
+
def analyze_log_file(filename: str) -> Tuple[List[Dict], List[Dict], Dict[str, List[str]]]:
|
27 |
+
"""Analyze a log file for entries missing images and errors.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
filename: Path to the log file to analyze
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
Tuple containing:
|
34 |
+
- List of entries with no images
|
35 |
+
- List of skipped/error entries
|
36 |
+
- Dict of processing errors by type
|
37 |
+
|
38 |
+
Raises:
|
39 |
+
SystemExit: If file cannot be found or read
|
40 |
+
"""
|
41 |
+
no_images = []
|
42 |
+
errors = defaultdict(list)
|
43 |
+
skipped = []
|
44 |
+
|
45 |
+
try:
|
46 |
+
with open(filename, "r") as f:
|
47 |
+
for line_num, line in enumerate(f, 1):
|
48 |
+
# Skip HTTP request logs
|
49 |
+
if line.startswith("HTTP Request:") or line.strip() == "":
|
50 |
+
continue
|
51 |
+
try:
|
52 |
+
# Try to parse the JSON line
|
53 |
+
if not line.strip().startswith("{"):
|
54 |
+
continue
|
55 |
+
entry = json.loads(line.strip())
|
56 |
+
case_id = entry.get("case_id")
|
57 |
+
question_id = entry.get("question_id")
|
58 |
+
|
59 |
+
# Skip if we can't identify the question
|
60 |
+
if not case_id or not question_id:
|
61 |
+
continue
|
62 |
+
|
63 |
+
# Check for explicit skip/error status
|
64 |
+
if entry.get("status") in ["skipped", "error"]:
|
65 |
+
skipped.append(
|
66 |
+
{
|
67 |
+
"case_id": case_id,
|
68 |
+
"question_id": question_id,
|
69 |
+
"reason": entry.get("reason"),
|
70 |
+
"status": entry.get("status"),
|
71 |
+
}
|
72 |
+
)
|
73 |
+
continue
|
74 |
+
|
75 |
+
# Check user content for images
|
76 |
+
messages = entry.get("input", {}).get("messages", [])
|
77 |
+
has_image = False
|
78 |
+
for msg in messages:
|
79 |
+
content = msg.get("content", [])
|
80 |
+
if isinstance(content, list):
|
81 |
+
for item in content:
|
82 |
+
if isinstance(item, dict) and item.get("type") == "image_url":
|
83 |
+
has_image = True
|
84 |
+
break
|
85 |
+
if not has_image:
|
86 |
+
no_images.append(
|
87 |
+
{
|
88 |
+
"case_id": case_id,
|
89 |
+
"question_id": question_id,
|
90 |
+
"question": entry.get("input", {})
|
91 |
+
.get("question_data", {})
|
92 |
+
.get("question", "")[:100]
|
93 |
+
+ "...", # First 100 chars of question
|
94 |
+
}
|
95 |
+
)
|
96 |
+
except json.JSONDecodeError:
|
97 |
+
errors["json_decode"].append(f"Line {line_num}: Invalid JSON")
|
98 |
+
continue
|
99 |
+
except Exception as e:
|
100 |
+
errors["other"].append(f"Line {line_num}: Error processing entry: {str(e)}")
|
101 |
+
except FileNotFoundError:
|
102 |
+
print(f"Error: Could not find log file: {filename}")
|
103 |
+
sys.exit(1)
|
104 |
+
except Exception as e:
|
105 |
+
print(f"Error reading file {filename}: {str(e)}")
|
106 |
+
sys.exit(1)
|
107 |
+
|
108 |
+
return no_images, skipped, errors
|
109 |
+
|
110 |
+
|
111 |
+
def print_results(
|
112 |
+
filename: str, no_images: List[Dict], skipped: List[Dict], errors: Dict[str, List[str]]
|
113 |
+
) -> None:
|
114 |
+
"""Print analysis results.
|
115 |
+
|
116 |
+
Args:
|
117 |
+
filename: Name of the analyzed log file
|
118 |
+
no_images: List of entries with no images
|
119 |
+
skipped: List of skipped/error entries
|
120 |
+
errors: Dict of processing errors by type
|
121 |
+
"""
|
122 |
+
print(f"\nAnalyzing log file: {filename}")
|
123 |
+
print("\n=== Questions with No Images ===")
|
124 |
+
if no_images:
|
125 |
+
for entry in no_images:
|
126 |
+
print(f"\nCase ID: {entry['case_id']}")
|
127 |
+
print(f"Question ID: {entry['question_id']}")
|
128 |
+
print(f"Question Preview: {entry['question']}")
|
129 |
+
print(f"\nTotal questions without images: {len(no_images)}")
|
130 |
+
|
131 |
+
print("\n=== Skipped/Error Questions ===")
|
132 |
+
if skipped:
|
133 |
+
for entry in skipped:
|
134 |
+
print(f"\nCase ID: {entry['case_id']}")
|
135 |
+
print(f"Question ID: {entry['question_id']}")
|
136 |
+
print(f"Status: {entry['status']}")
|
137 |
+
print(f"Reason: {entry.get('reason', 'unknown')}")
|
138 |
+
print(f"\nTotal skipped/error questions: {len(skipped)}")
|
139 |
+
|
140 |
+
if errors:
|
141 |
+
print("\n=== Processing Errors ===")
|
142 |
+
for error_type, messages in errors.items():
|
143 |
+
if messages:
|
144 |
+
print(f"\n{error_type}:")
|
145 |
+
for msg in messages:
|
146 |
+
print(f" {msg}")
|
147 |
+
|
148 |
+
|
149 |
+
def main() -> None:
|
150 |
+
"""Main entry point for log validation script."""
|
151 |
+
# If a file is specified as an argument, use it; otherwise find the latest log
|
152 |
+
if len(sys.argv) > 1:
|
153 |
+
log_file = sys.argv[1]
|
154 |
+
else:
|
155 |
+
log_file = get_latest_log()
|
156 |
+
|
157 |
+
no_images, skipped, errors = analyze_log_file(log_file)
|
158 |
+
print_results(log_file, no_images, skipped, errors)
|
159 |
+
|
160 |
+
|
161 |
+
if __name__ == "__main__":
|
162 |
+
main()
|
interface.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import base64
|
3 |
+
import gradio as gr
|
4 |
+
from pathlib import Path
|
5 |
+
import time
|
6 |
+
import shutil
|
7 |
+
from typing import AsyncGenerator, List, Optional, Tuple
|
8 |
+
from gradio import ChatMessage
|
9 |
+
|
10 |
+
|
11 |
+
class ChatInterface:
|
12 |
+
"""
|
13 |
+
A chat interface for interacting with a medical AI agent through Gradio.
|
14 |
+
|
15 |
+
Handles file uploads, message processing, and chat history management.
|
16 |
+
Supports both regular image files and DICOM medical imaging files.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, agent, tools_dict):
|
20 |
+
"""
|
21 |
+
Initialize the chat interface.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
agent: The medical AI agent to handle requests
|
25 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
26 |
+
"""
|
27 |
+
self.agent = agent
|
28 |
+
self.tools_dict = tools_dict
|
29 |
+
self.upload_dir = Path("temp")
|
30 |
+
self.upload_dir.mkdir(exist_ok=True)
|
31 |
+
self.current_thread_id = None
|
32 |
+
# Separate storage for original and display paths
|
33 |
+
self.original_file_path = None # For LLM (.dcm or other)
|
34 |
+
self.display_file_path = None # For UI (always viewable format)
|
35 |
+
|
36 |
+
def handle_upload(self, file_path: str) -> str:
|
37 |
+
"""
|
38 |
+
Handle new file upload and set appropriate paths.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
file_path (str): Path to the uploaded file
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
str: Display path for UI, or None if no file uploaded
|
45 |
+
"""
|
46 |
+
if not file_path:
|
47 |
+
return None
|
48 |
+
|
49 |
+
source = Path(file_path)
|
50 |
+
timestamp = int(time.time())
|
51 |
+
|
52 |
+
# Save original file with proper suffix
|
53 |
+
suffix = source.suffix.lower()
|
54 |
+
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}"
|
55 |
+
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source
|
56 |
+
self.original_file_path = str(saved_path)
|
57 |
+
|
58 |
+
# Handle DICOM conversion for display only
|
59 |
+
if suffix == ".dcm":
|
60 |
+
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path))
|
61 |
+
self.display_file_path = output["image_path"]
|
62 |
+
else:
|
63 |
+
self.display_file_path = str(saved_path)
|
64 |
+
|
65 |
+
return self.display_file_path
|
66 |
+
|
67 |
+
def add_message(
|
68 |
+
self, message: str, display_image: str, history: List[dict]
|
69 |
+
) -> Tuple[List[dict], gr.Textbox]:
|
70 |
+
"""
|
71 |
+
Add a new message to the chat history.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
message (str): Text message to add
|
75 |
+
display_image (str): Path to image being displayed
|
76 |
+
history (List[dict]): Current chat history
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
Tuple[List[dict], gr.Textbox]: Updated history and textbox component
|
80 |
+
"""
|
81 |
+
image_path = self.original_file_path or display_image
|
82 |
+
if image_path is not None:
|
83 |
+
history.append({"role": "user", "content": {"path": image_path}})
|
84 |
+
if message is not None:
|
85 |
+
history.append({"role": "user", "content": message})
|
86 |
+
return history, gr.Textbox(value=message, interactive=False)
|
87 |
+
|
88 |
+
async def process_message(
|
89 |
+
self, message: str, display_image: Optional[str], chat_history: List[ChatMessage]
|
90 |
+
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]:
|
91 |
+
"""
|
92 |
+
Process a message and generate responses.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
message (str): User message to process
|
96 |
+
display_image (Optional[str]): Path to currently displayed image
|
97 |
+
chat_history (List[ChatMessage]): Current chat history
|
98 |
+
|
99 |
+
Yields:
|
100 |
+
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string
|
101 |
+
"""
|
102 |
+
chat_history = chat_history or []
|
103 |
+
|
104 |
+
# Initialize thread if needed
|
105 |
+
if not self.current_thread_id:
|
106 |
+
self.current_thread_id = str(time.time())
|
107 |
+
|
108 |
+
messages = []
|
109 |
+
image_path = self.original_file_path or display_image
|
110 |
+
|
111 |
+
if image_path is not None:
|
112 |
+
# Send path for tools
|
113 |
+
messages.append({"role": "user", "content": f"image_path: {image_path}"})
|
114 |
+
|
115 |
+
# Load and encode image for multimodal
|
116 |
+
with open(image_path, "rb") as img_file:
|
117 |
+
img_base64 = base64.b64encode(img_file.read()).decode("utf-8")
|
118 |
+
|
119 |
+
messages.append(
|
120 |
+
{
|
121 |
+
"role": "user",
|
122 |
+
"content": [
|
123 |
+
{
|
124 |
+
"type": "image_url",
|
125 |
+
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
|
126 |
+
}
|
127 |
+
],
|
128 |
+
}
|
129 |
+
)
|
130 |
+
|
131 |
+
if message is not None:
|
132 |
+
messages.append({"role": "user", "content": [{"type": "text", "text": message}]})
|
133 |
+
|
134 |
+
try:
|
135 |
+
for event in self.agent.workflow.stream(
|
136 |
+
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}}
|
137 |
+
):
|
138 |
+
if isinstance(event, dict):
|
139 |
+
if "process" in event:
|
140 |
+
content = event["process"]["messages"][-1].content
|
141 |
+
if content:
|
142 |
+
content = re.sub(r"temp/[^\s]*", "", content)
|
143 |
+
chat_history.append(ChatMessage(role="assistant", content=content))
|
144 |
+
yield chat_history, self.display_file_path, ""
|
145 |
+
|
146 |
+
elif "execute" in event:
|
147 |
+
for message in event["execute"]["messages"]:
|
148 |
+
tool_name = message.name
|
149 |
+
tool_result = eval(message.content)[0]
|
150 |
+
|
151 |
+
if tool_result:
|
152 |
+
metadata = {"title": f"🖼️ Image from tool: {tool_name}"}
|
153 |
+
formatted_result = " ".join(
|
154 |
+
line.strip() for line in str(tool_result).splitlines()
|
155 |
+
).strip()
|
156 |
+
metadata["description"] = formatted_result
|
157 |
+
chat_history.append(
|
158 |
+
ChatMessage(
|
159 |
+
role="assistant",
|
160 |
+
content=formatted_result,
|
161 |
+
metadata=metadata,
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
# For image_visualizer, use display path
|
166 |
+
if tool_name == "image_visualizer":
|
167 |
+
self.display_file_path = tool_result["image_path"]
|
168 |
+
chat_history.append(
|
169 |
+
ChatMessage(
|
170 |
+
role="assistant",
|
171 |
+
# content=gr.Image(value=self.display_file_path),
|
172 |
+
content={"path": self.display_file_path},
|
173 |
+
)
|
174 |
+
)
|
175 |
+
|
176 |
+
yield chat_history, self.display_file_path, ""
|
177 |
+
|
178 |
+
except Exception as e:
|
179 |
+
chat_history.append(
|
180 |
+
ChatMessage(
|
181 |
+
role="assistant", content=f"❌ Error: {str(e)}", metadata={"title": "Error"}
|
182 |
+
)
|
183 |
+
)
|
184 |
+
yield chat_history, self.display_file_path
|
185 |
+
|
186 |
+
|
187 |
+
def create_demo(agent, tools_dict):
|
188 |
+
"""
|
189 |
+
Create a Gradio demo interface for the medical AI agent.
|
190 |
+
|
191 |
+
Args:
|
192 |
+
agent: The medical AI agent to handle requests
|
193 |
+
tools_dict (dict): Dictionary of available tools for image processing
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
gr.Blocks: Gradio Blocks interface
|
197 |
+
"""
|
198 |
+
interface = ChatInterface(agent, tools_dict)
|
199 |
+
|
200 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
201 |
+
with gr.Column():
|
202 |
+
gr.Markdown(
|
203 |
+
"""
|
204 |
+
# 🏥 MedRAX
|
205 |
+
Medical Reasoning Agent for Chest X-ray
|
206 |
+
"""
|
207 |
+
)
|
208 |
+
|
209 |
+
with gr.Row():
|
210 |
+
with gr.Column(scale=3):
|
211 |
+
chatbot = gr.Chatbot(
|
212 |
+
[],
|
213 |
+
height=800,
|
214 |
+
container=True,
|
215 |
+
show_label=True,
|
216 |
+
elem_classes="chat-box",
|
217 |
+
type="messages",
|
218 |
+
label="Agent",
|
219 |
+
avatar_images=(
|
220 |
+
None,
|
221 |
+
"assets/medrax_logo.jpg",
|
222 |
+
),
|
223 |
+
)
|
224 |
+
with gr.Row():
|
225 |
+
with gr.Column(scale=3):
|
226 |
+
txt = gr.Textbox(
|
227 |
+
show_label=False,
|
228 |
+
placeholder="Ask about the X-ray...",
|
229 |
+
container=False,
|
230 |
+
)
|
231 |
+
|
232 |
+
with gr.Column(scale=3):
|
233 |
+
image_display = gr.Image(
|
234 |
+
label="Image", type="filepath", height=700, container=True
|
235 |
+
)
|
236 |
+
with gr.Row():
|
237 |
+
upload_button = gr.UploadButton(
|
238 |
+
"📎 Upload X-Ray",
|
239 |
+
file_types=["image"],
|
240 |
+
)
|
241 |
+
dicom_upload = gr.UploadButton(
|
242 |
+
"📄 Upload DICOM",
|
243 |
+
file_types=["file"],
|
244 |
+
)
|
245 |
+
with gr.Row():
|
246 |
+
clear_btn = gr.Button("Clear Chat")
|
247 |
+
new_thread_btn = gr.Button("New Thread")
|
248 |
+
|
249 |
+
# Event handlers
|
250 |
+
def clear_chat():
|
251 |
+
interface.original_file_path = None
|
252 |
+
interface.display_file_path = None
|
253 |
+
return [], None
|
254 |
+
|
255 |
+
def new_thread():
|
256 |
+
interface.current_thread_id = str(time.time())
|
257 |
+
return [], interface.display_file_path
|
258 |
+
|
259 |
+
def handle_file_upload(file):
|
260 |
+
return interface.handle_upload(file.name)
|
261 |
+
|
262 |
+
chat_msg = txt.submit(
|
263 |
+
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt]
|
264 |
+
)
|
265 |
+
bot_msg = chat_msg.then(
|
266 |
+
interface.process_message,
|
267 |
+
inputs=[txt, image_display, chatbot],
|
268 |
+
outputs=[chatbot, image_display, txt],
|
269 |
+
)
|
270 |
+
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt])
|
271 |
+
|
272 |
+
upload_button.upload(handle_file_upload, inputs=upload_button, outputs=image_display)
|
273 |
+
|
274 |
+
dicom_upload.upload(handle_file_upload, inputs=dicom_upload, outputs=image_display)
|
275 |
+
|
276 |
+
clear_btn.click(clear_chat, outputs=[chatbot, image_display])
|
277 |
+
new_thread_btn.click(new_thread, outputs=[chatbot, image_display])
|
278 |
+
|
279 |
+
return demo
|
main.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import warnings
|
3 |
+
from typing import *
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from transformers import logging
|
6 |
+
|
7 |
+
from langgraph.checkpoint.memory import MemorySaver
|
8 |
+
from langchain_openai import ChatOpenAI
|
9 |
+
from langgraph.checkpoint.memory import MemorySaver
|
10 |
+
from langchain_openai import ChatOpenAI
|
11 |
+
|
12 |
+
from interface import create_demo
|
13 |
+
from medrax.agent import *
|
14 |
+
from medrax.tools import *
|
15 |
+
from medrax.utils import *
|
16 |
+
|
17 |
+
warnings.filterwarnings("ignore")
|
18 |
+
logging.set_verbosity_error()
|
19 |
+
_ = load_dotenv()
|
20 |
+
|
21 |
+
|
22 |
+
def initialize_agent(
|
23 |
+
prompt_file,
|
24 |
+
tools_to_use=None,
|
25 |
+
model_dir="/model-weights",
|
26 |
+
temp_dir="temp",
|
27 |
+
device="cuda",
|
28 |
+
model="google/gemini-2.5-pro-exp-03-25:free",
|
29 |
+
temperature=0.7,
|
30 |
+
top_p=0.95,
|
31 |
+
openai_kwargs=openai_kwargs
|
32 |
+
|
33 |
+
):
|
34 |
+
"""Initialize the MedRAX agent with specified tools and configuration.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prompt_file (str): Path to file containing system prompts
|
38 |
+
tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
|
39 |
+
model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
|
40 |
+
temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
|
41 |
+
device (str, optional): Device to run models on. Defaults to "cuda".
|
42 |
+
model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
|
43 |
+
temperature (float, optional): Temperature for the model. Defaults to 0.7.
|
44 |
+
top_p (float, optional): Top P for the model. Defaults to 0.95.
|
45 |
+
openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
|
49 |
+
"""
|
50 |
+
prompts = load_prompts_from_file(prompt_file)
|
51 |
+
prompt = prompts["MEDICAL_ASSISTANT"]
|
52 |
+
|
53 |
+
all_tools = {
|
54 |
+
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
|
55 |
+
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
|
56 |
+
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
|
57 |
+
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
|
58 |
+
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
|
59 |
+
cache_dir=model_dir, device=device
|
60 |
+
),
|
61 |
+
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
|
62 |
+
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
|
63 |
+
),
|
64 |
+
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
|
65 |
+
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
|
66 |
+
),
|
67 |
+
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
|
68 |
+
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
|
69 |
+
}
|
70 |
+
|
71 |
+
# Initialize only selected tools or all if none specified
|
72 |
+
tools_dict = {}
|
73 |
+
tools_to_use = tools_to_use or all_tools.keys()
|
74 |
+
for tool_name in tools_to_use:
|
75 |
+
if tool_name in all_tools:
|
76 |
+
tools_dict[tool_name] = all_tools[tool_name]()
|
77 |
+
|
78 |
+
checkpointer = MemorySaver()
|
79 |
+
model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
|
80 |
+
agent = Agent(
|
81 |
+
model,
|
82 |
+
tools=list(tools_dict.values()),
|
83 |
+
log_tools=True,
|
84 |
+
log_dir="logs",
|
85 |
+
system_prompt=prompt,
|
86 |
+
checkpointer=checkpointer,
|
87 |
+
)
|
88 |
+
|
89 |
+
print("Agent initialized")
|
90 |
+
return agent, tools_dict
|
91 |
+
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
"""
|
95 |
+
This is the main entry point for the MedRAX application.
|
96 |
+
It initializes the agent with the selected tools and creates the demo.
|
97 |
+
"""
|
98 |
+
print("Starting server...")
|
99 |
+
|
100 |
+
# Example: initialize with only specific tools
|
101 |
+
# Here three tools are commented out, you can uncomment them to use them
|
102 |
+
selected_tools = [
|
103 |
+
"ImageVisualizerTool",
|
104 |
+
"DicomProcessorTool",
|
105 |
+
"ChestXRayClassifierTool",
|
106 |
+
"ChestXRaySegmentationTool",
|
107 |
+
"ChestXRayReportGeneratorTool",
|
108 |
+
"XRayVQATool",
|
109 |
+
# "LlavaMedTool",
|
110 |
+
# "XRayPhraseGroundingTool",
|
111 |
+
# "ChestXRayGeneratorTool",
|
112 |
+
]
|
113 |
+
|
114 |
+
# Collect the ENV variables
|
115 |
+
openai_kwargs = {}
|
116 |
+
if api_key := os.getenv("OPENAI_API_KEY"):
|
117 |
+
openai_kwargs["api_key"] = api_key
|
118 |
+
|
119 |
+
if base_url := os.getenv("OPENAI_BASE_URL"):
|
120 |
+
openai_kwargs["base_url"] = base_url
|
121 |
+
|
122 |
+
# openai_kwargs = {
|
123 |
+
# "openai_api_key": os.environ.get("OPENAI_API_KEY"),
|
124 |
+
# "openai_api_base": os.environ.get("OPENAI_BASE_URL"),
|
125 |
+
# }
|
126 |
+
|
127 |
+
|
128 |
+
agent, tools_dict = initialize_agent(
|
129 |
+
"medrax/docs/system_prompts.txt",
|
130 |
+
tools_to_use=selected_tools,
|
131 |
+
model_dir="/model-weights", # Change this to the path of the model weights
|
132 |
+
temp_dir="temp", # Change this to the path of the temporary directory
|
133 |
+
device="cuda", # Change this to the device you want to use
|
134 |
+
model="gpt-4o", # Change this to the model you want to use, e.g. gpt-4o-mini
|
135 |
+
temperature=0.7,
|
136 |
+
top_p=0.95,
|
137 |
+
openai_kwargs=openai_kwargs
|
138 |
+
)
|
139 |
+
demo = create_demo(agent, tools_dict)
|
140 |
+
|
141 |
+
demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
|
medrax/__init__.py
ADDED
File without changes
|
medrax/agent/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .agent import AgentState, Agent
|
medrax/agent/agent.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import operator
|
3 |
+
from pathlib import Path
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from datetime import datetime
|
6 |
+
from typing import List, Dict, Any, TypedDict, Annotated, Optional
|
7 |
+
|
8 |
+
from langgraph.graph import StateGraph, END
|
9 |
+
from langchain_core.messages import AnyMessage, SystemMessage, ToolMessage
|
10 |
+
from langchain_core.language_models import BaseLanguageModel
|
11 |
+
from langchain_core.tools import BaseTool
|
12 |
+
|
13 |
+
_ = load_dotenv()
|
14 |
+
|
15 |
+
|
16 |
+
class ToolCallLog(TypedDict):
|
17 |
+
"""
|
18 |
+
A TypedDict representing a log entry for a tool call.
|
19 |
+
|
20 |
+
Attributes:
|
21 |
+
timestamp (str): The timestamp of when the tool call was made.
|
22 |
+
tool_call_id (str): The unique identifier for the tool call.
|
23 |
+
name (str): The name of the tool that was called.
|
24 |
+
args (Any): The arguments passed to the tool.
|
25 |
+
content (str): The content or result of the tool call.
|
26 |
+
"""
|
27 |
+
|
28 |
+
timestamp: str
|
29 |
+
tool_call_id: str
|
30 |
+
name: str
|
31 |
+
args: Any
|
32 |
+
content: str
|
33 |
+
|
34 |
+
|
35 |
+
class AgentState(TypedDict):
|
36 |
+
"""
|
37 |
+
A TypedDict representing the state of an agent.
|
38 |
+
|
39 |
+
Attributes:
|
40 |
+
messages (Annotated[List[AnyMessage], operator.add]): A list of messages
|
41 |
+
representing the conversation history. The operator.add annotation
|
42 |
+
indicates that new messages should be appended to this list.
|
43 |
+
"""
|
44 |
+
|
45 |
+
messages: Annotated[List[AnyMessage], operator.add]
|
46 |
+
|
47 |
+
|
48 |
+
class Agent:
|
49 |
+
"""
|
50 |
+
A class representing an agent that processes requests and executes tools based on
|
51 |
+
language model responses.
|
52 |
+
|
53 |
+
Attributes:
|
54 |
+
model (BaseLanguageModel): The language model used for processing.
|
55 |
+
tools (Dict[str, BaseTool]): A dictionary of available tools.
|
56 |
+
checkpointer (Any): Manages and persists the agent's state.
|
57 |
+
system_prompt (str): The system instructions for the agent.
|
58 |
+
workflow (StateGraph): The compiled workflow for the agent's processing.
|
59 |
+
log_tools (bool): Whether to log tool calls.
|
60 |
+
log_path (Path): Path to save tool call logs.
|
61 |
+
"""
|
62 |
+
|
63 |
+
def __init__(
|
64 |
+
self,
|
65 |
+
model: BaseLanguageModel,
|
66 |
+
tools: List[BaseTool],
|
67 |
+
checkpointer: Any = None,
|
68 |
+
system_prompt: str = "",
|
69 |
+
log_tools: bool = True,
|
70 |
+
log_dir: Optional[str] = "logs",
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Initialize the Agent.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
model (BaseLanguageModel): The language model to use.
|
77 |
+
tools (List[BaseTool]): A list of available tools.
|
78 |
+
checkpointer (Any, optional): State persistence manager. Defaults to None.
|
79 |
+
system_prompt (str, optional): System instructions. Defaults to "".
|
80 |
+
log_tools (bool, optional): Whether to log tool calls. Defaults to True.
|
81 |
+
log_dir (str, optional): Directory to save logs. Defaults to 'logs'.
|
82 |
+
"""
|
83 |
+
self.system_prompt = system_prompt
|
84 |
+
self.log_tools = log_tools
|
85 |
+
|
86 |
+
if self.log_tools:
|
87 |
+
self.log_path = Path(log_dir or "logs")
|
88 |
+
self.log_path.mkdir(exist_ok=True)
|
89 |
+
|
90 |
+
# Define the agent workflow
|
91 |
+
workflow = StateGraph(AgentState)
|
92 |
+
workflow.add_node("process", self.process_request)
|
93 |
+
workflow.add_node("execute", self.execute_tools)
|
94 |
+
workflow.add_conditional_edges(
|
95 |
+
"process", self.has_tool_calls, {True: "execute", False: END}
|
96 |
+
)
|
97 |
+
workflow.add_edge("execute", "process")
|
98 |
+
workflow.set_entry_point("process")
|
99 |
+
|
100 |
+
self.workflow = workflow.compile(checkpointer=checkpointer)
|
101 |
+
self.tools = {t.name: t for t in tools}
|
102 |
+
self.model = model.bind_tools(tools)
|
103 |
+
|
104 |
+
def process_request(self, state: AgentState) -> Dict[str, List[AnyMessage]]:
|
105 |
+
"""
|
106 |
+
Process the request using the language model.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
state (AgentState): The current state of the agent.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
Dict[str, List[AnyMessage]]: A dictionary containing the model's response.
|
113 |
+
"""
|
114 |
+
messages = state["messages"]
|
115 |
+
if self.system_prompt:
|
116 |
+
messages = [SystemMessage(content=self.system_prompt)] + messages
|
117 |
+
response = self.model.invoke(messages)
|
118 |
+
return {"messages": [response]}
|
119 |
+
|
120 |
+
def has_tool_calls(self, state: AgentState) -> bool:
|
121 |
+
"""
|
122 |
+
Check if the response contains any tool calls.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
state (AgentState): The current state of the agent.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
bool: True if tool calls exist, False otherwise.
|
129 |
+
"""
|
130 |
+
response = state["messages"][-1]
|
131 |
+
return len(response.tool_calls) > 0
|
132 |
+
|
133 |
+
def execute_tools(self, state: AgentState) -> Dict[str, List[ToolMessage]]:
|
134 |
+
"""
|
135 |
+
Execute tool calls from the model's response.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
state (AgentState): The current state of the agent.
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
Dict[str, List[ToolMessage]]: A dictionary containing tool execution results.
|
142 |
+
"""
|
143 |
+
tool_calls = state["messages"][-1].tool_calls
|
144 |
+
results = []
|
145 |
+
|
146 |
+
for call in tool_calls:
|
147 |
+
print(f"Executing tool: {call}")
|
148 |
+
if call["name"] not in self.tools:
|
149 |
+
print("\n....invalid tool....")
|
150 |
+
result = "invalid tool, please retry"
|
151 |
+
else:
|
152 |
+
result = self.tools[call["name"]].invoke(call["args"])
|
153 |
+
|
154 |
+
results.append(
|
155 |
+
ToolMessage(
|
156 |
+
tool_call_id=call["id"],
|
157 |
+
name=call["name"],
|
158 |
+
args=call["args"],
|
159 |
+
content=str(result),
|
160 |
+
)
|
161 |
+
)
|
162 |
+
|
163 |
+
self._save_tool_calls(results)
|
164 |
+
print("Returning to model processing!")
|
165 |
+
|
166 |
+
return {"messages": results}
|
167 |
+
|
168 |
+
def _save_tool_calls(self, tool_calls: List[ToolMessage]) -> None:
|
169 |
+
"""
|
170 |
+
Save tool calls to a JSON file with timestamp-based naming.
|
171 |
+
|
172 |
+
Args:
|
173 |
+
tool_calls (List[ToolMessage]): List of tool calls to save.
|
174 |
+
"""
|
175 |
+
if not self.log_tools:
|
176 |
+
return
|
177 |
+
|
178 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
179 |
+
filename = self.log_path / f"tool_calls_{timestamp}.json"
|
180 |
+
|
181 |
+
logs: List[ToolCallLog] = []
|
182 |
+
for call in tool_calls:
|
183 |
+
log_entry = {
|
184 |
+
"tool_call_id": call.tool_call_id,
|
185 |
+
"name": call.name,
|
186 |
+
"args": call.args,
|
187 |
+
"content": call.content,
|
188 |
+
"timestamp": datetime.now().isoformat(),
|
189 |
+
}
|
190 |
+
logs.append(log_entry)
|
191 |
+
|
192 |
+
with open(filename, "w") as f:
|
193 |
+
json.dump(logs, f, indent=4)
|
medrax/docs/system_prompts.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[MEDICAL_ASSISTANT]
|
2 |
+
You are an expert medical AI assistant who can answer any medical questions and analyze medical images similar to a doctor.
|
3 |
+
Solve using your own vision and reasoning and use tools to complement your reasoning.
|
4 |
+
Make multiple tool calls in parallel or sequence as needed for comprehensive answers.
|
5 |
+
Critically think about and criticize the tool outputs.
|
6 |
+
If you need to look up some information before asking a follow up question, you are allowed to do that.
|
7 |
+
|
8 |
+
[GENERAL_ASSISTANT]
|
9 |
+
You are a helpful AI assistant. Your role is to assist users with a wide range of tasks and questions, providing accurate and useful information on various topics.
|
medrax/llava/__init__.py
ADDED
File without changes
|
medrax/llava/constants.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
3 |
+
|
4 |
+
LOGDIR = "."
|
5 |
+
|
6 |
+
# Model Constants
|
7 |
+
IGNORE_INDEX = -100
|
8 |
+
IMAGE_TOKEN_INDEX = -200
|
9 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
10 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
11 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
12 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
13 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
medrax/llava/conversation.py
ADDED
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
|
9 |
+
SINGLE = auto()
|
10 |
+
TWO = auto()
|
11 |
+
MPT = auto()
|
12 |
+
PLAIN = auto()
|
13 |
+
LLAMA_2 = auto()
|
14 |
+
MISTRAL = auto()
|
15 |
+
|
16 |
+
|
17 |
+
@dataclasses.dataclass
|
18 |
+
class Conversation:
|
19 |
+
"""A class that keeps all conversation history."""
|
20 |
+
|
21 |
+
system: str
|
22 |
+
roles: List[str]
|
23 |
+
messages: List[List[str]]
|
24 |
+
offset: int
|
25 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
26 |
+
sep: str = "###"
|
27 |
+
sep2: str = None
|
28 |
+
version: str = "Unknown"
|
29 |
+
|
30 |
+
skip_next: bool = False
|
31 |
+
|
32 |
+
def get_prompt(self):
|
33 |
+
messages = self.messages
|
34 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
35 |
+
messages = self.messages.copy()
|
36 |
+
init_role, init_msg = messages[0].copy()
|
37 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
38 |
+
if "mmtag" in self.version:
|
39 |
+
messages[0] = (init_role, init_msg)
|
40 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
41 |
+
messages.insert(1, (self.roles[1], "Received."))
|
42 |
+
else:
|
43 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
44 |
+
|
45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
46 |
+
ret = self.system + self.sep
|
47 |
+
for role, message in messages:
|
48 |
+
if message:
|
49 |
+
if type(message) is tuple:
|
50 |
+
message, _, _ = message
|
51 |
+
ret += role + ": " + message + self.sep
|
52 |
+
else:
|
53 |
+
ret += role + ":"
|
54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
55 |
+
seps = [self.sep, self.sep2]
|
56 |
+
ret = self.system + seps[0]
|
57 |
+
for i, (role, message) in enumerate(messages):
|
58 |
+
if message:
|
59 |
+
if type(message) is tuple:
|
60 |
+
message, _, _ = message
|
61 |
+
sep = seps[i % 2]
|
62 |
+
sep = "{0} ".format(self.sep2) if sep == self.sep2 else self.sep
|
63 |
+
ret += role + ": " + message.strip() + sep
|
64 |
+
else:
|
65 |
+
ret += role + ":"
|
66 |
+
ret = ret.strip()
|
67 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
68 |
+
ret = self.system + self.sep
|
69 |
+
for role, message in messages:
|
70 |
+
if message:
|
71 |
+
if type(message) is tuple:
|
72 |
+
message, _, _ = message
|
73 |
+
ret += role + message + self.sep
|
74 |
+
else:
|
75 |
+
ret += role
|
76 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
77 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
|
78 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
79 |
+
ret = ""
|
80 |
+
|
81 |
+
for i, (role, message) in enumerate(messages):
|
82 |
+
if i == 0:
|
83 |
+
assert message, "first message should not be none"
|
84 |
+
assert role == self.roles[0], "first message should come from user"
|
85 |
+
if message:
|
86 |
+
if type(message) is tuple:
|
87 |
+
message, _, _ = message
|
88 |
+
if i == 0:
|
89 |
+
message = wrap_sys(self.system) + message
|
90 |
+
if i % 2 == 0:
|
91 |
+
message = wrap_inst(message)
|
92 |
+
ret += self.sep + message
|
93 |
+
else:
|
94 |
+
ret += " " + message + " " + self.sep2
|
95 |
+
else:
|
96 |
+
ret += ""
|
97 |
+
ret = ret.lstrip(self.sep)
|
98 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
99 |
+
seps = [self.sep, self.sep2]
|
100 |
+
ret = self.system
|
101 |
+
for i, (role, message) in enumerate(messages):
|
102 |
+
if message:
|
103 |
+
if type(message) is tuple:
|
104 |
+
message, _, _ = message
|
105 |
+
ret += message + seps[i % 2]
|
106 |
+
else:
|
107 |
+
ret += ""
|
108 |
+
elif self.sep_style == SeparatorStyle.MISTRAL:
|
109 |
+
# reference: https://docs.mistral.ai/models/
|
110 |
+
wrap_sys = lambda msg: f"{msg}</s>"
|
111 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
112 |
+
ret = ""
|
113 |
+
for i, (role, message) in enumerate(messages):
|
114 |
+
if i == 0:
|
115 |
+
assert message, "first message should not be none"
|
116 |
+
assert role == self.roles[0], "first message should come from user"
|
117 |
+
if message:
|
118 |
+
if type(message) is tuple:
|
119 |
+
message, _, _ = message
|
120 |
+
if i == 0:
|
121 |
+
message = self.system + " " + message.strip()
|
122 |
+
if i % 2 == 0:
|
123 |
+
message = wrap_inst(message)
|
124 |
+
ret += message
|
125 |
+
else:
|
126 |
+
ret += wrap_sys(message)
|
127 |
+
else:
|
128 |
+
ret += ""
|
129 |
+
# wrap_sys = lambda msg: f"\n{msg}\n\n"
|
130 |
+
# wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
131 |
+
# ret = ""
|
132 |
+
# for i, (role, message) in enumerate(messages):
|
133 |
+
# if i == 0:
|
134 |
+
# assert message, "first message should not be none"
|
135 |
+
# assert role == self.roles[0], "first message should come from user"
|
136 |
+
# if message:
|
137 |
+
# if type(message) is tuple:
|
138 |
+
# message, _, _ = message
|
139 |
+
# if i == 0: message = wrap_sys(self.system) + message
|
140 |
+
# if i % 2 == 0:
|
141 |
+
# message = wrap_inst(message)
|
142 |
+
# ret += message if i != 0 else self.sep + message
|
143 |
+
# else:
|
144 |
+
# # NOTE-JW: we need to add " " to strictly follow Mistral Instruction Format
|
145 |
+
# ret += " " + message + " " + self.sep2
|
146 |
+
# # ret += " " + wrap_sys(message)
|
147 |
+
# else:
|
148 |
+
# ret += ""
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
151 |
+
|
152 |
+
return ret
|
153 |
+
|
154 |
+
def append_message(self, role, message):
|
155 |
+
self.messages.append([role, message])
|
156 |
+
|
157 |
+
def get_images(self, return_pil=False):
|
158 |
+
images = []
|
159 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
160 |
+
if i % 2 == 0:
|
161 |
+
if type(msg) is tuple:
|
162 |
+
import base64
|
163 |
+
from io import BytesIO
|
164 |
+
from PIL import Image
|
165 |
+
|
166 |
+
msg, image, image_process_mode = msg
|
167 |
+
if image_process_mode == "Pad":
|
168 |
+
|
169 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
170 |
+
width, height = pil_img.size
|
171 |
+
if width == height:
|
172 |
+
return pil_img
|
173 |
+
elif width > height:
|
174 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
175 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
176 |
+
return result
|
177 |
+
else:
|
178 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
179 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
180 |
+
return result
|
181 |
+
|
182 |
+
image = expand2square(image)
|
183 |
+
elif image_process_mode in ["Default", "Crop"]:
|
184 |
+
pass
|
185 |
+
elif image_process_mode == "Resize":
|
186 |
+
image = image.resize((336, 336))
|
187 |
+
else:
|
188 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
189 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
190 |
+
aspect_ratio = max_hw / min_hw
|
191 |
+
max_len, min_len = 800, 400
|
192 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
193 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
194 |
+
W, H = image.size
|
195 |
+
if longest_edge != max(image.size):
|
196 |
+
if H > W:
|
197 |
+
H, W = longest_edge, shortest_edge
|
198 |
+
else:
|
199 |
+
H, W = shortest_edge, longest_edge
|
200 |
+
image = image.resize((W, H))
|
201 |
+
if return_pil:
|
202 |
+
images.append(image)
|
203 |
+
else:
|
204 |
+
buffered = BytesIO()
|
205 |
+
image.save(buffered, format="PNG")
|
206 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
207 |
+
images.append(img_b64_str)
|
208 |
+
return images
|
209 |
+
|
210 |
+
def to_gradio_chatbot(self):
|
211 |
+
ret = []
|
212 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
213 |
+
if i % 2 == 0:
|
214 |
+
if type(msg) is tuple:
|
215 |
+
import base64
|
216 |
+
from io import BytesIO
|
217 |
+
|
218 |
+
msg, image, image_process_mode = msg
|
219 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
220 |
+
aspect_ratio = max_hw / min_hw
|
221 |
+
max_len, min_len = 800, 400
|
222 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
223 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
224 |
+
W, H = image.size
|
225 |
+
if H > W:
|
226 |
+
H, W = longest_edge, shortest_edge
|
227 |
+
else:
|
228 |
+
H, W = shortest_edge, longest_edge
|
229 |
+
image = image.resize((W, H))
|
230 |
+
buffered = BytesIO()
|
231 |
+
image.save(buffered, format="JPEG")
|
232 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
233 |
+
img_str = (
|
234 |
+
f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
235 |
+
)
|
236 |
+
msg = img_str + msg.replace("<image>", "").strip()
|
237 |
+
ret.append([msg, None])
|
238 |
+
else:
|
239 |
+
ret.append([msg, None])
|
240 |
+
else:
|
241 |
+
ret[-1][-1] = msg
|
242 |
+
return ret
|
243 |
+
|
244 |
+
def copy(self):
|
245 |
+
return Conversation(
|
246 |
+
system=self.system,
|
247 |
+
roles=self.roles,
|
248 |
+
messages=[[x, y] for x, y in self.messages],
|
249 |
+
offset=self.offset,
|
250 |
+
sep_style=self.sep_style,
|
251 |
+
sep=self.sep,
|
252 |
+
sep2=self.sep2,
|
253 |
+
version=self.version,
|
254 |
+
)
|
255 |
+
|
256 |
+
def dict(self):
|
257 |
+
if len(self.get_images()) > 0:
|
258 |
+
return {
|
259 |
+
"system": self.system,
|
260 |
+
"roles": self.roles,
|
261 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
262 |
+
"offset": self.offset,
|
263 |
+
"sep": self.sep,
|
264 |
+
"sep2": self.sep2,
|
265 |
+
}
|
266 |
+
return {
|
267 |
+
"system": self.system,
|
268 |
+
"roles": self.roles,
|
269 |
+
"messages": self.messages,
|
270 |
+
"offset": self.offset,
|
271 |
+
"sep": self.sep,
|
272 |
+
"sep2": self.sep2,
|
273 |
+
}
|
274 |
+
|
275 |
+
|
276 |
+
conv_vicuna_v0 = Conversation(
|
277 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
278 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
279 |
+
roles=("Human", "Assistant"),
|
280 |
+
messages=(
|
281 |
+
(
|
282 |
+
"Human",
|
283 |
+
"What are the key differences between renewable and non-renewable energy sources?",
|
284 |
+
),
|
285 |
+
(
|
286 |
+
"Assistant",
|
287 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
288 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
289 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
290 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
291 |
+
"renewable and non-renewable energy sources:\n"
|
292 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
293 |
+
"energy sources are finite and will eventually run out.\n"
|
294 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
295 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
296 |
+
"and other negative effects.\n"
|
297 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
298 |
+
"have lower operational costs than non-renewable sources.\n"
|
299 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
300 |
+
"locations than non-renewable sources.\n"
|
301 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
302 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
303 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
304 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
305 |
+
),
|
306 |
+
),
|
307 |
+
offset=2,
|
308 |
+
sep_style=SeparatorStyle.SINGLE,
|
309 |
+
sep="###",
|
310 |
+
)
|
311 |
+
|
312 |
+
conv_vicuna_v1 = Conversation(
|
313 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
314 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
315 |
+
roles=("USER", "ASSISTANT"),
|
316 |
+
version="v1",
|
317 |
+
messages=(),
|
318 |
+
offset=0,
|
319 |
+
sep_style=SeparatorStyle.TWO,
|
320 |
+
sep=" ",
|
321 |
+
sep2="</s>",
|
322 |
+
)
|
323 |
+
|
324 |
+
conv_llama_2 = Conversation(
|
325 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
326 |
+
|
327 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
328 |
+
roles=("USER", "ASSISTANT"),
|
329 |
+
version="llama_v2",
|
330 |
+
messages=(),
|
331 |
+
offset=0,
|
332 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
333 |
+
sep="<s>",
|
334 |
+
sep2="</s>",
|
335 |
+
)
|
336 |
+
|
337 |
+
conv_llava_llama_2 = Conversation(
|
338 |
+
system="You are a helpful language and vision assistant. "
|
339 |
+
"You are able to understand the visual content that the user provides, "
|
340 |
+
"and assist the user with a variety of tasks using natural language.",
|
341 |
+
roles=("USER", "ASSISTANT"),
|
342 |
+
version="llama_v2",
|
343 |
+
messages=(),
|
344 |
+
offset=0,
|
345 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
346 |
+
sep="<s>",
|
347 |
+
sep2="</s>",
|
348 |
+
)
|
349 |
+
|
350 |
+
conv_mpt = Conversation(
|
351 |
+
system="""<|im_start|>system
|
352 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
353 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
354 |
+
version="mpt",
|
355 |
+
messages=(),
|
356 |
+
offset=0,
|
357 |
+
sep_style=SeparatorStyle.MPT,
|
358 |
+
sep="<|im_end|>",
|
359 |
+
)
|
360 |
+
|
361 |
+
conv_llava_plain = Conversation(
|
362 |
+
system="",
|
363 |
+
roles=("", ""),
|
364 |
+
messages=(),
|
365 |
+
offset=0,
|
366 |
+
sep_style=SeparatorStyle.PLAIN,
|
367 |
+
sep="\n",
|
368 |
+
)
|
369 |
+
|
370 |
+
conv_llava_v0 = Conversation(
|
371 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
372 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
373 |
+
roles=("Human", "Assistant"),
|
374 |
+
messages=(),
|
375 |
+
offset=0,
|
376 |
+
sep_style=SeparatorStyle.SINGLE,
|
377 |
+
sep="###",
|
378 |
+
)
|
379 |
+
|
380 |
+
conv_llava_v0_mmtag = Conversation(
|
381 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
382 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
383 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
384 |
+
roles=("Human", "Assistant"),
|
385 |
+
messages=(),
|
386 |
+
offset=0,
|
387 |
+
sep_style=SeparatorStyle.SINGLE,
|
388 |
+
sep="###",
|
389 |
+
version="v0_mmtag",
|
390 |
+
)
|
391 |
+
|
392 |
+
conv_llava_v1 = Conversation(
|
393 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
394 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
395 |
+
roles=("USER", "ASSISTANT"),
|
396 |
+
version="v1",
|
397 |
+
messages=(),
|
398 |
+
offset=0,
|
399 |
+
sep_style=SeparatorStyle.TWO,
|
400 |
+
sep=" ",
|
401 |
+
sep2="</s>",
|
402 |
+
)
|
403 |
+
|
404 |
+
conv_llava_v1_mmtag = Conversation(
|
405 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
406 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
407 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
408 |
+
roles=("USER", "ASSISTANT"),
|
409 |
+
messages=(),
|
410 |
+
offset=0,
|
411 |
+
sep_style=SeparatorStyle.TWO,
|
412 |
+
sep=" ",
|
413 |
+
sep2="</s>",
|
414 |
+
version="v1_mmtag",
|
415 |
+
)
|
416 |
+
|
417 |
+
conv_mistral_instruct = Conversation(
|
418 |
+
system="",
|
419 |
+
roles=("USER", "ASSISTANT"),
|
420 |
+
version="llama_v2",
|
421 |
+
messages=(),
|
422 |
+
offset=0,
|
423 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
424 |
+
sep="",
|
425 |
+
sep2="</s>",
|
426 |
+
)
|
427 |
+
|
428 |
+
default_conversation = conv_vicuna_v1
|
429 |
+
conv_templates = {
|
430 |
+
"default": conv_vicuna_v0,
|
431 |
+
"v0": conv_vicuna_v0,
|
432 |
+
"v1": conv_vicuna_v1,
|
433 |
+
"vicuna_v1": conv_vicuna_v1,
|
434 |
+
"llama_2": conv_llama_2,
|
435 |
+
"mistral_instruct": conv_mistral_instruct,
|
436 |
+
"plain": conv_llava_plain,
|
437 |
+
"v0_plain": conv_llava_plain,
|
438 |
+
"llava_v0": conv_llava_v0,
|
439 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
440 |
+
"llava_v1": conv_llava_v1,
|
441 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
442 |
+
"llava_llama_2": conv_llava_llama_2,
|
443 |
+
"mpt": conv_mpt,
|
444 |
+
}
|
445 |
+
|
446 |
+
|
447 |
+
if __name__ == "__main__":
|
448 |
+
print(default_conversation.get_prompt())
|
medrax/llava/eval/eval_multimodal_chat_gpt_score.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import argparse
|
4 |
+
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
from tqdm import tqdm
|
7 |
+
|
8 |
+
import llm
|
9 |
+
import util
|
10 |
+
|
11 |
+
|
12 |
+
INSTRUCT_PROMPT = """We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with caption describing the same image.
|
13 |
+
Please rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.
|
14 |
+
Please first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space. In the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."""
|
15 |
+
ROLE = "Assistant"
|
16 |
+
|
17 |
+
# Generate instruction for GPT-4 to score the two answers.
|
18 |
+
def conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
19 |
+
return (
|
20 |
+
f"[Context]\n"
|
21 |
+
f"Figure Caption:\n{fig_label}: {fig_caption}\n\n"
|
22 |
+
f"Figure Context:\n\t- {fig_context}\n\n"
|
23 |
+
f"[Question]\n{question}\n\n"
|
24 |
+
f"[{ROLE} 1]\n{ans1}\n\n[End of {ROLE} 1]\n\n"
|
25 |
+
f"[{ROLE} 2]\n{ans2}\n\n[End of {ROLE} 2]\n\n"
|
26 |
+
f"[System]\n{INSTRUCT_PROMPT}\n\n"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
def compare_messages_gen(fig_label, fig_caption, fig_context, question, ans1, ans2):
|
31 |
+
messages = [
|
32 |
+
{
|
33 |
+
"role": "system",
|
34 |
+
"content": """'You are a helpful and precise assistant for checking the quality of the answer.""",
|
35 |
+
},
|
36 |
+
]
|
37 |
+
messages.append(
|
38 |
+
{
|
39 |
+
"role": "user",
|
40 |
+
"content": conv_to_str(fig_label, fig_caption, fig_context, question, ans1, ans2),
|
41 |
+
}
|
42 |
+
)
|
43 |
+
return messages
|
44 |
+
|
45 |
+
|
46 |
+
def sum_list_list(x):
|
47 |
+
return sum(item for inner_list in x for item in inner_list)
|
48 |
+
|
49 |
+
|
50 |
+
def chunk(lst, n):
|
51 |
+
for i in range(0, len(lst), n):
|
52 |
+
if i + (1.5 * n) < len(lst):
|
53 |
+
end = i + n
|
54 |
+
else:
|
55 |
+
end = len(lst)
|
56 |
+
yield lst[i:end]
|
57 |
+
if end == len(lst):
|
58 |
+
return
|
59 |
+
|
60 |
+
|
61 |
+
def infer(samples):
|
62 |
+
model_inst = llm.GPT("gpt-4-0314")
|
63 |
+
|
64 |
+
BATCH_SIZE = 1
|
65 |
+
batch_samples = []
|
66 |
+
results = []
|
67 |
+
batch = []
|
68 |
+
|
69 |
+
print("Starting Multimodal Chat GPT Scoring Eval")
|
70 |
+
|
71 |
+
for sample in tqdm(samples):
|
72 |
+
sample_copy = deepcopy(sample)
|
73 |
+
input_msg = compare_messages_gen(
|
74 |
+
sample_copy["fig_label"],
|
75 |
+
sample_copy["fig_caption"],
|
76 |
+
sample_copy["in_text_mention"],
|
77 |
+
sample_copy["question"],
|
78 |
+
sample_copy["ans1"],
|
79 |
+
sample_copy["ans2"],
|
80 |
+
)
|
81 |
+
batch.append(input_msg)
|
82 |
+
batch_samples.append(sample_copy)
|
83 |
+
if len(batch) >= BATCH_SIZE:
|
84 |
+
inference_results = [
|
85 |
+
x.strip()
|
86 |
+
for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE)
|
87 |
+
for x in model_inst.infer(chunk_messages)
|
88 |
+
]
|
89 |
+
for item, inference_result in zip(batch_samples, inference_results):
|
90 |
+
item["gpt_eval"] = inference_result
|
91 |
+
results.extend(batch_samples)
|
92 |
+
batch = []
|
93 |
+
batch_samples = []
|
94 |
+
inference_results = [
|
95 |
+
x.strip()
|
96 |
+
for chunk_messages in chunk([x for x in batch if x], BATCH_SIZE)
|
97 |
+
for x in model_inst.infer(chunk_messages)
|
98 |
+
]
|
99 |
+
for item, inference_result in zip(batch_samples, inference_results):
|
100 |
+
item["gpt_eval"] = inference_result
|
101 |
+
results.extend(batch_samples)
|
102 |
+
print(f"Result Size: {len(results)}")
|
103 |
+
return results
|
104 |
+
|
105 |
+
|
106 |
+
def main(args):
|
107 |
+
answer_data = util.load_file_jsonl(args.answers_file)
|
108 |
+
question_data = util.load_file_jsonl(args.question_file)
|
109 |
+
|
110 |
+
samples = []
|
111 |
+
for question, answer in zip(question_data, answer_data):
|
112 |
+
question_copy = deepcopy(question)
|
113 |
+
question["question"] = question_copy["text"]
|
114 |
+
question["ans1"] = question_copy.pop("gpt4_answer")
|
115 |
+
question["ans2"] = answer["text"]
|
116 |
+
samples.append(question)
|
117 |
+
|
118 |
+
results = infer(samples)
|
119 |
+
|
120 |
+
# Create parent directory of output score files if it doesn't exist
|
121 |
+
os.makedirs(Path(args.scores_file).parent, exist_ok=True)
|
122 |
+
|
123 |
+
with open(args.scores_file, "w") as f:
|
124 |
+
for row in results:
|
125 |
+
f.write(json.dumps(row) + "\n")
|
126 |
+
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Scoring", add_help=True)
|
130 |
+
parser.add_argument(
|
131 |
+
"--answers-file", default="", metavar="FILE", help="path to model answer file"
|
132 |
+
)
|
133 |
+
parser.add_argument(
|
134 |
+
"--question-file",
|
135 |
+
default="data/questions/llava_med_eval_qa50_qa.jsonl",
|
136 |
+
metavar="FILE",
|
137 |
+
help="path to multichat questions file",
|
138 |
+
)
|
139 |
+
parser.add_argument(
|
140 |
+
"--scores-file", default="", metavar="FILE", help="path to save gpt-4 score file"
|
141 |
+
)
|
142 |
+
args = parser.parse_args()
|
143 |
+
main(args)
|
medrax/llava/eval/llm.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import asyncio
|
3 |
+
from abc import abstractmethod
|
4 |
+
import math
|
5 |
+
|
6 |
+
import tiktoken
|
7 |
+
import openai
|
8 |
+
import backoff
|
9 |
+
|
10 |
+
|
11 |
+
class LLM(abc.ABC):
|
12 |
+
|
13 |
+
prompt_percent = 0.9
|
14 |
+
|
15 |
+
@abstractmethod
|
16 |
+
def __init__(self):
|
17 |
+
raise NotImplementedError("Subclasses should implement this!")
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def infer(self, prompts):
|
21 |
+
raise NotImplementedError("Subclasses should implement this!")
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def split_input(
|
25 |
+
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
26 |
+
):
|
27 |
+
raise NotImplementedError("Subclasses should implement this!")
|
28 |
+
|
29 |
+
|
30 |
+
class GPT(LLM):
|
31 |
+
|
32 |
+
prompt_percent = 0.8
|
33 |
+
|
34 |
+
openai_cxn_dict = {
|
35 |
+
"default": {
|
36 |
+
"endpoint": "INSERT YOUR AZURE OPENAI ENDPOINT HERE",
|
37 |
+
"api_key": "INSERT YOUR AZURE OPENAI API KEY HERE",
|
38 |
+
},
|
39 |
+
}
|
40 |
+
|
41 |
+
deployment_max_length_dict = {
|
42 |
+
"gpt-4": 8192,
|
43 |
+
"gpt-4-0314": 8192,
|
44 |
+
"gpt-4-32k": 32768,
|
45 |
+
"gpt-35-turbo": 4096,
|
46 |
+
"gpt-35-turbo-16k": 16385,
|
47 |
+
}
|
48 |
+
|
49 |
+
def __init__(self, model_id):
|
50 |
+
self.temperature = 0.0
|
51 |
+
self.top_k = 1
|
52 |
+
self.encoding = tiktoken.encoding_for_model(
|
53 |
+
"-".join(model_id.split("-", 2)[:2]).replace("5", ".5")
|
54 |
+
)
|
55 |
+
self.openai_api = "default"
|
56 |
+
self.model_id = model_id
|
57 |
+
self.max_length = self.deployment_max_length_dict[model_id]
|
58 |
+
self.client = openai.AsyncAzureOpenAI(
|
59 |
+
api_key=self.openai_cxn_dict[self.openai_api]["api_key"],
|
60 |
+
api_version="2023-12-01-preview",
|
61 |
+
azure_endpoint=self.openai_cxn_dict[self.openai_api]["endpoint"],
|
62 |
+
)
|
63 |
+
|
64 |
+
def gen_messages(
|
65 |
+
self, fixed_instruction, few_shot_examples, input, input_header, output_header
|
66 |
+
):
|
67 |
+
messages = [
|
68 |
+
{
|
69 |
+
"role": "system",
|
70 |
+
"content": fixed_instruction,
|
71 |
+
},
|
72 |
+
]
|
73 |
+
for example in few_shot_examples:
|
74 |
+
messages.extend(
|
75 |
+
[
|
76 |
+
{
|
77 |
+
"role": "user",
|
78 |
+
"content": input_header + "\n" + example["user"] + "\n\n" + output_header,
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"role": "assistant",
|
82 |
+
"content": example["assistant"],
|
83 |
+
},
|
84 |
+
]
|
85 |
+
)
|
86 |
+
messages.extend(
|
87 |
+
[
|
88 |
+
{
|
89 |
+
"role": "user",
|
90 |
+
"content": input_header + "\n" + input + "\n\n" + output_header,
|
91 |
+
},
|
92 |
+
]
|
93 |
+
)
|
94 |
+
return messages
|
95 |
+
|
96 |
+
# Define the coroutine for making API calls to GPT
|
97 |
+
@backoff.on_exception(backoff.expo, openai.RateLimitError)
|
98 |
+
async def make_api_call_to_gpt(self, messages):
|
99 |
+
response = await self.client.chat.completions.create(
|
100 |
+
model=self.model_id,
|
101 |
+
messages=messages,
|
102 |
+
temperature=self.temperature,
|
103 |
+
)
|
104 |
+
return response.choices[0].message.content
|
105 |
+
|
106 |
+
async def dispatch_openai_requests(
|
107 |
+
self,
|
108 |
+
messages_list,
|
109 |
+
):
|
110 |
+
# Asynchronously call the function for each prompt
|
111 |
+
tasks = [self.make_api_call_to_gpt(messages) for messages in messages_list]
|
112 |
+
|
113 |
+
# Gather and run the tasks concurrently
|
114 |
+
results = await asyncio.gather(*tasks)
|
115 |
+
return results
|
116 |
+
|
117 |
+
def infer(
|
118 |
+
self,
|
119 |
+
messages_list,
|
120 |
+
):
|
121 |
+
return asyncio.run(self.dispatch_openai_requests(messages_list))
|
122 |
+
|
123 |
+
def split_input(
|
124 |
+
self, fixed_instruction, few_shot_examples, splittable_input, input_header, output_header
|
125 |
+
):
|
126 |
+
# Tokenize fixed_prompt
|
127 |
+
fixed_token_ids = self.encoding.encode(
|
128 |
+
fixed_instruction
|
129 |
+
+ " ".join([x["user"] + " " + x["assistant"] for x in few_shot_examples])
|
130 |
+
)
|
131 |
+
# Calculate remaining token length
|
132 |
+
remaining_token_len = math.ceil(
|
133 |
+
(self.prompt_percent * self.max_length) - len(fixed_token_ids)
|
134 |
+
)
|
135 |
+
|
136 |
+
# Tokenize splittable_input
|
137 |
+
split_token_ids = self.encoding.encode(splittable_input)
|
138 |
+
|
139 |
+
# Split tokenized split_prompt into list of individual inputs strings. Uses tokens to calculate length
|
140 |
+
split_token_ids_list = [
|
141 |
+
split_token_ids[i : i + remaining_token_len + 10]
|
142 |
+
for i in range(0, len(split_token_ids), remaining_token_len)
|
143 |
+
]
|
144 |
+
split_input_list = [
|
145 |
+
self.encoding.decode(split_token_ids) for split_token_ids in split_token_ids_list
|
146 |
+
]
|
147 |
+
|
148 |
+
# Take the fixed_prompt, few_shot_examples, splitted inputs, and input/output headers and generate list of prompt strings.
|
149 |
+
return [
|
150 |
+
self.gen_messages(
|
151 |
+
fixed_instruction, few_shot_examples, split_input, input_header, output_header
|
152 |
+
)
|
153 |
+
for split_input in split_input_list
|
154 |
+
]
|
medrax/llava/eval/model_vqa.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import json
|
5 |
+
from tqdm import tqdm
|
6 |
+
import shortuuid
|
7 |
+
|
8 |
+
from medrax.llava.constants import (
|
9 |
+
IMAGE_TOKEN_INDEX,
|
10 |
+
DEFAULT_IMAGE_TOKEN,
|
11 |
+
DEFAULT_IM_START_TOKEN,
|
12 |
+
DEFAULT_IM_END_TOKEN,
|
13 |
+
)
|
14 |
+
from medrax.llava.conversation import conv_templates, SeparatorStyle
|
15 |
+
from medrax.llava.model.builder import load_pretrained_model
|
16 |
+
from medrax.llava.utils import disable_torch_init
|
17 |
+
from medrax.llava.mm_utils import (
|
18 |
+
tokenizer_image_token,
|
19 |
+
get_model_name_from_path,
|
20 |
+
KeywordsStoppingCriteria,
|
21 |
+
process_images,
|
22 |
+
)
|
23 |
+
|
24 |
+
from PIL import Image
|
25 |
+
import math
|
26 |
+
from transformers import set_seed, logging
|
27 |
+
|
28 |
+
logging.set_verbosity_error()
|
29 |
+
|
30 |
+
|
31 |
+
def split_list(lst, n):
|
32 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
33 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
34 |
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
35 |
+
|
36 |
+
|
37 |
+
def get_chunk(lst, n, k):
|
38 |
+
chunks = split_list(lst, n)
|
39 |
+
return chunks[k]
|
40 |
+
|
41 |
+
|
42 |
+
def eval_model(args):
|
43 |
+
set_seed(0)
|
44 |
+
# Model
|
45 |
+
disable_torch_init()
|
46 |
+
model_path = os.path.expanduser(args.model_path)
|
47 |
+
model_name = get_model_name_from_path(model_path)
|
48 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
49 |
+
model_path, args.model_base, model_name
|
50 |
+
)
|
51 |
+
|
52 |
+
questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
|
53 |
+
questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
|
54 |
+
answers_file = os.path.expanduser(args.answers_file)
|
55 |
+
os.makedirs(os.path.dirname(answers_file), exist_ok=True)
|
56 |
+
ans_file = open(answers_file, "w")
|
57 |
+
for line in tqdm(questions):
|
58 |
+
idx = line["question_id"]
|
59 |
+
image_file = line["image"]
|
60 |
+
qs = line["text"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
61 |
+
cur_prompt = qs
|
62 |
+
if model.config.mm_use_im_start_end:
|
63 |
+
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + qs
|
64 |
+
else:
|
65 |
+
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
|
66 |
+
|
67 |
+
conv = conv_templates[args.conv_mode].copy()
|
68 |
+
conv.append_message(conv.roles[0], qs)
|
69 |
+
conv.append_message(conv.roles[1], None)
|
70 |
+
prompt = conv.get_prompt()
|
71 |
+
|
72 |
+
input_ids = (
|
73 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
74 |
+
.unsqueeze(0)
|
75 |
+
.cuda()
|
76 |
+
)
|
77 |
+
|
78 |
+
image = Image.open(os.path.join(args.image_folder, image_file))
|
79 |
+
image_tensor = process_images([image], image_processor, model.config)[0]
|
80 |
+
|
81 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
82 |
+
keywords = [stop_str]
|
83 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
84 |
+
|
85 |
+
with torch.inference_mode():
|
86 |
+
output_ids = model.generate(
|
87 |
+
input_ids,
|
88 |
+
images=image_tensor.unsqueeze(0).half().cuda(),
|
89 |
+
do_sample=True if args.temperature > 0 else False,
|
90 |
+
temperature=args.temperature,
|
91 |
+
top_p=args.top_p,
|
92 |
+
num_beams=args.num_beams,
|
93 |
+
# no_repeat_ngram_size=3,
|
94 |
+
max_new_tokens=1024,
|
95 |
+
use_cache=True,
|
96 |
+
)
|
97 |
+
|
98 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
|
99 |
+
|
100 |
+
ans_id = shortuuid.uuid()
|
101 |
+
ans_file.write(
|
102 |
+
json.dumps(
|
103 |
+
{
|
104 |
+
"question_id": idx,
|
105 |
+
"prompt": cur_prompt,
|
106 |
+
"text": outputs,
|
107 |
+
"answer_id": ans_id,
|
108 |
+
"model_id": model_name,
|
109 |
+
"metadata": {},
|
110 |
+
}
|
111 |
+
)
|
112 |
+
+ "\n"
|
113 |
+
)
|
114 |
+
ans_file.flush()
|
115 |
+
ans_file.close()
|
116 |
+
|
117 |
+
|
118 |
+
if __name__ == "__main__":
|
119 |
+
parser = argparse.ArgumentParser()
|
120 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
121 |
+
parser.add_argument("--model-base", type=str, default=None)
|
122 |
+
parser.add_argument("--image-folder", type=str, default="")
|
123 |
+
parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
|
124 |
+
parser.add_argument("--answers-file", type=str, default="answer.jsonl")
|
125 |
+
parser.add_argument("--conv-mode", type=str, default="vicuna_v1")
|
126 |
+
parser.add_argument("--num-chunks", type=int, default=1)
|
127 |
+
parser.add_argument("--chunk-idx", type=int, default=0)
|
128 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
129 |
+
parser.add_argument("--top_p", type=float, default=None)
|
130 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
131 |
+
args = parser.parse_args()
|
132 |
+
|
133 |
+
eval_model(args)
|
medrax/llava/eval/summarize_gpt_review.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import util
|
3 |
+
from collections import defaultdict
|
4 |
+
import pandas as pd
|
5 |
+
|
6 |
+
|
7 |
+
def get_domain(x):
|
8 |
+
for domain in ["chest_xray", "mri", "histology", "gross", "ct_scan"]:
|
9 |
+
in_domain = x["domain"][domain]
|
10 |
+
if in_domain:
|
11 |
+
return domain
|
12 |
+
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
scores_data = util.load_file_jsonl(args.scores_file)
|
16 |
+
predictions = [
|
17 |
+
(x["question_id"], x["type"], get_domain(x), x["gpt_eval"].split("\n")[0].split(" "))
|
18 |
+
for x in scores_data
|
19 |
+
]
|
20 |
+
|
21 |
+
score_type_dict = defaultdict(lambda: defaultdict(list))
|
22 |
+
for q_id, q_type, domain, (a1_score, a2_score) in predictions:
|
23 |
+
score_type_dict[q_type][1].append(a1_score)
|
24 |
+
score_type_dict[q_type][2].append(a2_score)
|
25 |
+
score_type_dict["overall"][1].append(a1_score)
|
26 |
+
score_type_dict["overall"][2].append(a2_score)
|
27 |
+
score_type_dict[domain][1].append(a1_score)
|
28 |
+
score_type_dict[domain][2].append(a2_score)
|
29 |
+
|
30 |
+
result = defaultdict(dict)
|
31 |
+
|
32 |
+
for q_type, score_dict in score_type_dict.items():
|
33 |
+
result[q_type]["gpt4_score"] = util.get_avg(score_dict[1])
|
34 |
+
result[q_type]["pred_score"] = util.get_avg(score_dict[2])
|
35 |
+
result[q_type]["pred_relative_score"] = (
|
36 |
+
util.get_avg([float(s2) / float(s1) for s1, s2 in zip(score_dict[1], score_dict[2])])
|
37 |
+
* 100
|
38 |
+
)
|
39 |
+
result[q_type]["data_size"] = len(score_dict[1])
|
40 |
+
|
41 |
+
df = pd.DataFrame.from_dict(result).filter(
|
42 |
+
[
|
43 |
+
"conversation",
|
44 |
+
"detailed_description",
|
45 |
+
"chest_xray",
|
46 |
+
"mri",
|
47 |
+
"histology",
|
48 |
+
"gross",
|
49 |
+
"ct_scan",
|
50 |
+
"overall",
|
51 |
+
]
|
52 |
+
)
|
53 |
+
print(df)
|
54 |
+
|
55 |
+
|
56 |
+
if __name__ == "__main__":
|
57 |
+
parser = argparse.ArgumentParser("GPT-4 Multimodal Chat Eval Postprocessing", add_help=True)
|
58 |
+
parser.add_argument(
|
59 |
+
"--scores-file", default="", metavar="FILE", help="input path to gpt-4 score file"
|
60 |
+
)
|
61 |
+
args = parser.parse_args()
|
62 |
+
main(args)
|
medrax/llava/eval/util.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
|
4 |
+
def load_file_jsonl(path):
|
5 |
+
with open(path) as f:
|
6 |
+
return [json.loads(row) for row in f]
|
7 |
+
|
8 |
+
|
9 |
+
def get_avg(x):
|
10 |
+
return sum([float(y) for y in x]) / len(x)
|
medrax/llava/mm_utils.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from io import BytesIO
|
3 |
+
import base64
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
from transformers import StoppingCriteria
|
7 |
+
from medrax.llava.constants import IMAGE_TOKEN_INDEX
|
8 |
+
|
9 |
+
|
10 |
+
def load_image_from_base64(image):
|
11 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
12 |
+
|
13 |
+
|
14 |
+
def expand2square(pil_img, background_color):
|
15 |
+
width, height = pil_img.size
|
16 |
+
if width == height:
|
17 |
+
return pil_img
|
18 |
+
elif width > height:
|
19 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
20 |
+
# sample a random between 0 and (width - height) // 2
|
21 |
+
y_start = random.randint((width - height) // 2, (width - height) // 2 + 1)
|
22 |
+
result.paste(pil_img, (0, y_start))
|
23 |
+
return result
|
24 |
+
else:
|
25 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
26 |
+
# sample a random between 0 and (height - width) // 2
|
27 |
+
x_start = random.randint((height - width) // 2, (height - width) // 2 + 1)
|
28 |
+
result.paste(pil_img, (x_start, 0))
|
29 |
+
return result
|
30 |
+
|
31 |
+
|
32 |
+
def process_images(images, image_processor, model_cfg):
|
33 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
34 |
+
new_images = []
|
35 |
+
for image in images:
|
36 |
+
if image_aspect_ratio == "pad":
|
37 |
+
if image.mode == "L":
|
38 |
+
background_color = int(
|
39 |
+
255 * sum(image_processor.image_mean) / len(image_processor.image_mean)
|
40 |
+
)
|
41 |
+
else:
|
42 |
+
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
|
43 |
+
image = expand2square(image, background_color)
|
44 |
+
image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
45 |
+
new_images.append(image)
|
46 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
47 |
+
new_images = torch.stack(new_images, dim=0)
|
48 |
+
return new_images
|
49 |
+
|
50 |
+
|
51 |
+
def tokenizer_image_token(
|
52 |
+
prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None
|
53 |
+
):
|
54 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
55 |
+
|
56 |
+
def insert_separator(X, sep):
|
57 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
58 |
+
|
59 |
+
input_ids = []
|
60 |
+
offset = 0
|
61 |
+
if (
|
62 |
+
len(prompt_chunks) > 0
|
63 |
+
and len(prompt_chunks[0]) > 0
|
64 |
+
and prompt_chunks[0][0] == tokenizer.bos_token_id
|
65 |
+
):
|
66 |
+
offset = 1
|
67 |
+
input_ids.append(prompt_chunks[0][0])
|
68 |
+
|
69 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
70 |
+
input_ids.extend(x[offset:])
|
71 |
+
|
72 |
+
if return_tensors is not None:
|
73 |
+
if return_tensors == "pt":
|
74 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
75 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
76 |
+
return input_ids
|
77 |
+
|
78 |
+
|
79 |
+
def get_model_name_from_path(model_path):
|
80 |
+
model_path = model_path.strip("/")
|
81 |
+
model_paths = model_path.split("/")
|
82 |
+
if model_paths[-1].startswith("checkpoint-"):
|
83 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
84 |
+
else:
|
85 |
+
return model_paths[-1]
|
86 |
+
|
87 |
+
|
88 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
89 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
90 |
+
self.keywords = keywords
|
91 |
+
self.keyword_ids = []
|
92 |
+
self.max_keyword_len = 0
|
93 |
+
for keyword in keywords:
|
94 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
95 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
96 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
97 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
98 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
99 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
100 |
+
self.tokenizer = tokenizer
|
101 |
+
self.start_len = input_ids.shape[1]
|
102 |
+
|
103 |
+
def call_for_batch(
|
104 |
+
self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
105 |
+
) -> bool:
|
106 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
107 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
108 |
+
for keyword_id in self.keyword_ids:
|
109 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
110 |
+
return True
|
111 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
112 |
+
for keyword in self.keywords:
|
113 |
+
if keyword in outputs:
|
114 |
+
return True
|
115 |
+
return False
|
116 |
+
|
117 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
118 |
+
outputs = []
|
119 |
+
for i in range(output_ids.shape[0]):
|
120 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
121 |
+
return all(outputs)
|
medrax/llava/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
|
medrax/llava/model/builder.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
2 |
+
import torch
|
3 |
+
from medrax.llava.model import LlavaMistralForCausalLM
|
4 |
+
from medrax.llava.constants import (
|
5 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
6 |
+
DEFAULT_IM_START_TOKEN,
|
7 |
+
DEFAULT_IM_END_TOKEN,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
def load_pretrained_model(
|
12 |
+
model_path,
|
13 |
+
model_base,
|
14 |
+
model_name,
|
15 |
+
load_in_8bit=False,
|
16 |
+
load_in_4bit=True,
|
17 |
+
device="cuda",
|
18 |
+
cache_dir: str = "/model-weights",
|
19 |
+
low_cpu_mem_usage=True,
|
20 |
+
torch_dtype=torch.bfloat16,
|
21 |
+
):
|
22 |
+
|
23 |
+
kwargs = {}
|
24 |
+
|
25 |
+
if device != "cuda":
|
26 |
+
kwargs["device_map"] = {"": device}
|
27 |
+
# else:
|
28 |
+
# kwargs["device_map"] = "auto"
|
29 |
+
|
30 |
+
if load_in_8bit:
|
31 |
+
kwargs["load_in_8bit"] = True
|
32 |
+
elif load_in_4bit:
|
33 |
+
# kwargs["load_in_4bit"] = True
|
34 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
35 |
+
load_in_4bit=True,
|
36 |
+
bnb_4bit_compute_dtype=torch_dtype,
|
37 |
+
bnb_4bit_use_double_quant=True,
|
38 |
+
bnb_4bit_quant_type="nf4",
|
39 |
+
)
|
40 |
+
# else:
|
41 |
+
# kwargs["torch_dtype"] = torch_dtype
|
42 |
+
|
43 |
+
if "llava" in model_name.lower():
|
44 |
+
# Load LLaVA model
|
45 |
+
if "mistral" in model_name.lower():
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path, cache_dir=cache_dir)
|
47 |
+
model = LlavaMistralForCausalLM.from_pretrained(
|
48 |
+
model_path,
|
49 |
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
50 |
+
use_flash_attention_2=False,
|
51 |
+
cache_dir=cache_dir,
|
52 |
+
torch_dtype=torch_dtype,
|
53 |
+
**kwargs,
|
54 |
+
)
|
55 |
+
|
56 |
+
else:
|
57 |
+
# Load language model
|
58 |
+
if model_base is not None:
|
59 |
+
# PEFT model
|
60 |
+
from peft import PeftModel
|
61 |
+
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
63 |
+
model_base, use_fast=False, cache_dir=cache_dir
|
64 |
+
)
|
65 |
+
model = AutoModelForCausalLM.from_pretrained(
|
66 |
+
model_base,
|
67 |
+
low_cpu_mem_usage=True,
|
68 |
+
cache_dir=cache_dir,
|
69 |
+
torch_dtype=torch_dtype,
|
70 |
+
**kwargs,
|
71 |
+
)
|
72 |
+
print(f"Loading LoRA weights from {model_path}")
|
73 |
+
model = PeftModel.from_pretrained(model, model_path)
|
74 |
+
print("Merging weights")
|
75 |
+
model = model.merge_and_unload()
|
76 |
+
print("Convert to FP16...")
|
77 |
+
model.to(torch_dtype)
|
78 |
+
else:
|
79 |
+
use_fast = False
|
80 |
+
if "mpt" in model_name.lower():
|
81 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
82 |
+
model_path, use_fast=True, cache_dir=cache_dir
|
83 |
+
)
|
84 |
+
model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
model_path,
|
86 |
+
low_cpu_mem_usage=True,
|
87 |
+
trust_remote_code=True,
|
88 |
+
cache_dir=cache_dir,
|
89 |
+
torch_dtype=torch_dtype,
|
90 |
+
**kwargs,
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
94 |
+
model_path, use_fast=False, cache_dir=cache_dir
|
95 |
+
)
|
96 |
+
model = AutoModelForCausalLM.from_pretrained(
|
97 |
+
model_path,
|
98 |
+
low_cpu_mem_usage=True,
|
99 |
+
cache_dir=cache_dir,
|
100 |
+
torch_dtype=torch_dtype,
|
101 |
+
**kwargs,
|
102 |
+
)
|
103 |
+
|
104 |
+
image_processor = None
|
105 |
+
|
106 |
+
if "llava" in model_name.lower(): # or 'mistral' in model_name.lower():
|
107 |
+
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
|
108 |
+
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
|
109 |
+
if mm_use_im_patch_token:
|
110 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
111 |
+
if mm_use_im_start_end:
|
112 |
+
tokenizer.add_tokens(
|
113 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
114 |
+
)
|
115 |
+
model.resize_token_embeddings(len(tokenizer))
|
116 |
+
|
117 |
+
vision_tower = model.get_vision_tower()
|
118 |
+
if not vision_tower.is_loaded:
|
119 |
+
vision_tower.load_model()
|
120 |
+
|
121 |
+
vision_tower.to(device=device, dtype=torch_dtype)
|
122 |
+
model.model.mm_projector.to(device=device, dtype=torch_dtype)
|
123 |
+
|
124 |
+
if not (load_in_4bit or load_in_8bit):
|
125 |
+
model.to(device=device, dtype=torch_dtype)
|
126 |
+
|
127 |
+
image_processor = vision_tower.image_processor
|
128 |
+
|
129 |
+
if hasattr(model.config, "max_sequence_length"):
|
130 |
+
context_len = model.config.max_sequence_length
|
131 |
+
else:
|
132 |
+
context_len = 2048
|
133 |
+
|
134 |
+
return tokenizer, model, image_processor, context_len
|
medrax/llava/model/language_model/llava_mistral.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from transformers import (
|
7 |
+
AutoConfig,
|
8 |
+
AutoModelForCausalLM,
|
9 |
+
MistralConfig,
|
10 |
+
MistralModel,
|
11 |
+
MistralForCausalLM,
|
12 |
+
)
|
13 |
+
|
14 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
15 |
+
from transformers.generation.utils import GenerateOutput
|
16 |
+
|
17 |
+
from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
|
18 |
+
|
19 |
+
|
20 |
+
class LlavaMistralConfig(MistralConfig):
|
21 |
+
model_type = "llava_mistral"
|
22 |
+
|
23 |
+
|
24 |
+
class LlavaMistralModel(LlavaMetaModel, MistralModel):
|
25 |
+
config_class = LlavaMistralConfig
|
26 |
+
|
27 |
+
def __init__(self, config: MistralConfig):
|
28 |
+
super(LlavaMistralModel, self).__init__(config)
|
29 |
+
|
30 |
+
|
31 |
+
class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
|
32 |
+
config_class = LlavaMistralConfig
|
33 |
+
|
34 |
+
def __init__(self, config):
|
35 |
+
super(MistralForCausalLM, self).__init__(config)
|
36 |
+
self.model = LlavaMistralModel(config)
|
37 |
+
|
38 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
39 |
+
|
40 |
+
# Initialize weights and apply final processing
|
41 |
+
self.post_init()
|
42 |
+
|
43 |
+
def get_model(self):
|
44 |
+
return self.model
|
45 |
+
|
46 |
+
def forward(
|
47 |
+
self,
|
48 |
+
input_ids: torch.LongTensor = None,
|
49 |
+
attention_mask: Optional[torch.Tensor] = None,
|
50 |
+
position_ids: Optional[torch.LongTensor] = None,
|
51 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
52 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
53 |
+
labels: Optional[torch.LongTensor] = None,
|
54 |
+
use_cache: Optional[bool] = None,
|
55 |
+
output_attentions: Optional[bool] = None,
|
56 |
+
output_hidden_states: Optional[bool] = None,
|
57 |
+
images: Optional[torch.FloatTensor] = None,
|
58 |
+
image_sizes: Optional[List[List[int]]] = None,
|
59 |
+
return_dict: Optional[bool] = None,
|
60 |
+
cache_position: Optional[str] = None,
|
61 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
62 |
+
|
63 |
+
if inputs_embeds is None:
|
64 |
+
(
|
65 |
+
input_ids,
|
66 |
+
position_ids,
|
67 |
+
attention_mask,
|
68 |
+
past_key_values,
|
69 |
+
inputs_embeds,
|
70 |
+
labels,
|
71 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
72 |
+
input_ids,
|
73 |
+
position_ids,
|
74 |
+
attention_mask,
|
75 |
+
past_key_values,
|
76 |
+
labels,
|
77 |
+
images,
|
78 |
+
image_sizes,
|
79 |
+
)
|
80 |
+
|
81 |
+
return super().forward(
|
82 |
+
input_ids=input_ids,
|
83 |
+
attention_mask=attention_mask,
|
84 |
+
position_ids=position_ids,
|
85 |
+
past_key_values=past_key_values,
|
86 |
+
inputs_embeds=inputs_embeds,
|
87 |
+
labels=labels,
|
88 |
+
use_cache=use_cache,
|
89 |
+
output_attentions=output_attentions,
|
90 |
+
output_hidden_states=output_hidden_states,
|
91 |
+
return_dict=return_dict,
|
92 |
+
)
|
93 |
+
|
94 |
+
@torch.no_grad()
|
95 |
+
def generate(
|
96 |
+
self,
|
97 |
+
inputs: Optional[torch.Tensor] = None,
|
98 |
+
images: Optional[torch.Tensor] = None,
|
99 |
+
image_sizes: Optional[torch.Tensor] = None,
|
100 |
+
**kwargs,
|
101 |
+
) -> Union[GenerateOutput, torch.LongTensor]:
|
102 |
+
position_ids = kwargs.pop("position_ids", None)
|
103 |
+
attention_mask = kwargs.pop("attention_mask", None)
|
104 |
+
if "inputs_embeds" in kwargs:
|
105 |
+
raise NotImplementedError("`inputs_embeds` is not supported")
|
106 |
+
|
107 |
+
if images is not None:
|
108 |
+
(
|
109 |
+
inputs,
|
110 |
+
position_ids,
|
111 |
+
attention_mask,
|
112 |
+
_,
|
113 |
+
inputs_embeds,
|
114 |
+
_,
|
115 |
+
) = self.prepare_inputs_labels_for_multimodal(
|
116 |
+
inputs, position_ids, attention_mask, None, None, images, image_sizes=image_sizes
|
117 |
+
)
|
118 |
+
else:
|
119 |
+
inputs_embeds = self.get_model().embed_tokens(inputs)
|
120 |
+
|
121 |
+
return super().generate(
|
122 |
+
position_ids=position_ids,
|
123 |
+
attention_mask=attention_mask,
|
124 |
+
inputs_embeds=inputs_embeds,
|
125 |
+
**kwargs,
|
126 |
+
)
|
127 |
+
|
128 |
+
def prepare_inputs_for_generation(
|
129 |
+
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
|
130 |
+
):
|
131 |
+
images = kwargs.pop("images", None)
|
132 |
+
image_sizes = kwargs.pop("image_sizes", None)
|
133 |
+
inputs = super().prepare_inputs_for_generation(
|
134 |
+
input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
|
135 |
+
)
|
136 |
+
if images is not None:
|
137 |
+
inputs["images"] = images
|
138 |
+
if image_sizes is not None:
|
139 |
+
inputs["image_sizes"] = image_sizes
|
140 |
+
return inputs
|
141 |
+
|
142 |
+
|
143 |
+
AutoConfig.register("llava_mistral", LlavaMistralConfig)
|
144 |
+
AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
|
medrax/llava/model/llava_arch.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 Haotian Liu
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
from abc import ABC, abstractmethod
|
17 |
+
import os
|
18 |
+
from glob import glob
|
19 |
+
|
20 |
+
import torch
|
21 |
+
|
22 |
+
from .multimodal_encoder.builder import build_vision_tower
|
23 |
+
from .multimodal_projector.builder import build_vision_projector
|
24 |
+
|
25 |
+
from medrax.llava.constants import (
|
26 |
+
IGNORE_INDEX,
|
27 |
+
IMAGE_TOKEN_INDEX,
|
28 |
+
DEFAULT_IMAGE_PATCH_TOKEN,
|
29 |
+
DEFAULT_IM_START_TOKEN,
|
30 |
+
DEFAULT_IM_END_TOKEN,
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class LlavaMetaModel:
|
35 |
+
def __init__(self, config):
|
36 |
+
super(LlavaMetaModel, self).__init__(config)
|
37 |
+
|
38 |
+
if hasattr(config, "mm_vision_tower"):
|
39 |
+
self.vision_tower = build_vision_tower(config, delay_load=True)
|
40 |
+
self.mm_projector = build_vision_projector(config)
|
41 |
+
|
42 |
+
def get_vision_tower(self):
|
43 |
+
vision_tower = getattr(self, "vision_tower", None)
|
44 |
+
if type(vision_tower) is list:
|
45 |
+
vision_tower = vision_tower[0]
|
46 |
+
return vision_tower
|
47 |
+
|
48 |
+
def initialize_vision_modules(self, model_args, fsdp=None, embed_tokens=None):
|
49 |
+
vision_tower = model_args.vision_tower
|
50 |
+
mm_vision_select_layer = model_args.mm_vision_select_layer
|
51 |
+
mm_vision_select_feature = model_args.mm_vision_select_feature
|
52 |
+
pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
|
53 |
+
|
54 |
+
self.config.mm_vision_tower = vision_tower
|
55 |
+
|
56 |
+
if self.get_vision_tower() is None:
|
57 |
+
vision_tower = build_vision_tower(model_args)
|
58 |
+
|
59 |
+
if fsdp is not None and len(fsdp) > 0:
|
60 |
+
self.vision_tower = [vision_tower]
|
61 |
+
else:
|
62 |
+
self.vision_tower = vision_tower
|
63 |
+
else:
|
64 |
+
if fsdp is not None and len(fsdp) > 0:
|
65 |
+
vision_tower = self.vision_tower[0]
|
66 |
+
else:
|
67 |
+
vision_tower = self.vision_tower
|
68 |
+
vision_tower.load_model()
|
69 |
+
|
70 |
+
self.config.use_mm_proj = True
|
71 |
+
self.config.mm_projector_type = getattr(model_args, "mm_projector_type", "linear")
|
72 |
+
self.config.mm_hidden_size = vision_tower.hidden_size
|
73 |
+
self.config.mm_vision_select_layer = mm_vision_select_layer
|
74 |
+
self.config.mm_vision_select_feature = mm_vision_select_feature
|
75 |
+
|
76 |
+
# add additional configs for segtok
|
77 |
+
self.config.feature_outs = model_args.feature_outs
|
78 |
+
self.config.img_size = model_args.img_size
|
79 |
+
self.config.vision_backbone = model_args.vision_backbone
|
80 |
+
self.config.segtok_posembed = model_args.segtok_posembed
|
81 |
+
|
82 |
+
if getattr(self, "mm_projector", None) is None:
|
83 |
+
self.mm_projector = build_vision_projector(self.config)
|
84 |
+
else:
|
85 |
+
# In case it is frozen by LoRA
|
86 |
+
for p in self.mm_projector.parameters():
|
87 |
+
p.requires_grad = True
|
88 |
+
|
89 |
+
# Initialize last layer in mm_projector with weight=0 and bias=mean(embed_tokens)
|
90 |
+
if embed_tokens is not None:
|
91 |
+
embed_tokens_weight = embed_tokens.weight.data
|
92 |
+
self.mm_projector[-1].weight.data.zero_()
|
93 |
+
self.mm_projector[-1].bias.data.copy_(embed_tokens_weight.mean(dim=0))
|
94 |
+
|
95 |
+
if pretrain_mm_mlp_adapter is not None:
|
96 |
+
|
97 |
+
def get_w(weights, keyword):
|
98 |
+
return {k.split(keyword + ".")[1]: v for k, v in weights.items() if keyword in k}
|
99 |
+
|
100 |
+
mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location="cpu")
|
101 |
+
self.mm_projector.load_state_dict(get_w(mm_projector_weights, "mm_projector"))
|
102 |
+
|
103 |
+
# also load additional learnable parameters during feature alignment
|
104 |
+
checkpoint_folder = os.path.dirname(pretrain_mm_mlp_adapter)
|
105 |
+
ckpts = glob(f"{checkpoint_folder}/checkpoint-*", recursive=False)
|
106 |
+
if len(ckpts) > 0:
|
107 |
+
vision_module_weights = torch.load(
|
108 |
+
f"{ckpts[-1]}/mm_projector.bin", map_location="cpu"
|
109 |
+
)
|
110 |
+
model_dict = get_w(vision_module_weights, "vision_tower")
|
111 |
+
print(f"Loading vision module weights from {ckpts[-1]}/mm_projector.bin")
|
112 |
+
# print keys in model_dict
|
113 |
+
print(f"Loaded keys: {model_dict.keys()}")
|
114 |
+
self.vision_tower.load_state_dict(model_dict, strict=False)
|
115 |
+
|
116 |
+
|
117 |
+
class LlavaMetaForCausalLM(ABC):
|
118 |
+
@abstractmethod
|
119 |
+
def get_model(self):
|
120 |
+
pass
|
121 |
+
|
122 |
+
def get_vision_tower(self):
|
123 |
+
return self.get_model().get_vision_tower()
|
124 |
+
|
125 |
+
def encode_images(self, images):
|
126 |
+
image_features = self.get_model().get_vision_tower()(images)
|
127 |
+
image_features = self.get_model().mm_projector(image_features)
|
128 |
+
return image_features
|
129 |
+
|
130 |
+
def prepare_inputs_labels_for_multimodal(
|
131 |
+
self,
|
132 |
+
input_ids,
|
133 |
+
position_ids,
|
134 |
+
attention_mask,
|
135 |
+
past_key_values,
|
136 |
+
labels,
|
137 |
+
images,
|
138 |
+
image_sizes=None,
|
139 |
+
):
|
140 |
+
vision_tower = self.get_vision_tower()
|
141 |
+
if vision_tower is None or images is None or input_ids.shape[1] == 1:
|
142 |
+
if (
|
143 |
+
past_key_values is not None
|
144 |
+
and vision_tower is not None
|
145 |
+
and images is not None
|
146 |
+
and input_ids.shape[1] == 1
|
147 |
+
):
|
148 |
+
target_shape = past_key_values[-1][-1].shape[-2] + 1
|
149 |
+
attention_mask = torch.cat(
|
150 |
+
(
|
151 |
+
attention_mask,
|
152 |
+
torch.ones(
|
153 |
+
(attention_mask.shape[0], target_shape - attention_mask.shape[1]),
|
154 |
+
dtype=attention_mask.dtype,
|
155 |
+
device=attention_mask.device,
|
156 |
+
),
|
157 |
+
),
|
158 |
+
dim=1,
|
159 |
+
)
|
160 |
+
position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
|
161 |
+
return input_ids, position_ids, attention_mask, past_key_values, None, labels
|
162 |
+
|
163 |
+
if type(images) is list or images.ndim == 5:
|
164 |
+
concat_images = torch.cat([image for image in images], dim=0)
|
165 |
+
image_features = self.encode_images(concat_images)
|
166 |
+
split_sizes = [image.shape[0] for image in images]
|
167 |
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
168 |
+
image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
|
169 |
+
else:
|
170 |
+
image_features = self.encode_images(images).to(self.device)
|
171 |
+
|
172 |
+
# TODO: image start / end is not implemented here to support pretraining.
|
173 |
+
if getattr(self.config, "tune_mm_mlp_adapter", False) and getattr(
|
174 |
+
self.config, "mm_use_im_start_end", False
|
175 |
+
):
|
176 |
+
raise NotImplementedError
|
177 |
+
|
178 |
+
# Let's just add dummy tensors if they do not exist,
|
179 |
+
# it is a headache to deal with None all the time.
|
180 |
+
# But it is not ideal, and if you have a better idea,
|
181 |
+
# please open an issue / submit a PR, thanks.
|
182 |
+
_labels = labels
|
183 |
+
_position_ids = position_ids
|
184 |
+
_attention_mask = attention_mask
|
185 |
+
|
186 |
+
if attention_mask is None:
|
187 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
188 |
+
else:
|
189 |
+
attention_mask = attention_mask.bool()
|
190 |
+
if position_ids is None:
|
191 |
+
position_ids = torch.arange(
|
192 |
+
0, input_ids.shape[1], dtype=torch.long, device=input_ids.device
|
193 |
+
)
|
194 |
+
|
195 |
+
if labels is None:
|
196 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
197 |
+
|
198 |
+
input_ids = [
|
199 |
+
cur_input_ids[cur_attention_mask]
|
200 |
+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
201 |
+
]
|
202 |
+
labels = [
|
203 |
+
cur_labels[cur_attention_mask]
|
204 |
+
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
205 |
+
]
|
206 |
+
|
207 |
+
new_input_embeds = []
|
208 |
+
new_labels = []
|
209 |
+
cur_image_idx = 0
|
210 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
211 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
212 |
+
if num_images == 0:
|
213 |
+
cur_image_features = image_features[cur_image_idx]
|
214 |
+
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
|
215 |
+
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
|
216 |
+
new_input_embeds.append(cur_input_embeds)
|
217 |
+
new_labels.append(labels[batch_idx])
|
218 |
+
cur_image_idx += 1
|
219 |
+
continue
|
220 |
+
|
221 |
+
image_token_indices = (
|
222 |
+
[-1]
|
223 |
+
+ torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist()
|
224 |
+
+ [cur_input_ids.shape[0]]
|
225 |
+
)
|
226 |
+
cur_input_ids_noim = []
|
227 |
+
cur_labels = labels[batch_idx]
|
228 |
+
cur_labels_noim = []
|
229 |
+
for i in range(len(image_token_indices) - 1):
|
230 |
+
cur_input_ids_noim.append(
|
231 |
+
cur_input_ids[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
232 |
+
)
|
233 |
+
cur_labels_noim.append(
|
234 |
+
cur_labels[image_token_indices[i] + 1 : image_token_indices[i + 1]]
|
235 |
+
)
|
236 |
+
|
237 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
238 |
+
cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
|
239 |
+
cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
|
240 |
+
cur_new_input_embeds = []
|
241 |
+
cur_new_labels = []
|
242 |
+
|
243 |
+
for i in range(num_images + 1):
|
244 |
+
cur_new_input_embeds.append(cur_input_embeds_no_im[i])
|
245 |
+
cur_new_labels.append(cur_labels_noim[i])
|
246 |
+
if i < num_images:
|
247 |
+
cur_image_features = image_features[cur_image_idx]
|
248 |
+
cur_image_idx += 1
|
249 |
+
cur_new_input_embeds.append(cur_image_features)
|
250 |
+
cur_new_labels.append(
|
251 |
+
torch.full(
|
252 |
+
(cur_image_features.shape[0],),
|
253 |
+
IGNORE_INDEX,
|
254 |
+
device=cur_labels.device,
|
255 |
+
dtype=cur_labels.dtype,
|
256 |
+
)
|
257 |
+
)
|
258 |
+
|
259 |
+
cur_new_input_embeds = torch.cat(cur_new_input_embeds)
|
260 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
261 |
+
|
262 |
+
new_input_embeds.append(cur_new_input_embeds)
|
263 |
+
new_labels.append(cur_new_labels)
|
264 |
+
|
265 |
+
# Truncate sequences to max length as image embeddings can make the sequence longer
|
266 |
+
tokenizer_model_max_length = getattr(self.config, "tokenizer_model_max_length", None)
|
267 |
+
if tokenizer_model_max_length is not None:
|
268 |
+
new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
|
269 |
+
new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
|
270 |
+
|
271 |
+
# Combine them
|
272 |
+
max_len = max(x.shape[0] for x in new_input_embeds)
|
273 |
+
batch_size = len(new_input_embeds)
|
274 |
+
|
275 |
+
new_input_embeds_padded = []
|
276 |
+
new_labels_padded = torch.full(
|
277 |
+
(batch_size, max_len),
|
278 |
+
IGNORE_INDEX,
|
279 |
+
dtype=new_labels[0].dtype,
|
280 |
+
device=new_labels[0].device,
|
281 |
+
)
|
282 |
+
attention_mask = torch.zeros(
|
283 |
+
(batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device
|
284 |
+
)
|
285 |
+
position_ids = torch.zeros(
|
286 |
+
(batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device
|
287 |
+
)
|
288 |
+
|
289 |
+
for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
|
290 |
+
cur_len = cur_new_embed.shape[0]
|
291 |
+
if getattr(self.config, "tokenizer_padding_side", "right") == "left":
|
292 |
+
new_input_embeds_padded.append(
|
293 |
+
torch.cat(
|
294 |
+
(
|
295 |
+
torch.zeros(
|
296 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
297 |
+
dtype=cur_new_embed.dtype,
|
298 |
+
device=cur_new_embed.device,
|
299 |
+
),
|
300 |
+
cur_new_embed,
|
301 |
+
),
|
302 |
+
dim=0,
|
303 |
+
)
|
304 |
+
)
|
305 |
+
if cur_len > 0:
|
306 |
+
new_labels_padded[i, -cur_len:] = cur_new_labels
|
307 |
+
attention_mask[i, -cur_len:] = True
|
308 |
+
position_ids[i, -cur_len:] = torch.arange(
|
309 |
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
310 |
+
)
|
311 |
+
else:
|
312 |
+
new_input_embeds_padded.append(
|
313 |
+
torch.cat(
|
314 |
+
(
|
315 |
+
cur_new_embed,
|
316 |
+
torch.zeros(
|
317 |
+
(max_len - cur_len, cur_new_embed.shape[1]),
|
318 |
+
dtype=cur_new_embed.dtype,
|
319 |
+
device=cur_new_embed.device,
|
320 |
+
),
|
321 |
+
),
|
322 |
+
dim=0,
|
323 |
+
)
|
324 |
+
)
|
325 |
+
if cur_len > 0:
|
326 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
327 |
+
attention_mask[i, :cur_len] = True
|
328 |
+
position_ids[i, :cur_len] = torch.arange(
|
329 |
+
0, cur_len, dtype=position_ids.dtype, device=position_ids.device
|
330 |
+
)
|
331 |
+
|
332 |
+
new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
|
333 |
+
|
334 |
+
if _labels is None:
|
335 |
+
new_labels = None
|
336 |
+
else:
|
337 |
+
new_labels = new_labels_padded
|
338 |
+
|
339 |
+
if _attention_mask is None:
|
340 |
+
attention_mask = None
|
341 |
+
else:
|
342 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
343 |
+
|
344 |
+
if _position_ids is None:
|
345 |
+
position_ids = None
|
346 |
+
return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
|
347 |
+
|
348 |
+
def initialize_vision_tokenizer(self, model_args, tokenizer):
|
349 |
+
if model_args.mm_use_im_patch_token:
|
350 |
+
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
|
351 |
+
self.resize_token_embeddings(len(tokenizer))
|
352 |
+
|
353 |
+
if model_args.mm_use_im_start_end:
|
354 |
+
num_new_tokens = tokenizer.add_tokens(
|
355 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
356 |
+
)
|
357 |
+
self.resize_token_embeddings(len(tokenizer))
|
358 |
+
|
359 |
+
if num_new_tokens > 0:
|
360 |
+
input_embeddings = self.get_input_embeddings().weight.data
|
361 |
+
output_embeddings = self.get_output_embeddings().weight.data
|
362 |
+
|
363 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
364 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
365 |
+
dim=0, keepdim=True
|
366 |
+
)
|
367 |
+
|
368 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
369 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
370 |
+
|
371 |
+
if model_args.tune_mm_mlp_adapter:
|
372 |
+
for p in self.get_input_embeddings().parameters():
|
373 |
+
p.requires_grad = True
|
374 |
+
for p in self.get_output_embeddings().parameters():
|
375 |
+
p.requires_grad = False
|
376 |
+
|
377 |
+
if model_args.pretrain_mm_mlp_adapter:
|
378 |
+
mm_projector_weights = torch.load(
|
379 |
+
model_args.pretrain_mm_mlp_adapter, map_location="cpu"
|
380 |
+
)
|
381 |
+
embed_tokens_weight = mm_projector_weights["model.embed_tokens.weight"]
|
382 |
+
assert num_new_tokens == 2
|
383 |
+
if input_embeddings.shape == embed_tokens_weight.shape:
|
384 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:]
|
385 |
+
elif embed_tokens_weight.shape[0] == num_new_tokens:
|
386 |
+
input_embeddings[-num_new_tokens:] = embed_tokens_weight
|
387 |
+
else:
|
388 |
+
raise ValueError(
|
389 |
+
f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}."
|
390 |
+
)
|
391 |
+
elif model_args.mm_use_im_patch_token:
|
392 |
+
if model_args.tune_mm_mlp_adapter:
|
393 |
+
for p in self.get_input_embeddings().parameters():
|
394 |
+
p.requires_grad = False
|
395 |
+
for p in self.get_output_embeddings().parameters():
|
396 |
+
p.requires_grad = False
|
medrax/llava/model/multimodal_encoder/builder.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .clip_encoder import CLIPVisionTower
|
3 |
+
|
4 |
+
|
5 |
+
def build_vision_tower(vision_tower_cfg, **kwargs):
|
6 |
+
vision_tower = getattr(
|
7 |
+
vision_tower_cfg, "mm_vision_tower", getattr(vision_tower_cfg, "vision_tower", None)
|
8 |
+
)
|
9 |
+
is_absolute_path_exists = os.path.exists(vision_tower)
|
10 |
+
if (
|
11 |
+
is_absolute_path_exists
|
12 |
+
or vision_tower.startswith("openai")
|
13 |
+
or vision_tower.startswith("laion")
|
14 |
+
):
|
15 |
+
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
|
medrax/llava/model/multimodal_encoder/clip_encoder.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPVisionTower(nn.Module):
|
8 |
+
def __init__(self, vision_tower, args, delay_load=False):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.is_loaded = False
|
12 |
+
|
13 |
+
self.vision_tower_name = vision_tower
|
14 |
+
self.select_layer = args.mm_vision_select_layer
|
15 |
+
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")
|
16 |
+
|
17 |
+
if not delay_load:
|
18 |
+
self.load_model()
|
19 |
+
else:
|
20 |
+
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
|
21 |
+
|
22 |
+
def load_model(self):
|
23 |
+
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
|
24 |
+
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
|
25 |
+
self.vision_tower.requires_grad_(False)
|
26 |
+
|
27 |
+
self.is_loaded = True
|
28 |
+
|
29 |
+
def feature_select(self, image_forward_outs):
|
30 |
+
image_features = image_forward_outs.hidden_states[self.select_layer]
|
31 |
+
if self.select_feature == "patch":
|
32 |
+
image_features = image_features[:, 1:]
|
33 |
+
elif self.select_feature == "cls_patch":
|
34 |
+
image_features = image_features
|
35 |
+
else:
|
36 |
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
37 |
+
return image_features
|
38 |
+
|
39 |
+
@torch.no_grad()
|
40 |
+
def forward(self, images):
|
41 |
+
if type(images) is list:
|
42 |
+
image_features = []
|
43 |
+
for image in images:
|
44 |
+
image_forward_out = self.vision_tower(
|
45 |
+
image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
|
46 |
+
output_hidden_states=True,
|
47 |
+
)
|
48 |
+
image_feature = self.feature_select(image_forward_out).to(image.dtype)
|
49 |
+
image_features.append(image_feature)
|
50 |
+
else:
|
51 |
+
image_forward_outs = self.vision_tower(
|
52 |
+
images.to(device=self.device, dtype=self.dtype), output_hidden_states=True
|
53 |
+
)
|
54 |
+
image_features = self.feature_select(image_forward_outs).to(images.dtype)
|
55 |
+
|
56 |
+
return image_features
|
57 |
+
|
58 |
+
@property
|
59 |
+
def dummy_feature(self):
|
60 |
+
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
|
61 |
+
|
62 |
+
@property
|
63 |
+
def dtype(self):
|
64 |
+
return self.vision_tower.dtype
|
65 |
+
|
66 |
+
@property
|
67 |
+
def device(self):
|
68 |
+
return self.vision_tower.device
|
69 |
+
|
70 |
+
@property
|
71 |
+
def config(self):
|
72 |
+
if self.is_loaded:
|
73 |
+
return self.vision_tower.config
|
74 |
+
else:
|
75 |
+
return self.cfg_only
|
76 |
+
|
77 |
+
@property
|
78 |
+
def hidden_size(self):
|
79 |
+
return self.config.hidden_size
|
80 |
+
|
81 |
+
@property
|
82 |
+
def num_patches(self):
|
83 |
+
return (self.config.image_size // self.config.patch_size) ** 2
|
medrax/llava/model/multimodal_projector/builder.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
class IdentityMap(nn.Module):
|
6 |
+
def __init__(self):
|
7 |
+
super().__init__()
|
8 |
+
|
9 |
+
def forward(self, x, *args, **kwargs):
|
10 |
+
return x
|
11 |
+
|
12 |
+
@property
|
13 |
+
def config(self):
|
14 |
+
return {"mm_projector_type": "identity"}
|
15 |
+
|
16 |
+
|
17 |
+
class SimpleResBlock(nn.Module):
|
18 |
+
def __init__(self, channels):
|
19 |
+
super().__init__()
|
20 |
+
self.pre_norm = nn.LayerNorm(channels)
|
21 |
+
|
22 |
+
self.proj = nn.Sequential(
|
23 |
+
nn.Linear(channels, channels), nn.GELU(), nn.Linear(channels, channels)
|
24 |
+
)
|
25 |
+
|
26 |
+
def forward(self, x):
|
27 |
+
x = self.pre_norm(x)
|
28 |
+
return x + self.proj(x)
|
29 |
+
|
30 |
+
|
31 |
+
def build_vision_projector(config, delay_load=False, **kwargs):
|
32 |
+
projector_type = getattr(config, "mm_projector_type", "linear")
|
33 |
+
|
34 |
+
if projector_type == "linear":
|
35 |
+
return nn.Linear(config.mm_hidden_size, config.hidden_size)
|
36 |
+
|
37 |
+
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
|
38 |
+
if mlp_gelu_match:
|
39 |
+
mlp_depth = int(mlp_gelu_match.group(1))
|
40 |
+
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
|
41 |
+
for _ in range(1, mlp_depth):
|
42 |
+
modules.append(nn.GELU())
|
43 |
+
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
|
44 |
+
return nn.Sequential(*modules)
|
45 |
+
|
46 |
+
if projector_type == "identity":
|
47 |
+
return IdentityMap()
|
48 |
+
|
49 |
+
raise ValueError(f"Unknown projector type: {projector_type}")
|
medrax/llava/serve/__init__.py
ADDED
File without changes
|
medrax/llava/serve/cli.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from medrax.llava.constants import (
|
5 |
+
IMAGE_TOKEN_INDEX,
|
6 |
+
DEFAULT_IMAGE_TOKEN,
|
7 |
+
DEFAULT_IM_START_TOKEN,
|
8 |
+
DEFAULT_IM_END_TOKEN,
|
9 |
+
)
|
10 |
+
from medrax.llava.conversation import conv_templates, SeparatorStyle
|
11 |
+
from medrax.llava.model.builder import load_pretrained_model
|
12 |
+
from medrax.llava.utils import disable_torch_init
|
13 |
+
from medrax.llava.mm_utils import (
|
14 |
+
process_images,
|
15 |
+
tokenizer_image_token,
|
16 |
+
get_model_name_from_path,
|
17 |
+
KeywordsStoppingCriteria,
|
18 |
+
)
|
19 |
+
|
20 |
+
from PIL import Image
|
21 |
+
|
22 |
+
import requests
|
23 |
+
from io import BytesIO
|
24 |
+
from transformers import TextStreamer
|
25 |
+
|
26 |
+
|
27 |
+
def load_image(image_file):
|
28 |
+
if image_file.startswith("http://") or image_file.startswith("https://"):
|
29 |
+
response = requests.get(image_file)
|
30 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
31 |
+
else:
|
32 |
+
image = Image.open(image_file).convert("RGB")
|
33 |
+
return image
|
34 |
+
|
35 |
+
|
36 |
+
def main(args):
|
37 |
+
# Model
|
38 |
+
disable_torch_init()
|
39 |
+
|
40 |
+
model_name = get_model_name_from_path(args.model_path)
|
41 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
42 |
+
args.model_path,
|
43 |
+
args.model_base,
|
44 |
+
model_name,
|
45 |
+
args.load_8bit,
|
46 |
+
args.load_4bit,
|
47 |
+
device=args.device,
|
48 |
+
)
|
49 |
+
|
50 |
+
if "llama-2" in model_name.lower():
|
51 |
+
conv_mode = "llava_llama_2"
|
52 |
+
elif "v1" in model_name.lower():
|
53 |
+
conv_mode = "llava_v1"
|
54 |
+
elif "mpt" in model_name.lower():
|
55 |
+
conv_mode = "mpt"
|
56 |
+
else:
|
57 |
+
conv_mode = "llava_v0"
|
58 |
+
conv_mode = "mistral_instruct"
|
59 |
+
|
60 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
61 |
+
print(
|
62 |
+
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
|
63 |
+
conv_mode, args.conv_mode, args.conv_mode
|
64 |
+
)
|
65 |
+
)
|
66 |
+
else:
|
67 |
+
args.conv_mode = conv_mode
|
68 |
+
|
69 |
+
conv = conv_templates[args.conv_mode].copy()
|
70 |
+
if "mpt" in model_name.lower():
|
71 |
+
roles = ("user", "assistant")
|
72 |
+
else:
|
73 |
+
roles = conv.roles
|
74 |
+
|
75 |
+
image = load_image(args.image_file)
|
76 |
+
# Similar operation in model_worker.py
|
77 |
+
image_tensor = process_images([image], image_processor, model.config)
|
78 |
+
if type(image_tensor) is list:
|
79 |
+
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
|
80 |
+
else:
|
81 |
+
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
|
82 |
+
|
83 |
+
while True:
|
84 |
+
try:
|
85 |
+
inp = input(f"{roles[0]}: ")
|
86 |
+
except EOFError:
|
87 |
+
inp = ""
|
88 |
+
if not inp:
|
89 |
+
print("exit...")
|
90 |
+
break
|
91 |
+
|
92 |
+
print(f"{roles[1]}: ", end="")
|
93 |
+
|
94 |
+
if image is not None:
|
95 |
+
# first message
|
96 |
+
if model.config.mm_use_im_start_end:
|
97 |
+
inp = (
|
98 |
+
DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + inp
|
99 |
+
)
|
100 |
+
else:
|
101 |
+
inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
|
102 |
+
conv.append_message(conv.roles[0], inp)
|
103 |
+
image = None
|
104 |
+
else:
|
105 |
+
# later messages
|
106 |
+
conv.append_message(conv.roles[0], inp)
|
107 |
+
conv.append_message(conv.roles[1], None)
|
108 |
+
prompt = conv.get_prompt()
|
109 |
+
|
110 |
+
input_ids = (
|
111 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
112 |
+
.unsqueeze(0)
|
113 |
+
.to(model.device)
|
114 |
+
)
|
115 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
116 |
+
keywords = [stop_str]
|
117 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
118 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
119 |
+
|
120 |
+
with torch.inference_mode():
|
121 |
+
output_ids = model.generate(
|
122 |
+
input_ids,
|
123 |
+
images=image_tensor,
|
124 |
+
do_sample=True if args.temperature > 0 else False,
|
125 |
+
temperature=args.temperature,
|
126 |
+
max_new_tokens=args.max_new_tokens,
|
127 |
+
streamer=streamer,
|
128 |
+
use_cache=True,
|
129 |
+
stopping_criteria=[stopping_criteria],
|
130 |
+
)
|
131 |
+
|
132 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
|
133 |
+
conv.messages[-1][-1] = outputs
|
134 |
+
|
135 |
+
if args.debug:
|
136 |
+
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
parser = argparse.ArgumentParser()
|
141 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
142 |
+
parser.add_argument("--model-base", type=str, default=None)
|
143 |
+
parser.add_argument("--image-file", type=str, required=True)
|
144 |
+
parser.add_argument("--device", type=str, default="cuda")
|
145 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
146 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
147 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
148 |
+
parser.add_argument("--load-8bit", action="store_true")
|
149 |
+
parser.add_argument("--load-4bit", action="store_true")
|
150 |
+
parser.add_argument("--debug", action="store_true")
|
151 |
+
args = parser.parse_args()
|
152 |
+
main(args)
|
medrax/llava/serve/controller.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A controller manages distributed workers.
|
3 |
+
It sends worker addresses to clients.
|
4 |
+
"""
|
5 |
+
import argparse
|
6 |
+
import dataclasses
|
7 |
+
from enum import Enum, auto
|
8 |
+
import json
|
9 |
+
import time
|
10 |
+
from typing import List
|
11 |
+
import threading
|
12 |
+
|
13 |
+
from fastapi import FastAPI, Request
|
14 |
+
from fastapi.responses import StreamingResponse
|
15 |
+
import numpy as np
|
16 |
+
import requests
|
17 |
+
import uvicorn
|
18 |
+
|
19 |
+
from medrax.llava.constants import CONTROLLER_HEART_BEAT_EXPIRATION
|
20 |
+
from medrax.llava.utils import build_logger, server_error_msg
|
21 |
+
|
22 |
+
|
23 |
+
logger = build_logger("controller", "controller.log")
|
24 |
+
|
25 |
+
|
26 |
+
class DispatchMethod(Enum):
|
27 |
+
LOTTERY = auto()
|
28 |
+
SHORTEST_QUEUE = auto()
|
29 |
+
|
30 |
+
@classmethod
|
31 |
+
def from_str(cls, name):
|
32 |
+
if name == "lottery":
|
33 |
+
return cls.LOTTERY
|
34 |
+
elif name == "shortest_queue":
|
35 |
+
return cls.SHORTEST_QUEUE
|
36 |
+
else:
|
37 |
+
raise ValueError("Invalid dispatch method")
|
38 |
+
|
39 |
+
|
40 |
+
@dataclasses.dataclass
|
41 |
+
class WorkerInfo:
|
42 |
+
model_names: List[str]
|
43 |
+
speed: int
|
44 |
+
queue_length: int
|
45 |
+
check_heart_beat: bool
|
46 |
+
last_heart_beat: str
|
47 |
+
|
48 |
+
|
49 |
+
def heart_beat_controller(controller):
|
50 |
+
while True:
|
51 |
+
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
|
52 |
+
controller.remove_stable_workers_by_expiration()
|
53 |
+
|
54 |
+
|
55 |
+
class Controller:
|
56 |
+
def __init__(self, dispatch_method: str):
|
57 |
+
# Dict[str -> WorkerInfo]
|
58 |
+
self.worker_info = {}
|
59 |
+
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
|
60 |
+
|
61 |
+
self.heart_beat_thread = threading.Thread(target=heart_beat_controller, args=(self,))
|
62 |
+
self.heart_beat_thread.start()
|
63 |
+
|
64 |
+
logger.info("Init controller")
|
65 |
+
|
66 |
+
def register_worker(self, worker_name: str, check_heart_beat: bool, worker_status: dict):
|
67 |
+
if worker_name not in self.worker_info:
|
68 |
+
logger.info(f"Register a new worker: {worker_name}")
|
69 |
+
else:
|
70 |
+
logger.info(f"Register an existing worker: {worker_name}")
|
71 |
+
|
72 |
+
if not worker_status:
|
73 |
+
worker_status = self.get_worker_status(worker_name)
|
74 |
+
if not worker_status:
|
75 |
+
return False
|
76 |
+
|
77 |
+
self.worker_info[worker_name] = WorkerInfo(
|
78 |
+
worker_status["model_names"],
|
79 |
+
worker_status["speed"],
|
80 |
+
worker_status["queue_length"],
|
81 |
+
check_heart_beat,
|
82 |
+
time.time(),
|
83 |
+
)
|
84 |
+
|
85 |
+
logger.info(f"Register done: {worker_name}, {worker_status}")
|
86 |
+
return True
|
87 |
+
|
88 |
+
def get_worker_status(self, worker_name: str):
|
89 |
+
try:
|
90 |
+
r = requests.post(worker_name + "/worker_get_status", timeout=5)
|
91 |
+
except requests.exceptions.RequestException as e:
|
92 |
+
logger.error(f"Get status fails: {worker_name}, {e}")
|
93 |
+
return None
|
94 |
+
|
95 |
+
if r.status_code != 200:
|
96 |
+
logger.error(f"Get status fails: {worker_name}, {r}")
|
97 |
+
return None
|
98 |
+
|
99 |
+
return r.json()
|
100 |
+
|
101 |
+
def remove_worker(self, worker_name: str):
|
102 |
+
del self.worker_info[worker_name]
|
103 |
+
|
104 |
+
def refresh_all_workers(self):
|
105 |
+
old_info = dict(self.worker_info)
|
106 |
+
self.worker_info = {}
|
107 |
+
|
108 |
+
for w_name, w_info in old_info.items():
|
109 |
+
if not self.register_worker(w_name, w_info.check_heart_beat, None):
|
110 |
+
logger.info(f"Remove stale worker: {w_name}")
|
111 |
+
|
112 |
+
def list_models(self):
|
113 |
+
model_names = set()
|
114 |
+
|
115 |
+
for w_name, w_info in self.worker_info.items():
|
116 |
+
model_names.update(w_info.model_names)
|
117 |
+
|
118 |
+
return list(model_names)
|
119 |
+
|
120 |
+
def get_worker_address(self, model_name: str):
|
121 |
+
if self.dispatch_method == DispatchMethod.LOTTERY:
|
122 |
+
worker_names = []
|
123 |
+
worker_speeds = []
|
124 |
+
for w_name, w_info in self.worker_info.items():
|
125 |
+
if model_name in w_info.model_names:
|
126 |
+
worker_names.append(w_name)
|
127 |
+
worker_speeds.append(w_info.speed)
|
128 |
+
worker_speeds = np.array(worker_speeds, dtype=np.float32)
|
129 |
+
norm = np.sum(worker_speeds)
|
130 |
+
if norm < 1e-4:
|
131 |
+
return ""
|
132 |
+
worker_speeds = worker_speeds / norm
|
133 |
+
if True: # Directly return address
|
134 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
135 |
+
worker_name = worker_names[pt]
|
136 |
+
return worker_name
|
137 |
+
|
138 |
+
# Check status before returning
|
139 |
+
while True:
|
140 |
+
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
|
141 |
+
worker_name = worker_names[pt]
|
142 |
+
|
143 |
+
if self.get_worker_status(worker_name):
|
144 |
+
break
|
145 |
+
else:
|
146 |
+
self.remove_worker(worker_name)
|
147 |
+
worker_speeds[pt] = 0
|
148 |
+
norm = np.sum(worker_speeds)
|
149 |
+
if norm < 1e-4:
|
150 |
+
return ""
|
151 |
+
worker_speeds = worker_speeds / norm
|
152 |
+
continue
|
153 |
+
return worker_name
|
154 |
+
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
|
155 |
+
worker_names = []
|
156 |
+
worker_qlen = []
|
157 |
+
for w_name, w_info in self.worker_info.items():
|
158 |
+
if model_name in w_info.model_names:
|
159 |
+
worker_names.append(w_name)
|
160 |
+
worker_qlen.append(w_info.queue_length / w_info.speed)
|
161 |
+
if len(worker_names) == 0:
|
162 |
+
return ""
|
163 |
+
min_index = np.argmin(worker_qlen)
|
164 |
+
w_name = worker_names[min_index]
|
165 |
+
self.worker_info[w_name].queue_length += 1
|
166 |
+
logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
|
167 |
+
return w_name
|
168 |
+
else:
|
169 |
+
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
|
170 |
+
|
171 |
+
def receive_heart_beat(self, worker_name: str, queue_length: int):
|
172 |
+
if worker_name not in self.worker_info:
|
173 |
+
logger.info(f"Receive unknown heart beat. {worker_name}")
|
174 |
+
return False
|
175 |
+
|
176 |
+
self.worker_info[worker_name].queue_length = queue_length
|
177 |
+
self.worker_info[worker_name].last_heart_beat = time.time()
|
178 |
+
logger.info(f"Receive heart beat. {worker_name}")
|
179 |
+
return True
|
180 |
+
|
181 |
+
def remove_stable_workers_by_expiration(self):
|
182 |
+
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
|
183 |
+
to_delete = []
|
184 |
+
for worker_name, w_info in self.worker_info.items():
|
185 |
+
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
|
186 |
+
to_delete.append(worker_name)
|
187 |
+
|
188 |
+
for worker_name in to_delete:
|
189 |
+
self.remove_worker(worker_name)
|
190 |
+
|
191 |
+
def worker_api_generate_stream(self, params):
|
192 |
+
worker_addr = self.get_worker_address(params["model"])
|
193 |
+
if not worker_addr:
|
194 |
+
logger.info(f"no worker: {params['model']}")
|
195 |
+
ret = {
|
196 |
+
"text": server_error_msg,
|
197 |
+
"error_code": 2,
|
198 |
+
}
|
199 |
+
yield json.dumps(ret).encode() + b"\0"
|
200 |
+
|
201 |
+
try:
|
202 |
+
response = requests.post(
|
203 |
+
worker_addr + "/worker_generate_stream", json=params, stream=True, timeout=5
|
204 |
+
)
|
205 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
206 |
+
if chunk:
|
207 |
+
yield chunk + b"\0"
|
208 |
+
except requests.exceptions.RequestException:
|
209 |
+
logger.info(f"worker timeout: {worker_addr}")
|
210 |
+
ret = {
|
211 |
+
"text": server_error_msg,
|
212 |
+
"error_code": 3,
|
213 |
+
}
|
214 |
+
yield json.dumps(ret).encode() + b"\0"
|
215 |
+
|
216 |
+
# Let the controller act as a worker to achieve hierarchical
|
217 |
+
# management. This can be used to connect isolated sub networks.
|
218 |
+
def worker_api_get_status(self):
|
219 |
+
model_names = set()
|
220 |
+
speed = 0
|
221 |
+
queue_length = 0
|
222 |
+
|
223 |
+
for w_name in self.worker_info:
|
224 |
+
worker_status = self.get_worker_status(w_name)
|
225 |
+
if worker_status is not None:
|
226 |
+
model_names.update(worker_status["model_names"])
|
227 |
+
speed += worker_status["speed"]
|
228 |
+
queue_length += worker_status["queue_length"]
|
229 |
+
|
230 |
+
return {
|
231 |
+
"model_names": list(model_names),
|
232 |
+
"speed": speed,
|
233 |
+
"queue_length": queue_length,
|
234 |
+
}
|
235 |
+
|
236 |
+
|
237 |
+
app = FastAPI()
|
238 |
+
|
239 |
+
|
240 |
+
@app.post("/register_worker")
|
241 |
+
async def register_worker(request: Request):
|
242 |
+
data = await request.json()
|
243 |
+
controller.register_worker(
|
244 |
+
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
|
245 |
+
)
|
246 |
+
|
247 |
+
|
248 |
+
@app.post("/refresh_all_workers")
|
249 |
+
async def refresh_all_workers():
|
250 |
+
models = controller.refresh_all_workers()
|
251 |
+
|
252 |
+
|
253 |
+
@app.post("/list_models")
|
254 |
+
async def list_models():
|
255 |
+
models = controller.list_models()
|
256 |
+
return {"models": models}
|
257 |
+
|
258 |
+
|
259 |
+
@app.post("/get_worker_address")
|
260 |
+
async def get_worker_address(request: Request):
|
261 |
+
data = await request.json()
|
262 |
+
addr = controller.get_worker_address(data["model"])
|
263 |
+
return {"address": addr}
|
264 |
+
|
265 |
+
|
266 |
+
@app.post("/receive_heart_beat")
|
267 |
+
async def receive_heart_beat(request: Request):
|
268 |
+
data = await request.json()
|
269 |
+
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
|
270 |
+
return {"exist": exist}
|
271 |
+
|
272 |
+
|
273 |
+
@app.post("/worker_generate_stream")
|
274 |
+
async def worker_api_generate_stream(request: Request):
|
275 |
+
params = await request.json()
|
276 |
+
generator = controller.worker_api_generate_stream(params)
|
277 |
+
return StreamingResponse(generator)
|
278 |
+
|
279 |
+
|
280 |
+
@app.post("/worker_get_status")
|
281 |
+
async def worker_api_get_status(request: Request):
|
282 |
+
return controller.worker_api_get_status()
|
283 |
+
|
284 |
+
|
285 |
+
if __name__ == "__main__":
|
286 |
+
parser = argparse.ArgumentParser()
|
287 |
+
parser.add_argument("--host", type=str, default="localhost")
|
288 |
+
parser.add_argument("--port", type=int, default=21001)
|
289 |
+
parser.add_argument(
|
290 |
+
"--dispatch-method",
|
291 |
+
type=str,
|
292 |
+
choices=["lottery", "shortest_queue"],
|
293 |
+
default="shortest_queue",
|
294 |
+
)
|
295 |
+
args = parser.parse_args()
|
296 |
+
logger.info(f"args: {args}")
|
297 |
+
|
298 |
+
controller = Controller(args.dispatch_method)
|
299 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|
medrax/llava/serve/gradio_web_server.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import datetime
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import time
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import requests
|
9 |
+
|
10 |
+
from medrax.llava.conversation import default_conversation, conv_templates, SeparatorStyle
|
11 |
+
from medrax.llava.constants import LOGDIR
|
12 |
+
from medrax.llava.utils import build_logger, server_error_msg, violates_moderation, moderation_msg
|
13 |
+
import hashlib
|
14 |
+
|
15 |
+
|
16 |
+
logger = build_logger("gradio_web_server", "gradio_web_server.log")
|
17 |
+
|
18 |
+
headers = {"User-Agent": "LLaVA-Med Client"}
|
19 |
+
|
20 |
+
no_change_btn = gr.Button.update()
|
21 |
+
enable_btn = gr.Button.update(interactive=True)
|
22 |
+
disable_btn = gr.Button.update(interactive=False)
|
23 |
+
|
24 |
+
priority = {
|
25 |
+
"vicuna-13b": "aaaaaaa",
|
26 |
+
"koala-13b": "aaaaaab",
|
27 |
+
}
|
28 |
+
|
29 |
+
|
30 |
+
def get_conv_log_filename():
|
31 |
+
t = datetime.datetime.now()
|
32 |
+
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
|
33 |
+
return name
|
34 |
+
|
35 |
+
|
36 |
+
def get_model_list():
|
37 |
+
ret = requests.post(args.controller_url + "/refresh_all_workers")
|
38 |
+
assert ret.status_code == 200
|
39 |
+
ret = requests.post(args.controller_url + "/list_models")
|
40 |
+
models = ret.json()["models"]
|
41 |
+
models.sort(key=lambda x: priority.get(x, x))
|
42 |
+
logger.info(f"Models: {models}")
|
43 |
+
return models
|
44 |
+
|
45 |
+
|
46 |
+
get_window_url_params = """
|
47 |
+
function() {
|
48 |
+
const params = new URLSearchParams(window.location.search);
|
49 |
+
url_params = Object.fromEntries(params);
|
50 |
+
console.log(url_params);
|
51 |
+
return url_params;
|
52 |
+
}
|
53 |
+
"""
|
54 |
+
|
55 |
+
|
56 |
+
def load_demo(url_params, request: gr.Request):
|
57 |
+
logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
|
58 |
+
|
59 |
+
dropdown_update = gr.Dropdown.update(visible=True)
|
60 |
+
if "model" in url_params:
|
61 |
+
model = url_params["model"]
|
62 |
+
if model in models:
|
63 |
+
dropdown_update = gr.Dropdown.update(value=model, visible=True)
|
64 |
+
|
65 |
+
state = default_conversation.copy()
|
66 |
+
return state, dropdown_update
|
67 |
+
|
68 |
+
|
69 |
+
def load_demo_refresh_model_list(request: gr.Request):
|
70 |
+
logger.info(f"load_demo. ip: {request.client.host}")
|
71 |
+
models = get_model_list()
|
72 |
+
state = default_conversation.copy()
|
73 |
+
dropdown_update = gr.Dropdown.update(choices=models, value=models[0] if len(models) > 0 else "")
|
74 |
+
return state, dropdown_update
|
75 |
+
|
76 |
+
|
77 |
+
def vote_last_response(state, vote_type, model_selector, request: gr.Request):
|
78 |
+
with open(get_conv_log_filename(), "a") as fout:
|
79 |
+
data = {
|
80 |
+
"tstamp": round(time.time(), 4),
|
81 |
+
"type": vote_type,
|
82 |
+
"model": model_selector,
|
83 |
+
"state": state.dict(),
|
84 |
+
"ip": request.client.host,
|
85 |
+
}
|
86 |
+
fout.write(json.dumps(data) + "\n")
|
87 |
+
|
88 |
+
|
89 |
+
def upvote_last_response(state, model_selector, request: gr.Request):
|
90 |
+
logger.info(f"upvote. ip: {request.client.host}")
|
91 |
+
vote_last_response(state, "upvote", model_selector, request)
|
92 |
+
return ("",) + (disable_btn,) * 3
|
93 |
+
|
94 |
+
|
95 |
+
def downvote_last_response(state, model_selector, request: gr.Request):
|
96 |
+
logger.info(f"downvote. ip: {request.client.host}")
|
97 |
+
vote_last_response(state, "downvote", model_selector, request)
|
98 |
+
return ("",) + (disable_btn,) * 3
|
99 |
+
|
100 |
+
|
101 |
+
def flag_last_response(state, model_selector, request: gr.Request):
|
102 |
+
logger.info(f"flag. ip: {request.client.host}")
|
103 |
+
vote_last_response(state, "flag", model_selector, request)
|
104 |
+
return ("",) + (disable_btn,) * 3
|
105 |
+
|
106 |
+
|
107 |
+
def regenerate(state, image_process_mode, request: gr.Request):
|
108 |
+
logger.info(f"regenerate. ip: {request.client.host}")
|
109 |
+
state.messages[-1][-1] = None
|
110 |
+
prev_human_msg = state.messages[-2]
|
111 |
+
if type(prev_human_msg[1]) in (tuple, list):
|
112 |
+
prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
|
113 |
+
state.skip_next = False
|
114 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
115 |
+
|
116 |
+
|
117 |
+
def clear_history(request: gr.Request):
|
118 |
+
logger.info(f"clear_history. ip: {request.client.host}")
|
119 |
+
state = default_conversation.copy()
|
120 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
121 |
+
|
122 |
+
|
123 |
+
def add_text(state, text, image, image_process_mode, request: gr.Request):
|
124 |
+
logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
|
125 |
+
if len(text) <= 0 and image is None:
|
126 |
+
state.skip_next = True
|
127 |
+
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
|
128 |
+
if args.moderate:
|
129 |
+
flagged = violates_moderation(text)
|
130 |
+
if flagged:
|
131 |
+
state.skip_next = True
|
132 |
+
return (state, state.to_gradio_chatbot(), moderation_msg, None) + (no_change_btn,) * 5
|
133 |
+
|
134 |
+
text = text[:1536] # Hard cut-off
|
135 |
+
if image is not None:
|
136 |
+
text = text[:1200] # Hard cut-off for images
|
137 |
+
if "<image>" not in text:
|
138 |
+
# text = '<Image><image></Image>' + text
|
139 |
+
text = text + "\n<image>"
|
140 |
+
text = (text, image, image_process_mode)
|
141 |
+
if len(state.get_images(return_pil=True)) > 0:
|
142 |
+
state = default_conversation.copy()
|
143 |
+
state.append_message(state.roles[0], text)
|
144 |
+
state.append_message(state.roles[1], None)
|
145 |
+
state.skip_next = False
|
146 |
+
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
|
147 |
+
|
148 |
+
|
149 |
+
def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request):
|
150 |
+
logger.info(f"http_bot. ip: {request.client.host}")
|
151 |
+
start_tstamp = time.time()
|
152 |
+
model_name = model_selector
|
153 |
+
|
154 |
+
if state.skip_next:
|
155 |
+
# This generate call is skipped due to invalid inputs
|
156 |
+
yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
|
157 |
+
return
|
158 |
+
|
159 |
+
if len(state.messages) == state.offset + 2:
|
160 |
+
# First round of conversation
|
161 |
+
if "llava" in model_name.lower():
|
162 |
+
if "llama-2" in model_name.lower():
|
163 |
+
template_name = "llava_llama_2"
|
164 |
+
elif "v1" in model_name.lower():
|
165 |
+
if "mmtag" in model_name.lower():
|
166 |
+
template_name = "v1_mmtag"
|
167 |
+
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
|
168 |
+
template_name = "v1_mmtag"
|
169 |
+
else:
|
170 |
+
template_name = "llava_v1"
|
171 |
+
elif "mpt" in model_name.lower():
|
172 |
+
template_name = "mpt"
|
173 |
+
else:
|
174 |
+
if "mmtag" in model_name.lower():
|
175 |
+
template_name = "v0_mmtag"
|
176 |
+
elif "plain" in model_name.lower() and "finetune" not in model_name.lower():
|
177 |
+
template_name = "v0_mmtag"
|
178 |
+
else:
|
179 |
+
template_name = "llava_v0"
|
180 |
+
elif "mpt" in model_name:
|
181 |
+
template_name = "mpt_text"
|
182 |
+
elif "llama-2" in model_name:
|
183 |
+
template_name = "llama_2"
|
184 |
+
else:
|
185 |
+
template_name = "vicuna_v1"
|
186 |
+
template_name = "mistral_instruct" # FIXME: overwrite
|
187 |
+
new_state = conv_templates[template_name].copy()
|
188 |
+
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
189 |
+
new_state.append_message(new_state.roles[1], None)
|
190 |
+
state = new_state
|
191 |
+
|
192 |
+
# Query worker address
|
193 |
+
controller_url = args.controller_url
|
194 |
+
ret = requests.post(controller_url + "/get_worker_address", json={"model": model_name})
|
195 |
+
worker_addr = ret.json()["address"]
|
196 |
+
logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
|
197 |
+
|
198 |
+
# No available worker
|
199 |
+
if worker_addr == "":
|
200 |
+
state.messages[-1][-1] = server_error_msg
|
201 |
+
yield (
|
202 |
+
state,
|
203 |
+
state.to_gradio_chatbot(),
|
204 |
+
disable_btn,
|
205 |
+
disable_btn,
|
206 |
+
disable_btn,
|
207 |
+
enable_btn,
|
208 |
+
enable_btn,
|
209 |
+
)
|
210 |
+
return
|
211 |
+
|
212 |
+
# Construct prompt
|
213 |
+
prompt = state.get_prompt()
|
214 |
+
|
215 |
+
all_images = state.get_images(return_pil=True)
|
216 |
+
all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images]
|
217 |
+
for image, hash in zip(all_images, all_image_hash):
|
218 |
+
t = datetime.datetime.now()
|
219 |
+
filename = os.path.join(
|
220 |
+
LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg"
|
221 |
+
)
|
222 |
+
if not os.path.isfile(filename):
|
223 |
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
224 |
+
image.save(filename)
|
225 |
+
|
226 |
+
# Make requests
|
227 |
+
pload = {
|
228 |
+
"model": model_name,
|
229 |
+
"prompt": prompt,
|
230 |
+
"temperature": float(temperature),
|
231 |
+
"top_p": float(top_p),
|
232 |
+
"max_new_tokens": min(int(max_new_tokens), 1536),
|
233 |
+
"stop": state.sep
|
234 |
+
if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT]
|
235 |
+
else state.sep2,
|
236 |
+
"images": f"List of {len(state.get_images())} images: {all_image_hash}",
|
237 |
+
}
|
238 |
+
logger.info(f"==== request ====\n{pload}")
|
239 |
+
|
240 |
+
pload["images"] = state.get_images()
|
241 |
+
|
242 |
+
state.messages[-1][-1] = "▌"
|
243 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
244 |
+
|
245 |
+
try:
|
246 |
+
# Stream output
|
247 |
+
response = requests.post(
|
248 |
+
worker_addr + "/worker_generate_stream",
|
249 |
+
headers=headers,
|
250 |
+
json=pload,
|
251 |
+
stream=True,
|
252 |
+
timeout=10,
|
253 |
+
)
|
254 |
+
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
|
255 |
+
if chunk:
|
256 |
+
data = json.loads(chunk.decode())
|
257 |
+
if data["error_code"] == 0:
|
258 |
+
output = data["text"][len(prompt) :].strip()
|
259 |
+
state.messages[-1][-1] = output + "▌"
|
260 |
+
yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
|
261 |
+
else:
|
262 |
+
output = data["text"] + f" (error_code: {data['error_code']})"
|
263 |
+
state.messages[-1][-1] = output
|
264 |
+
yield (state, state.to_gradio_chatbot()) + (
|
265 |
+
disable_btn,
|
266 |
+
disable_btn,
|
267 |
+
disable_btn,
|
268 |
+
enable_btn,
|
269 |
+
enable_btn,
|
270 |
+
)
|
271 |
+
return
|
272 |
+
time.sleep(0.03)
|
273 |
+
except requests.exceptions.RequestException:
|
274 |
+
state.messages[-1][-1] = server_error_msg
|
275 |
+
yield (state, state.to_gradio_chatbot()) + (
|
276 |
+
disable_btn,
|
277 |
+
disable_btn,
|
278 |
+
disable_btn,
|
279 |
+
enable_btn,
|
280 |
+
enable_btn,
|
281 |
+
)
|
282 |
+
return
|
283 |
+
|
284 |
+
state.messages[-1][-1] = state.messages[-1][-1][:-1]
|
285 |
+
yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5
|
286 |
+
|
287 |
+
finish_tstamp = time.time()
|
288 |
+
logger.info(f"{output}")
|
289 |
+
|
290 |
+
with open(get_conv_log_filename(), "a") as fout:
|
291 |
+
data = {
|
292 |
+
"tstamp": round(finish_tstamp, 4),
|
293 |
+
"type": "chat",
|
294 |
+
"model": model_name,
|
295 |
+
"start": round(start_tstamp, 4),
|
296 |
+
"finish": round(finish_tstamp, 4),
|
297 |
+
"state": state.dict(),
|
298 |
+
"images": all_image_hash,
|
299 |
+
"ip": request.client.host,
|
300 |
+
}
|
301 |
+
fout.write(json.dumps(data) + "\n")
|
302 |
+
|
303 |
+
|
304 |
+
title_markdown = """
|
305 |
+
# 🌋 LLaVA-Med: Large Language and Vision Assistant for Medical Research
|
306 |
+
[[Project Page]](https://llava-vl.github.io) [[Paper]](https://arxiv.org/abs/2304.08485) [[Code]](https://github.com/haotian-liu/LLaVA) [[Model]](https://huggingface.co/liuhaotian/LLaVA-13b-delta-v0)
|
307 |
+
"""
|
308 |
+
|
309 |
+
tos_markdown = """
|
310 |
+
### Terms of use
|
311 |
+
By using this service, users are required to agree to the following terms:
|
312 |
+
The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
313 |
+
Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
|
314 |
+
For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
|
315 |
+
"""
|
316 |
+
|
317 |
+
|
318 |
+
learn_more_markdown = """
|
319 |
+
### License
|
320 |
+
The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation.
|
321 |
+
"""
|
322 |
+
|
323 |
+
block_css = """
|
324 |
+
|
325 |
+
#buttons button {
|
326 |
+
min-width: min(120px,100%);
|
327 |
+
}
|
328 |
+
|
329 |
+
"""
|
330 |
+
|
331 |
+
|
332 |
+
def build_demo(embed_mode):
|
333 |
+
textbox = gr.Textbox(
|
334 |
+
show_label=False, placeholder="Enter text and press ENTER", container=False
|
335 |
+
)
|
336 |
+
with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo:
|
337 |
+
state = gr.State()
|
338 |
+
|
339 |
+
if not embed_mode:
|
340 |
+
gr.Markdown(title_markdown)
|
341 |
+
|
342 |
+
with gr.Row():
|
343 |
+
with gr.Column(scale=3):
|
344 |
+
with gr.Row(elem_id="model_selector_row"):
|
345 |
+
model_selector = gr.Dropdown(
|
346 |
+
choices=models,
|
347 |
+
value=models[0] if len(models) > 0 else "",
|
348 |
+
interactive=True,
|
349 |
+
show_label=False,
|
350 |
+
container=False,
|
351 |
+
)
|
352 |
+
|
353 |
+
imagebox = gr.Image(type="pil")
|
354 |
+
image_process_mode = gr.Radio(
|
355 |
+
["Crop", "Resize", "Pad", "Default"],
|
356 |
+
value="Default",
|
357 |
+
label="Preprocess for non-square image",
|
358 |
+
visible=False,
|
359 |
+
)
|
360 |
+
|
361 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
362 |
+
gr.Examples(
|
363 |
+
examples=[
|
364 |
+
[f"{cur_dir}/examples/bio_patch.png", "What is this image about?"],
|
365 |
+
[
|
366 |
+
f"{cur_dir}/examples/med_img_1.png",
|
367 |
+
"Can you describe the image in details?",
|
368 |
+
],
|
369 |
+
[
|
370 |
+
f"{cur_dir}/examples/xy_chromosome.jpg",
|
371 |
+
"Can you describe the image in details?",
|
372 |
+
],
|
373 |
+
[
|
374 |
+
f"{cur_dir}/examples/synpic42202.jpg",
|
375 |
+
"Is there evidence of an aortic aneurysm? Please choose from the following two options: [yes, no]?",
|
376 |
+
], # answer" yes
|
377 |
+
[
|
378 |
+
f"{cur_dir}/examples/synpic32933.jpg",
|
379 |
+
"What is the abnormality by the right hemidiaphragm?",
|
380 |
+
], # answer: free air
|
381 |
+
[
|
382 |
+
f"{cur_dir}/examples/extreme_ironing.jpg",
|
383 |
+
"What is unusual about this image?",
|
384 |
+
],
|
385 |
+
[
|
386 |
+
f"{cur_dir}/examples/waterview.jpg",
|
387 |
+
"What are the things I should be cautious about when I visit here?",
|
388 |
+
],
|
389 |
+
],
|
390 |
+
inputs=[imagebox, textbox],
|
391 |
+
)
|
392 |
+
|
393 |
+
with gr.Accordion("Parameters", open=False) as parameter_row:
|
394 |
+
temperature = gr.Slider(
|
395 |
+
minimum=0.0,
|
396 |
+
maximum=1.0,
|
397 |
+
value=0.2,
|
398 |
+
step=0.1,
|
399 |
+
interactive=True,
|
400 |
+
label="Temperature",
|
401 |
+
)
|
402 |
+
top_p = gr.Slider(
|
403 |
+
minimum=0.0,
|
404 |
+
maximum=1.0,
|
405 |
+
value=0.7,
|
406 |
+
step=0.1,
|
407 |
+
interactive=True,
|
408 |
+
label="Top P",
|
409 |
+
)
|
410 |
+
max_output_tokens = gr.Slider(
|
411 |
+
minimum=0,
|
412 |
+
maximum=1024,
|
413 |
+
value=512,
|
414 |
+
step=64,
|
415 |
+
interactive=True,
|
416 |
+
label="Max output tokens",
|
417 |
+
)
|
418 |
+
|
419 |
+
with gr.Column(scale=8):
|
420 |
+
chatbot = gr.Chatbot(elem_id="chatbot", label="LLaVA-Med Chatbot", height=550)
|
421 |
+
with gr.Row():
|
422 |
+
with gr.Column(scale=8):
|
423 |
+
textbox.render()
|
424 |
+
with gr.Column(scale=1, min_width=50):
|
425 |
+
submit_btn = gr.Button(value="Send", variant="primary")
|
426 |
+
with gr.Row(elem_id="buttons") as button_row:
|
427 |
+
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
|
428 |
+
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
|
429 |
+
flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
|
430 |
+
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
|
431 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
|
432 |
+
clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
|
433 |
+
|
434 |
+
if not embed_mode:
|
435 |
+
gr.Markdown(tos_markdown)
|
436 |
+
gr.Markdown(learn_more_markdown)
|
437 |
+
url_params = gr.JSON(visible=False)
|
438 |
+
|
439 |
+
# Register listeners
|
440 |
+
btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
|
441 |
+
upvote_btn.click(
|
442 |
+
upvote_last_response,
|
443 |
+
[state, model_selector],
|
444 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
445 |
+
queue=False,
|
446 |
+
)
|
447 |
+
downvote_btn.click(
|
448 |
+
downvote_last_response,
|
449 |
+
[state, model_selector],
|
450 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
451 |
+
queue=False,
|
452 |
+
)
|
453 |
+
flag_btn.click(
|
454 |
+
flag_last_response,
|
455 |
+
[state, model_selector],
|
456 |
+
[textbox, upvote_btn, downvote_btn, flag_btn],
|
457 |
+
queue=False,
|
458 |
+
)
|
459 |
+
|
460 |
+
regenerate_btn.click(
|
461 |
+
regenerate,
|
462 |
+
[state, image_process_mode],
|
463 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
464 |
+
queue=False,
|
465 |
+
).then(
|
466 |
+
http_bot,
|
467 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
468 |
+
[state, chatbot] + btn_list,
|
469 |
+
)
|
470 |
+
|
471 |
+
clear_btn.click(
|
472 |
+
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False
|
473 |
+
)
|
474 |
+
|
475 |
+
textbox.submit(
|
476 |
+
add_text,
|
477 |
+
[state, textbox, imagebox, image_process_mode],
|
478 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
479 |
+
queue=False,
|
480 |
+
).then(
|
481 |
+
http_bot,
|
482 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
483 |
+
[state, chatbot] + btn_list,
|
484 |
+
)
|
485 |
+
|
486 |
+
submit_btn.click(
|
487 |
+
add_text,
|
488 |
+
[state, textbox, imagebox, image_process_mode],
|
489 |
+
[state, chatbot, textbox, imagebox] + btn_list,
|
490 |
+
queue=False,
|
491 |
+
).then(
|
492 |
+
http_bot,
|
493 |
+
[state, model_selector, temperature, top_p, max_output_tokens],
|
494 |
+
[state, chatbot] + btn_list,
|
495 |
+
)
|
496 |
+
|
497 |
+
if args.model_list_mode == "once":
|
498 |
+
demo.load(
|
499 |
+
load_demo,
|
500 |
+
[url_params],
|
501 |
+
[state, model_selector],
|
502 |
+
_js=get_window_url_params,
|
503 |
+
queue=False,
|
504 |
+
)
|
505 |
+
elif args.model_list_mode == "reload":
|
506 |
+
demo.load(load_demo_refresh_model_list, None, [state, model_selector], queue=False)
|
507 |
+
else:
|
508 |
+
raise ValueError(f"Unknown model list mode: {args.model_list_mode}")
|
509 |
+
|
510 |
+
return demo
|
511 |
+
|
512 |
+
|
513 |
+
if __name__ == "__main__":
|
514 |
+
parser = argparse.ArgumentParser()
|
515 |
+
parser.add_argument("--host", type=str, default="0.0.0.0")
|
516 |
+
parser.add_argument("--port", type=int)
|
517 |
+
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
|
518 |
+
parser.add_argument("--concurrency-count", type=int, default=10)
|
519 |
+
parser.add_argument("--model-list-mode", type=str, default="once", choices=["once", "reload"])
|
520 |
+
parser.add_argument("--share", action="store_true")
|
521 |
+
parser.add_argument("--moderate", action="store_true")
|
522 |
+
parser.add_argument("--embed", action="store_true")
|
523 |
+
args = parser.parse_args()
|
524 |
+
logger.info(f"args: {args}")
|
525 |
+
|
526 |
+
models = get_model_list()
|
527 |
+
|
528 |
+
logger.info(args)
|
529 |
+
demo = build_demo(args.embed)
|
530 |
+
demo.queue(concurrency_count=args.concurrency_count, api_open=False).launch(
|
531 |
+
server_name=args.host, server_port=args.port, share=args.share
|
532 |
+
)
|
medrax/llava/serve/model_worker.py
ADDED
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A model worker executes the model.
|
3 |
+
"""
|
4 |
+
import argparse
|
5 |
+
import asyncio
|
6 |
+
import json
|
7 |
+
import time
|
8 |
+
import threading
|
9 |
+
import uuid
|
10 |
+
|
11 |
+
from fastapi import FastAPI, Request, BackgroundTasks
|
12 |
+
from fastapi.responses import StreamingResponse
|
13 |
+
import requests
|
14 |
+
import torch
|
15 |
+
import uvicorn
|
16 |
+
from functools import partial
|
17 |
+
|
18 |
+
from medrax.llava.constants import WORKER_HEART_BEAT_INTERVAL
|
19 |
+
from medrax.llava.utils import build_logger, server_error_msg, pretty_print_semaphore
|
20 |
+
from medrax.llava.model.builder import load_pretrained_model
|
21 |
+
from medrax.llava.mm_utils import (
|
22 |
+
process_images,
|
23 |
+
load_image_from_base64,
|
24 |
+
tokenizer_image_token,
|
25 |
+
KeywordsStoppingCriteria,
|
26 |
+
)
|
27 |
+
from medrax.llava.constants import (
|
28 |
+
IMAGE_TOKEN_INDEX,
|
29 |
+
DEFAULT_IMAGE_TOKEN,
|
30 |
+
DEFAULT_IM_START_TOKEN,
|
31 |
+
DEFAULT_IM_END_TOKEN,
|
32 |
+
)
|
33 |
+
from transformers import TextIteratorStreamer
|
34 |
+
from threading import Thread
|
35 |
+
|
36 |
+
|
37 |
+
GB = 1 << 30
|
38 |
+
|
39 |
+
worker_id = str(uuid.uuid4())[:6]
|
40 |
+
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
|
41 |
+
global_counter = 0
|
42 |
+
|
43 |
+
model_semaphore = None
|
44 |
+
|
45 |
+
|
46 |
+
def heart_beat_worker(controller):
|
47 |
+
|
48 |
+
while True:
|
49 |
+
time.sleep(WORKER_HEART_BEAT_INTERVAL)
|
50 |
+
controller.send_heart_beat()
|
51 |
+
|
52 |
+
|
53 |
+
class ModelWorker:
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
controller_addr,
|
57 |
+
worker_addr,
|
58 |
+
worker_id,
|
59 |
+
no_register,
|
60 |
+
model_path,
|
61 |
+
model_base,
|
62 |
+
model_name,
|
63 |
+
load_8bit,
|
64 |
+
load_4bit,
|
65 |
+
device,
|
66 |
+
):
|
67 |
+
self.controller_addr = controller_addr
|
68 |
+
self.worker_addr = worker_addr
|
69 |
+
self.worker_id = worker_id
|
70 |
+
if model_path.endswith("/"):
|
71 |
+
model_path = model_path[:-1]
|
72 |
+
if model_name is None:
|
73 |
+
model_paths = model_path.split("/")
|
74 |
+
if model_paths[-1].startswith("checkpoint-"):
|
75 |
+
self.model_name = model_paths[-2] + "_" + model_paths[-1]
|
76 |
+
else:
|
77 |
+
self.model_name = model_paths[-1]
|
78 |
+
else:
|
79 |
+
self.model_name = model_name
|
80 |
+
|
81 |
+
self.device = device
|
82 |
+
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
|
83 |
+
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
|
84 |
+
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device
|
85 |
+
)
|
86 |
+
self.is_multimodal = "llava" in self.model_name.lower()
|
87 |
+
|
88 |
+
if not no_register:
|
89 |
+
self.register_to_controller()
|
90 |
+
self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,))
|
91 |
+
self.heart_beat_thread.start()
|
92 |
+
|
93 |
+
def register_to_controller(self):
|
94 |
+
logger.info("Register to controller")
|
95 |
+
|
96 |
+
url = self.controller_addr + "/register_worker"
|
97 |
+
data = {
|
98 |
+
"worker_name": self.worker_addr,
|
99 |
+
"check_heart_beat": True,
|
100 |
+
"worker_status": self.get_status(),
|
101 |
+
}
|
102 |
+
r = requests.post(url, json=data)
|
103 |
+
assert r.status_code == 200
|
104 |
+
|
105 |
+
def send_heart_beat(self):
|
106 |
+
logger.info(
|
107 |
+
f"Send heart beat. Models: {[self.model_name]}. "
|
108 |
+
f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
|
109 |
+
f"global_counter: {global_counter}"
|
110 |
+
)
|
111 |
+
|
112 |
+
url = self.controller_addr + "/receive_heart_beat"
|
113 |
+
|
114 |
+
while True:
|
115 |
+
try:
|
116 |
+
ret = requests.post(
|
117 |
+
url,
|
118 |
+
json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()},
|
119 |
+
timeout=5,
|
120 |
+
)
|
121 |
+
exist = ret.json()["exist"]
|
122 |
+
break
|
123 |
+
except requests.exceptions.RequestException as e:
|
124 |
+
logger.error(f"heart beat error: {e}")
|
125 |
+
time.sleep(5)
|
126 |
+
|
127 |
+
if not exist:
|
128 |
+
self.register_to_controller()
|
129 |
+
|
130 |
+
def get_queue_length(self):
|
131 |
+
if model_semaphore is None:
|
132 |
+
return 0
|
133 |
+
else:
|
134 |
+
return (
|
135 |
+
args.limit_model_concurrency
|
136 |
+
- model_semaphore._value
|
137 |
+
+ (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0)
|
138 |
+
)
|
139 |
+
|
140 |
+
def get_status(self):
|
141 |
+
return {
|
142 |
+
"model_names": [self.model_name],
|
143 |
+
"speed": 1,
|
144 |
+
"queue_length": self.get_queue_length(),
|
145 |
+
}
|
146 |
+
|
147 |
+
@torch.inference_mode()
|
148 |
+
def generate_stream(self, params):
|
149 |
+
tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
|
150 |
+
|
151 |
+
prompt = params["prompt"]
|
152 |
+
ori_prompt = prompt
|
153 |
+
images = params.get("images", None)
|
154 |
+
num_image_tokens = 0
|
155 |
+
if images is not None and len(images) > 0 and self.is_multimodal:
|
156 |
+
if len(images) > 0:
|
157 |
+
if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
|
158 |
+
raise ValueError(
|
159 |
+
"Number of images does not match number of <image> tokens in prompt"
|
160 |
+
)
|
161 |
+
|
162 |
+
images = [load_image_from_base64(image) for image in images]
|
163 |
+
images = process_images(images, image_processor, model.config)
|
164 |
+
|
165 |
+
if type(images) is list:
|
166 |
+
images = [image.to(self.model.device, dtype=torch.float16) for image in images]
|
167 |
+
else:
|
168 |
+
images = images.to(self.model.device, dtype=torch.float16)
|
169 |
+
|
170 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
171 |
+
if getattr(self.model.config, "mm_use_im_start_end", False):
|
172 |
+
replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
|
173 |
+
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
|
174 |
+
|
175 |
+
num_image_tokens = (
|
176 |
+
prompt.count(replace_token) * model.get_vision_tower().num_patches
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
images = None
|
180 |
+
image_args = {"images": images}
|
181 |
+
else:
|
182 |
+
images = None
|
183 |
+
image_args = {}
|
184 |
+
|
185 |
+
temperature = float(params.get("temperature", 1.0))
|
186 |
+
top_p = float(params.get("top_p", 1.0))
|
187 |
+
max_context_length = getattr(model.config, "max_position_embeddings", 2048)
|
188 |
+
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
|
189 |
+
stop_str = params.get("stop", None)
|
190 |
+
do_sample = True if temperature > 0.001 else False
|
191 |
+
|
192 |
+
input_ids = (
|
193 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
194 |
+
.unsqueeze(0)
|
195 |
+
.to(self.device)
|
196 |
+
)
|
197 |
+
keywords = [stop_str]
|
198 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
199 |
+
streamer = TextIteratorStreamer(
|
200 |
+
tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15
|
201 |
+
)
|
202 |
+
|
203 |
+
max_new_tokens = min(
|
204 |
+
max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens
|
205 |
+
)
|
206 |
+
|
207 |
+
if max_new_tokens < 1:
|
208 |
+
yield json.dumps(
|
209 |
+
{
|
210 |
+
"text": ori_prompt
|
211 |
+
+ "Exceeds max token length. Please start a new conversation, thanks.",
|
212 |
+
"error_code": 0,
|
213 |
+
}
|
214 |
+
).encode() + b"\0"
|
215 |
+
return
|
216 |
+
|
217 |
+
thread = Thread(
|
218 |
+
target=model.generate,
|
219 |
+
kwargs=dict(
|
220 |
+
inputs=input_ids,
|
221 |
+
do_sample=do_sample,
|
222 |
+
temperature=temperature,
|
223 |
+
top_p=top_p,
|
224 |
+
max_new_tokens=max_new_tokens,
|
225 |
+
streamer=streamer,
|
226 |
+
stopping_criteria=[stopping_criteria],
|
227 |
+
use_cache=True,
|
228 |
+
**image_args,
|
229 |
+
),
|
230 |
+
)
|
231 |
+
thread.start()
|
232 |
+
|
233 |
+
generated_text = ori_prompt
|
234 |
+
for new_text in streamer:
|
235 |
+
generated_text += new_text
|
236 |
+
if generated_text.endswith(stop_str):
|
237 |
+
generated_text = generated_text[: -len(stop_str)]
|
238 |
+
yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0"
|
239 |
+
|
240 |
+
def generate_stream_gate(self, params):
|
241 |
+
try:
|
242 |
+
for x in self.generate_stream(params):
|
243 |
+
yield x
|
244 |
+
except ValueError as e:
|
245 |
+
print("Caught ValueError:", e)
|
246 |
+
ret = {
|
247 |
+
"text": server_error_msg,
|
248 |
+
"error_code": 1,
|
249 |
+
}
|
250 |
+
yield json.dumps(ret).encode() + b"\0"
|
251 |
+
except torch.cuda.CudaError as e:
|
252 |
+
print("Caught torch.cuda.CudaError:", e)
|
253 |
+
ret = {
|
254 |
+
"text": server_error_msg,
|
255 |
+
"error_code": 1,
|
256 |
+
}
|
257 |
+
yield json.dumps(ret).encode() + b"\0"
|
258 |
+
except Exception as e:
|
259 |
+
print("Caught Unknown Error", e)
|
260 |
+
ret = {
|
261 |
+
"text": server_error_msg,
|
262 |
+
"error_code": 1,
|
263 |
+
}
|
264 |
+
yield json.dumps(ret).encode() + b"\0"
|
265 |
+
|
266 |
+
|
267 |
+
app = FastAPI()
|
268 |
+
|
269 |
+
|
270 |
+
def release_model_semaphore(fn=None):
|
271 |
+
model_semaphore.release()
|
272 |
+
if fn is not None:
|
273 |
+
fn()
|
274 |
+
|
275 |
+
|
276 |
+
@app.post("/worker_generate_stream")
|
277 |
+
async def generate_stream(request: Request):
|
278 |
+
global model_semaphore, global_counter
|
279 |
+
global_counter += 1
|
280 |
+
params = await request.json()
|
281 |
+
|
282 |
+
if model_semaphore is None:
|
283 |
+
model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
|
284 |
+
await model_semaphore.acquire()
|
285 |
+
worker.send_heart_beat()
|
286 |
+
generator = worker.generate_stream_gate(params)
|
287 |
+
background_tasks = BackgroundTasks()
|
288 |
+
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat))
|
289 |
+
return StreamingResponse(generator, background=background_tasks)
|
290 |
+
|
291 |
+
|
292 |
+
@app.post("/worker_get_status")
|
293 |
+
async def get_status(request: Request):
|
294 |
+
return worker.get_status()
|
295 |
+
|
296 |
+
|
297 |
+
if __name__ == "__main__":
|
298 |
+
parser = argparse.ArgumentParser()
|
299 |
+
parser.add_argument("--host", type=str, default="localhost")
|
300 |
+
parser.add_argument("--port", type=int, default=21002)
|
301 |
+
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
302 |
+
parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
|
303 |
+
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
|
304 |
+
parser.add_argument("--model-base", type=str, default=None)
|
305 |
+
parser.add_argument("--model-name", type=str)
|
306 |
+
parser.add_argument("--device", type=str, default="cuda")
|
307 |
+
parser.add_argument(
|
308 |
+
"--multi-modal",
|
309 |
+
action="store_true",
|
310 |
+
help="Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path.",
|
311 |
+
)
|
312 |
+
parser.add_argument("--limit-model-concurrency", type=int, default=5)
|
313 |
+
parser.add_argument("--stream-interval", type=int, default=1)
|
314 |
+
parser.add_argument("--no-register", action="store_true")
|
315 |
+
parser.add_argument("--load-8bit", action="store_true")
|
316 |
+
parser.add_argument("--load-4bit", action="store_true")
|
317 |
+
args = parser.parse_args()
|
318 |
+
logger.info(f"args: {args}")
|
319 |
+
|
320 |
+
if args.multi_modal:
|
321 |
+
logger.warning(
|
322 |
+
"Multimodal mode is automatically detected with model name, please make sure `llava` is included in the model path."
|
323 |
+
)
|
324 |
+
|
325 |
+
worker = ModelWorker(
|
326 |
+
args.controller_address,
|
327 |
+
args.worker_address,
|
328 |
+
worker_id,
|
329 |
+
args.no_register,
|
330 |
+
args.model_path,
|
331 |
+
args.model_base,
|
332 |
+
args.model_name,
|
333 |
+
args.load_8bit,
|
334 |
+
args.load_4bit,
|
335 |
+
args.device,
|
336 |
+
)
|
337 |
+
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
|