|
|
|
""" |
|
Medical X-ray Question Generation Benchmark aka ChestAgentBench |
|
|
|
This script generates clinical questions from X-ray case data of Eurorad dataset using GPT-4o. |
|
It structures questions across different analytical categories and saves them as JSON. |
|
""" |
|
|
|
import os |
|
import re |
|
import json |
|
from typing import * |
|
from pprint import pprint |
|
|
|
import openai |
|
import numpy as np |
|
from scipy import stats |
|
import plotly.graph_objects as go |
|
from tqdm import tqdm |
|
|
|
from benchmark.utils import load_eurorad_dataset |
|
from benchmark.llm import get_llm_response |
|
|
|
|
|
DATA_DIR = "set your data directory here, e.g. /home/MedRAX/data" |
|
DATASET_PATH = os.path.join(DATA_DIR, "eurorad_metadata.json") |
|
|
|
SYSTEM_PROMPT = """ |
|
You are an expert medical benchmark creation assistant. |
|
Your goal is to generate questions that evaluate a multimodal medical AI agent's ability to interpret and reason about chest X-rays. |
|
""".strip() |
|
|
|
CATEGORIES_META = { |
|
"detection": "Identify and locate specific findings in the chest X-ray.", |
|
"classification": "Determine whether specific findings are present or absent in the chest X-ray.", |
|
"enumeration": "Count the number of target findings in the chest X-ray.", |
|
"localization": "Locate a given finding in the chest X-ray.", |
|
"comparison": "Compare the size or position of a specific finding in the chest X-ray.", |
|
"relationship": "Determine the relationship between two or more findings in the chest X-ray.", |
|
"diagnosis": "Make a diagnosis or determine a treatment plan by interpreting the chest X-ray.", |
|
"characterization": "Describe specific attributes (shape, density, margins, etc.) of findings.", |
|
"reasoning": "Explain the medical rationale and thought process behind findings and conclusions.", |
|
} |
|
CATEGORIES = list(CATEGORIES_META.keys()) |
|
|
|
CATEGORY_COMBINATIONS = [ |
|
["detection", "localization", "characterization", "reasoning"], |
|
["detection", "classification", "relationship", "reasoning"], |
|
["localization", "comparison", "relationship", "reasoning"], |
|
["classification", "comparison", "diagnosis", "reasoning"], |
|
["classification", "characterization", "diagnosis", "reasoning"], |
|
] |
|
|
|
DEFAULT_SECTIONS = [ |
|
"history", |
|
"image_finding", |
|
"discussion", |
|
"differential_diagnosis", |
|
"diagnosis", |
|
"figures", |
|
] |
|
|
|
|
|
class Question: |
|
"""A class to generate clinical questions from case data. |
|
|
|
This class handles creating structured clinical questions by combining case data with |
|
specified categories and difficulty levels. |
|
|
|
Attributes: |
|
type (str): The type of question (e.g. multiple choice) |
|
difficulty (str): Difficulty level of the question |
|
case_data (Dict[str, Any]): Dictionary containing the clinical case data |
|
case_content (str): Formatted case data from selected sections |
|
case_id (str): Unique identifier for the case |
|
categories (List[str]): List of analytical categories this question tests |
|
sections (List[str]): Case sections to include in question |
|
raw_content (Optional[str]): Raw LLM response to the question prompt |
|
content (Optional[Dict[str, str]]): Extracted content from the raw LLM response |
|
""" |
|
|
|
def __init__( |
|
self, |
|
type: str, |
|
difficulty: str, |
|
case_data: Dict[str, Any], |
|
categories: List[str], |
|
sections: List[str] = [ |
|
"history", |
|
"image_finding", |
|
"discussion", |
|
"differential_diagnosis", |
|
"diagnosis", |
|
"figures", |
|
], |
|
system_prompt: str = "You are an expert medical benchmark creation assistant.", |
|
) -> None: |
|
self.type = type |
|
self.difficulty = difficulty |
|
self.case_data = case_data |
|
self.case_id = case_data["case_id"] |
|
self.categories = categories |
|
self.sections = sections |
|
self.system_prompt = system_prompt |
|
self.case_content = self.select_case_sections() |
|
self.raw_content: Optional[str] = None |
|
self.content: Optional[Dict[str, str]] = None |
|
|
|
def create_question_prompt(self) -> str: |
|
"""Creates a formatted prompt for generating a clinical question. |
|
|
|
Returns: |
|
str: A structured prompt containing the question parameters and clinical data |
|
""" |
|
category_descriptions = "\n".join( |
|
f"{category}: {desc}" |
|
for category, desc in CATEGORIES_META.items() |
|
if category in self.categories |
|
) |
|
|
|
return f""" |
|
You must follow these guidelines: |
|
1. Questions must be answerable using only context and chest X-rays. |
|
- Questions must explicitly mention the referenced figures |
|
- Questions can only reference the chest X-ray figures |
|
|
|
2. Questions must have unambiguous, verifiable answers, and should: |
|
- Challenge the agent's analytical capabilities |
|
- Require multi-step reasoning |
|
- Test ability to make precise observations |
|
- Evaluate capability to derive insights and findings from the chest X-ray |
|
|
|
3. The agent has access to tools like classification, report generation, segmentation, grounding, visual question answering, etc. Your question should be complex to require the use of such tools. |
|
|
|
|
|
Create a {self.difficulty} {self.type} clinical question that integrates the following: |
|
|
|
{category_descriptions} |
|
|
|
based on the following clinical case: |
|
|
|
{self.case_content} |
|
|
|
Do not use any infomration derived from the CT and MRI images. Do not provide any information and findings about the chest X-rays. |
|
Your question should require the agent to derive insights and findings from the chest X-ray by itself. |
|
Your answer should be verifiable directly in the context of the case. |
|
You can only use the image findings that come from the chest X-ray figures. |
|
|
|
Your response must follow this exact format: |
|
THOUGHTS: [Think about different reasoning steps and tools the agent should use to answer the question] |
|
QUESTION: [complete question with relevant context. Incorrect choices should be very close to the correct answer.] |
|
FIGURES: [list of required figures, e.g. ["Figure 1", "Figure 2a"]] |
|
EXPLANATION: [short explanation of why your answer is verifiable in the case] |
|
ANSWER: [correct answer e.g. "A"] |
|
""".strip().replace( |
|
" ", "" |
|
) |
|
|
|
def select_case_sections(self) -> str: |
|
"""Extract and format selected sections from case data into paragraphs. |
|
|
|
Returns: |
|
str: Formatted string with case sections and content |
|
""" |
|
section_mapping = { |
|
"history": ("history", "No history provided."), |
|
"image_finding": ("image_finding", "No findings provided."), |
|
"discussion": ("discussion", "No discussion provided."), |
|
"differential_diagnosis": ( |
|
"differential_diagnosis", |
|
"No differential diagnosis provided.", |
|
), |
|
"diagnosis": ("diagnosis", "No diagnosis provided."), |
|
"figures": ("figures", "No figures provided."), |
|
} |
|
|
|
formatted = [] |
|
for section in self.sections: |
|
if section in section_mapping: |
|
key, default = section_mapping[section] |
|
content = self.case_data.get(key, default) |
|
|
|
if key == "figures": |
|
figures_text = [] |
|
for figure in content: |
|
for subfig in figure["subfigures"]: |
|
figures_text.append(f"{subfig['number']}: {subfig['caption']}") |
|
content = "\n".join(figures_text) |
|
|
|
formatted.append(f"{section}:\n{content}") |
|
|
|
return "\n\n".join(formatted) |
|
|
|
def create_question( |
|
self, |
|
client: openai.OpenAI, |
|
temperature: float = 0.7, |
|
top_p: float = 0.95, |
|
max_tokens: int = 500, |
|
model: str = "gpt-4o", |
|
) -> str: |
|
"""Create a clinical question using LLM. |
|
|
|
Args: |
|
client (openai.OpenAI): OpenAI client instance |
|
temperature (float): Controls randomness in responses. Defaults to 0.7. |
|
top_p (float): Controls diversity via nucleus sampling. Defaults to 0.95. |
|
max_tokens (int): Max tokens in model response. Defaults to 500. |
|
model (str): OpenAI model to use. Defaults to "gpt-4o". |
|
|
|
Returns: |
|
str: LLM response containing formatted question components |
|
""" |
|
self.raw_content = get_llm_response( |
|
client=client, |
|
prompt=self.create_question_prompt(), |
|
system_prompt=self.system_prompt, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=max_tokens, |
|
model=model, |
|
) |
|
self.content = self.extract_content() |
|
|
|
return self.raw_content |
|
|
|
def extract_content(self) -> Dict[str, str]: |
|
"""Extract sections from raw LLM response using regex patterns. |
|
|
|
Returns: |
|
Dict[str, str]: Extracted sections including thoughts, question, figures, explanation, and answer |
|
""" |
|
keywords = ["THOUGHTS", "QUESTION", "FIGURES", "EXPLANATION", "ANSWER"] |
|
|
|
content = {} |
|
for kw in keywords: |
|
pattern = rf"{kw}:\s*(.*?)(?=\n[A-Z]+:|$)" |
|
match = re.search(pattern, self.raw_content, re.DOTALL) |
|
content[kw.lower()] = match.group(1).strip() if match else None |
|
|
|
return content |
|
|
|
def save(self, output_path: str) -> Dict[str, Any]: |
|
"""Save question content and metadata as a JSON file. |
|
|
|
Args: |
|
output_path (str): Directory path where the JSON file will be saved |
|
|
|
Returns: |
|
Dict[str, Any]: Question data including content (thoughts, question, figures, options, |
|
explanation, answer) and metadata (type, difficulty, categories, etc.) |
|
""" |
|
question_metadata = self.content.copy() |
|
|
|
|
|
question_metadata["metadata"] = { |
|
"case_id": self.case_id, |
|
"type": self.type, |
|
"difficulty": self.difficulty, |
|
"categories": self.categories, |
|
"sections": self.sections, |
|
} |
|
|
|
|
|
case_dir = os.path.join(output_path, str(self.case_id)) |
|
os.makedirs(case_dir, exist_ok=True) |
|
|
|
|
|
output_file = os.path.join(case_dir, f"{self.case_id}_{self.__hash__()}.json") |
|
with open(output_file, "w") as f: |
|
json.dump(question_metadata, f, indent=2) |
|
|
|
return question_metadata |
|
|
|
|
|
def generate_questions( |
|
dataset: Dict[str, Any], |
|
client: openai.OpenAI, |
|
output_dir: str, |
|
skip_first: int = 100, |
|
temperature: float = 0.7, |
|
top_p: float = 0.95, |
|
max_tokens: int = 1200, |
|
model: str = "gpt-4o", |
|
) -> None: |
|
"""Generate questions for each case and category combination. |
|
|
|
Args: |
|
dataset: Dictionary of case data |
|
client: OpenAI client instance |
|
output_dir: Directory to save generated questions |
|
skip_first: Number of initial cases to skip |
|
temperature: LLM temperature parameter |
|
top_p: LLM top_p parameter |
|
max_tokens: Maximum tokens for LLM response |
|
model: LLM model name |
|
""" |
|
target_cases = sorted(list(dataset.keys()), key=int)[-len(dataset) : -skip_first] |
|
|
|
for case_id in tqdm(target_cases, desc="Processing cases"): |
|
case_data = dataset[case_id] |
|
|
|
for category in tqdm(CATEGORY_COMBINATIONS, desc=f"Categories for case {case_id}"): |
|
question = Question( |
|
type="multiple choice (A/B/C/D/E/F)", |
|
difficulty="complex", |
|
case_data=case_data, |
|
categories=category, |
|
sections=DEFAULT_SECTIONS, |
|
system_prompt=SYSTEM_PROMPT, |
|
) |
|
|
|
response = question.create_question( |
|
client=client, |
|
temperature=temperature, |
|
top_p=top_p, |
|
max_tokens=max_tokens, |
|
model=model, |
|
) |
|
question.save(output_dir) |
|
|
|
|
|
def main(): |
|
"""Main execution function.""" |
|
client = openai.OpenAI() |
|
|
|
|
|
dataset = load_eurorad_dataset( |
|
DATASET_PATH, |
|
section="Chest Imaging", |
|
as_dict=True, |
|
filter_by_caption=[ |
|
"xray", |
|
"x-ray", |
|
"x ray", |
|
"ray", |
|
"xr", |
|
"radiograph", |
|
], |
|
) |
|
print(f"\n---\nFound {len(dataset)} cases with X-ray mentions\n---\n") |
|
|
|
|
|
case_data = dataset["16798"] |
|
pprint(case_data, sort_dicts=False) |
|
|
|
|
|
generate_questions(dataset=dataset, client=client, output_dir="benchmark/questions") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|