Spaces:
Running
Running
Upload 2 files
Browse filesUpload notebook and requirements.txt
- ClinicalTrialsTools_LLM.ipynb +575 -0
- requirements.txt +4 -0
ClinicalTrialsTools_LLM.ipynb
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 84,
|
6 |
+
"id": "1ba926d0-6852-4d81-a2ac-07303d804ead",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": [
|
10 |
+
"# imports\n",
|
11 |
+
"\n",
|
12 |
+
"import os\n",
|
13 |
+
"import json\n",
|
14 |
+
"from dotenv import load_dotenv\n",
|
15 |
+
"from openai import OpenAI\n",
|
16 |
+
"import gradio as gr\n",
|
17 |
+
"import requests\n",
|
18 |
+
"import urllib.parse \n",
|
19 |
+
"from datetime import datetime"
|
20 |
+
]
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"cell_type": "code",
|
24 |
+
"execution_count": 85,
|
25 |
+
"id": "e14ec800-d43d-4f52-8c8a-b31d1a814a1e",
|
26 |
+
"metadata": {},
|
27 |
+
"outputs": [
|
28 |
+
{
|
29 |
+
"name": "stdout",
|
30 |
+
"output_type": "stream",
|
31 |
+
"text": [
|
32 |
+
"OpenAI API Key exists and begins sk-proj-\n"
|
33 |
+
]
|
34 |
+
}
|
35 |
+
],
|
36 |
+
"source": [
|
37 |
+
"# Initialization\n",
|
38 |
+
"\n",
|
39 |
+
"load_dotenv(override=True)\n",
|
40 |
+
"\n",
|
41 |
+
"openai_api_key = os.getenv('OPENAI_API_KEY')\n",
|
42 |
+
"if openai_api_key:\n",
|
43 |
+
" print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
|
44 |
+
"else:\n",
|
45 |
+
" print(\"OpenAI API Key not set\")\n",
|
46 |
+
" \n",
|
47 |
+
"MODEL = \"gpt-4o-mini\"\n",
|
48 |
+
"openai = OpenAI()"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 86,
|
54 |
+
"id": "f4e5fc9e-e11c-4796-945d-f707be42fcaf",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [],
|
57 |
+
"source": [
|
58 |
+
"system_message = (\n",
|
59 |
+
" \"You are a clinical trials assistant that uses the ClinicalTrials.gov API to answer questions. \"\n",
|
60 |
+
" \"If the user asks for study data, you must always call the `search_studies` tool. \"\n",
|
61 |
+
" \"Only use your own knowledge for greetings or general questions.\"\n",
|
62 |
+
" \"Always return detailed and structured responses.\"\n",
|
63 |
+
")"
|
64 |
+
]
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"cell_type": "code",
|
68 |
+
"execution_count": 87,
|
69 |
+
"id": "45b18caa-9c96-4708-83ff-c7552e46c438",
|
70 |
+
"metadata": {},
|
71 |
+
"outputs": [],
|
72 |
+
"source": [
|
73 |
+
"def search_studies(query): \n",
|
74 |
+
" import urllib.parse\n",
|
75 |
+
" import requests\n",
|
76 |
+
"\n",
|
77 |
+
" print(f\"[DEBUG] Received query: {query} (type: {type(query)})\")\n",
|
78 |
+
" \n",
|
79 |
+
" if isinstance(query, dict):\n",
|
80 |
+
" filters = []\n",
|
81 |
+
" query_parts = []\n",
|
82 |
+
"\n",
|
83 |
+
" # ✅ Condition query\n",
|
84 |
+
" if \"query\" in query:\n",
|
85 |
+
" query_string = urllib.parse.quote_plus(query[\"query\"])\n",
|
86 |
+
" query_parts.append(f\"query.cond={query_string}\")\n",
|
87 |
+
"\n",
|
88 |
+
" # ✅ Structured filters (excluding date filters)\n",
|
89 |
+
" if \"phase\" in query:\n",
|
90 |
+
" filters.append(f\"AREA[Phase]{query['phase']}\")\n",
|
91 |
+
" if \"country\" in query:\n",
|
92 |
+
" filters.append(f\"AREA[LocationCountry]{query['country']}\")\n",
|
93 |
+
" if \"study_type\" in query:\n",
|
94 |
+
" filters.append(f\"AREA[StudyType]{query['study_type']}\")\n",
|
95 |
+
" if \"sex\" in query:\n",
|
96 |
+
" filters.append(f\"AREA[Sex]{query['sex']}\")\n",
|
97 |
+
" if \"age_group\" in query:\n",
|
98 |
+
" filters.append(f\"AREA[StdAge]{query['age_group']}\")\n",
|
99 |
+
" if \"status\" in query:\n",
|
100 |
+
" filters.append(f\"AREA[OverallStatus]{query['status']}\")\n",
|
101 |
+
" if \"sampling_method\" in query:\n",
|
102 |
+
" filters.append(f\"AREA[SamplingMethod]{query['sampling_method']}\")\n",
|
103 |
+
" if \"ipd_sharing\" in query:\n",
|
104 |
+
" filters.append(f\"AREA[IPDSharing]{query['ipd_sharing']}\")\n",
|
105 |
+
" if \"expanded_access\" in query:\n",
|
106 |
+
" filters.append(f\"AREA[ExpandedAccess]{query['expanded_access']}\")\n",
|
107 |
+
"\n",
|
108 |
+
" page_size = query.get(\"max_results\", 3)\n",
|
109 |
+
"\n",
|
110 |
+
" # ❗ Remove date filters from API query\n",
|
111 |
+
" filter_advanced = \" AND \".join(filters)\n",
|
112 |
+
" if filter_advanced:\n",
|
113 |
+
" filter_advanced = f\"({filter_advanced})\"\n",
|
114 |
+
" encoded_filter = urllib.parse.quote(filter_advanced, safe=\"[]*\")\n",
|
115 |
+
"\n",
|
116 |
+
" url = (\n",
|
117 |
+
" f\"https://clinicaltrials.gov/api/v2/studies?\"\n",
|
118 |
+
" f\"{'&'.join(query_parts)}\"\n",
|
119 |
+
" f\"{f'&filter.advanced={encoded_filter}' if filter_advanced else ''}\"\n",
|
120 |
+
" f\"&pageSize={page_size}\"\n",
|
121 |
+
" )\n",
|
122 |
+
" \n",
|
123 |
+
" else:\n",
|
124 |
+
" encoded_query = urllib.parse.quote_plus(query)\n",
|
125 |
+
" url = f\"https://clinicaltrials.gov/api/v2/studies?query.cond={encoded_query}&pageSize=3\"\n",
|
126 |
+
"\n",
|
127 |
+
" print(\"Requesting:\", url)\n",
|
128 |
+
"\n",
|
129 |
+
" try:\n",
|
130 |
+
" response = requests.get(url)\n",
|
131 |
+
" print(\"Status Code:\", response.status_code)\n",
|
132 |
+
"\n",
|
133 |
+
" # 🔁 Fallback if query.cond fails\n",
|
134 |
+
" if response.status_code == 400 and any(\"query.cond=\" in part for part in query_parts):\n",
|
135 |
+
" print(\"[⚠️ Fallback] Retrying with query.text instead of query.cond...\")\n",
|
136 |
+
" query_parts = [p.replace(\"query.cond=\", \"query.text=\") for p in query_parts]\n",
|
137 |
+
" url = (\n",
|
138 |
+
" f\"https://clinicaltrials.gov/api/v2/studies?\"\n",
|
139 |
+
" f\"{'&'.join(query_parts)}\"\n",
|
140 |
+
" f\"{f'&filter.advanced={encoded_filter}' if filter_advanced else ''}\"\n",
|
141 |
+
" f\"&pageSize={page_size}\"\n",
|
142 |
+
" )\n",
|
143 |
+
" print(\"Retrying:\", url)\n",
|
144 |
+
" response = requests.get(url)\n",
|
145 |
+
" print(\"Retry Status Code:\", response.status_code)\n",
|
146 |
+
"\n",
|
147 |
+
" if response.status_code == 200:\n",
|
148 |
+
" data = response.json()\n",
|
149 |
+
" trials = data.get(\"studies\", [])\n",
|
150 |
+
" if not trials:\n",
|
151 |
+
" return \"No studies found.\"\n",
|
152 |
+
"\n",
|
153 |
+
" # ✅ Post-filter\n",
|
154 |
+
" def matches(study, key, value):\n",
|
155 |
+
" section = study.get(\"protocolSection\", {})\n",
|
156 |
+
" if key == \"sponsor\":\n",
|
157 |
+
" return value.lower() in section.get(\"sponsorCollaboratorsModule\", {}).get(\"leadSponsor\", {}).get(\"name\", \"\").lower()\n",
|
158 |
+
" elif key == \"intervention\":\n",
|
159 |
+
" interventions = section.get(\"armsInterventionsModule\", {}).get(\"interventions\", [])\n",
|
160 |
+
" return any(value.lower() in i.get(\"name\", \"\").lower() for i in interventions)\n",
|
161 |
+
" return True\n",
|
162 |
+
"\n",
|
163 |
+
" if \"sponsor\" in query:\n",
|
164 |
+
" trials = [s for s in trials if matches(s, \"sponsor\", query[\"sponsor\"])]\n",
|
165 |
+
" if \"intervention\" in query:\n",
|
166 |
+
" trials = [s for s in trials if matches(s, \"intervention\", query[\"intervention\"])]\n",
|
167 |
+
"\n",
|
168 |
+
" if not trials:\n",
|
169 |
+
" return \"No studies found after applying filters.\"\n",
|
170 |
+
"\n",
|
171 |
+
" result = []\n",
|
172 |
+
" for study in trials:\n",
|
173 |
+
" ps = study.get(\"protocolSection\", {})\n",
|
174 |
+
" id_module = ps.get(\"identificationModule\", {})\n",
|
175 |
+
" design_module = ps.get(\"designModule\", {})\n",
|
176 |
+
" status_module = ps.get(\"statusModule\", {})\n",
|
177 |
+
" elig_module = ps.get(\"eligibilityModule\", {})\n",
|
178 |
+
" ipd_module = ps.get(\"ipdSharingStatementModule\", {})\n",
|
179 |
+
" desc_module = ps.get(\"descriptionModule\", {})\n",
|
180 |
+
" contact_module = ps.get(\"contactsLocationsModule\", {})\n",
|
181 |
+
" sponsor_module = ps.get(\"sponsorCollaboratorsModule\", {})\n",
|
182 |
+
" outcomes_module = ps.get(\"outcomesModule\", {})\n",
|
183 |
+
" arms_module = ps.get(\"armsInterventionsModule\", {})\n",
|
184 |
+
"\n",
|
185 |
+
" nct_id = id_module.get(\"nctId\", \"N/A\")\n",
|
186 |
+
" title = id_module.get(\"briefTitle\", \"No Title\")\n",
|
187 |
+
" official_title = id_module.get(\"officialTitle\", \"N/A\")\n",
|
188 |
+
" phases = design_module.get(\"phases\", [])\n",
|
189 |
+
" study_type = design_module.get(\"studyType\", \"N/A\")\n",
|
190 |
+
" status = status_module.get(\"overallStatus\", \"N/A\")\n",
|
191 |
+
" ipd_sharing = ipd_module.get(\"ipdSharing\", \"N/A\")\n",
|
192 |
+
" expanded_access = status_module.get(\"expandedAccessInfo\", {}).get(\"hasExpandedAccess\", \"N/A\")\n",
|
193 |
+
" sex = elig_module.get(\"sex\", \"N/A\")\n",
|
194 |
+
" std_ages = elig_module.get(\"stdAges\", [])\n",
|
195 |
+
" age_range = \", \".join(std_ages) if std_ages else \"N/A\"\n",
|
196 |
+
" sampling_method = elig_module.get(\"samplingMethod\", \"N/A\")\n",
|
197 |
+
" criteria = elig_module.get(\"eligibilityCriteria\", \"N/A\")\n",
|
198 |
+
" start_date = status_module.get(\"startDateStruct\", {}).get(\"date\", \"N/A\")\n",
|
199 |
+
" completion_date = status_module.get(\"completionDateStruct\", {}).get(\"date\", \"N/A\")\n",
|
200 |
+
" locations = contact_module.get(\"locations\", [])\n",
|
201 |
+
" countries = sorted({loc.get(\"country\") for loc in locations if loc.get(\"country\")})\n",
|
202 |
+
"\n",
|
203 |
+
" location_lines = []\n",
|
204 |
+
" for loc in locations:\n",
|
205 |
+
" parts = [loc.get(\"facility\"), loc.get(\"city\"), loc.get(\"state\"), loc.get(\"country\")]\n",
|
206 |
+
" clean = [p for p in parts if p]\n",
|
207 |
+
" if clean:\n",
|
208 |
+
" location_lines.append(\", \".join(clean))\n",
|
209 |
+
" locations_text = \"\\n\".join(f\"- {line}\" for line in location_lines) if location_lines else \"N/A\"\n",
|
210 |
+
"\n",
|
211 |
+
" description = desc_module.get(\"detailedDescription\", \"N/A\")\n",
|
212 |
+
" interventions = arms_module.get(\"interventions\", [])\n",
|
213 |
+
" intervention_names = [iv.get(\"name\", \"\") for iv in interventions if iv.get(\"name\")]\n",
|
214 |
+
" intervention_text = \", \".join(intervention_names) if intervention_names else \"N/A\"\n",
|
215 |
+
" sponsor = sponsor_module.get(\"leadSponsor\", {}).get(\"name\", \"N/A\")\n",
|
216 |
+
" collaborators = sponsor_module.get(\"collaborators\", [])\n",
|
217 |
+
" collaborator_names = [c.get(\"name\", \"\") for c in collaborators]\n",
|
218 |
+
"\n",
|
219 |
+
" def format_outcomes(lst, label):\n",
|
220 |
+
" out = []\n",
|
221 |
+
" for o in lst:\n",
|
222 |
+
" out.append(f\"- **{o.get('measure')}** ({o.get('timeFrame', 'N/A')}): {o.get('description', '')}\")\n",
|
223 |
+
" return f\"\\n**{label}:**\\n\" + \"\\n\".join(out) if out else \"\"\n",
|
224 |
+
"\n",
|
225 |
+
" outcomes = (\n",
|
226 |
+
" format_outcomes(outcomes_module.get(\"primaryOutcomes\", []), \"Primary Outcomes\") +\n",
|
227 |
+
" format_outcomes(outcomes_module.get(\"secondaryOutcomes\", []), \"Secondary Outcomes\") +\n",
|
228 |
+
" format_outcomes(outcomes_module.get(\"otherOutcomes\", []), \"Other Outcomes\")\n",
|
229 |
+
" )\n",
|
230 |
+
"\n",
|
231 |
+
" arms = []\n",
|
232 |
+
" for group in arms_module.get(\"armGroups\", []):\n",
|
233 |
+
" label = group.get(\"label\", \"N/A\")\n",
|
234 |
+
" gtype = group.get(\"type\", \"N/A\")\n",
|
235 |
+
" desc = group.get(\"description\")\n",
|
236 |
+
" arms.append(f\"- **{label}** ({gtype}): {desc.strip()}\" if desc else f\"- **{label}** ({gtype})\")\n",
|
237 |
+
"\n",
|
238 |
+
" arms_text = \"**Arms & Groups:**\\n\" + \"\\n\".join(arms) if arms else \"\"\n",
|
239 |
+
" phase_text = ', '.join(phases) if study_type.upper() == \"INTERVENTIONAL\" else \"Not applicable (Observational study)\"\n",
|
240 |
+
" ctgov_link = f\"https://clinicaltrials.gov/study/{nct_id}\"\n",
|
241 |
+
" description_block = f\"**Detailed Description:**\\n{description.strip()}\\n\\n\" if description and description != \"N/A\" else \"\"\n",
|
242 |
+
"\n",
|
243 |
+
" result.append(\n",
|
244 |
+
" f\"### 🧪 {title}\\n\\n\"\n",
|
245 |
+
" f\"**NCT ID:** `{nct_id}`\\n\"\n",
|
246 |
+
" f\"🔗 [View on ClinicalTrials.gov]({ctgov_link})\\n\\n\"\n",
|
247 |
+
" f\"**Start Date (for filtering):** {start_date}\\n\"\n",
|
248 |
+
" f\"**Completion Date (for filtering):** {completion_date}\\n\\n\"\n",
|
249 |
+
" f\"**Official Title:** {official_title}\\n\"\n",
|
250 |
+
" f\"**Type:** {study_type.title()}\\n\"\n",
|
251 |
+
" f\"**Phase:** {phase_text}\\n\"\n",
|
252 |
+
" f\"**Status:** {status}\\n\"\n",
|
253 |
+
" f\"**Country:** {', '.join(countries) if countries else 'N/A'}\\n\"\n",
|
254 |
+
" f\"**Interventions:** {intervention_text}\\n\"\n",
|
255 |
+
" f\"**Sponsor:** {sponsor}\\n\"\n",
|
256 |
+
" f\"**Collaborators:** {', '.join(collaborator_names) if collaborator_names else 'None'}\\n\\n\"\n",
|
257 |
+
" )\n",
|
258 |
+
"\n",
|
259 |
+
" return \"\\n\\n---\\n\\n\".join(result).strip()\n",
|
260 |
+
"\n",
|
261 |
+
" return f\"API returned error: {response.status_code}\"\n",
|
262 |
+
"\n",
|
263 |
+
" except Exception as e:\n",
|
264 |
+
" print(\"Exception occurred:\", e)\n",
|
265 |
+
" return \"Error fetching study data.\""
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": 88,
|
271 |
+
"id": "8d3a9ddb-36cd-437f-87d9-295397d62a53",
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [],
|
274 |
+
"source": [
|
275 |
+
"# There's a particular dictionary structure that's required to describe our function:\n",
|
276 |
+
"\n",
|
277 |
+
"search_function = {\n",
|
278 |
+
" \"name\": \"search_studies\",\n",
|
279 |
+
" \"description\": \"Search for clinical trials with strict filtering on all key metadata fields such as condition, country, phase, study type, sex, age group, sampling method, sponsor, collaborators, intervention, start dates, completion dates, etc.\",\n",
|
280 |
+
" \"parameters\": {\n",
|
281 |
+
" \"type\": \"object\",\n",
|
282 |
+
" \"properties\": {\n",
|
283 |
+
" \"query\": {\n",
|
284 |
+
" \"type\": \"string\",\n",
|
285 |
+
" \"description\": \"Condition or keyword to search for. (e.g., 'lung cancer', 'IBD')\"\n",
|
286 |
+
" },\n",
|
287 |
+
" \"phase\": {\n",
|
288 |
+
" \"type\": \"string\",\n",
|
289 |
+
" \"description\": \"Clinical trial phase. (e.g., 'Phase 1', 'Phase 2', 'Phase 3')\"\n",
|
290 |
+
" },\n",
|
291 |
+
" \"status\": {\n",
|
292 |
+
" \"type\": \"string\",\n",
|
293 |
+
" \"description\": \"Recruitment status. (e.g., 'RECRUITING', 'COMPLETED')\"\n",
|
294 |
+
" },\n",
|
295 |
+
" \"country\": {\n",
|
296 |
+
" \"type\": \"string\",\n",
|
297 |
+
" \"description\": \"Country where the trial is conducted. (e.g., 'Italy')\"\n",
|
298 |
+
" },\n",
|
299 |
+
" \"study_type\": {\n",
|
300 |
+
" \"type\": \"string\",\n",
|
301 |
+
" \"description\": \"Type of study. (e.g., 'INTERVENTIONAL', 'OBSERVATIONAL')\"\n",
|
302 |
+
" },\n",
|
303 |
+
" \"sex\": {\n",
|
304 |
+
" \"type\": \"string\",\n",
|
305 |
+
" \"description\": \"Sex eligibility. (e.g., 'Male', 'Female', 'All')\"\n",
|
306 |
+
" },\n",
|
307 |
+
" \"age_group\": {\n",
|
308 |
+
" \"type\": \"string\",\n",
|
309 |
+
" \"description\": \"Standard age group. (e.g., 'CHILD', 'ADULT', 'OLDER_ADULT')\"\n",
|
310 |
+
" },\n",
|
311 |
+
" \"sampling_method\": {\n",
|
312 |
+
" \"type\": \"string\",\n",
|
313 |
+
" \"description\": \"Participant sampling method. (e.g., 'PROBABILITY_SAMPLE', 'NON_PROBABILITY_SAMPLE')\"\n",
|
314 |
+
" },\n",
|
315 |
+
" \"intervention\": {\n",
|
316 |
+
" \"type\": \"string\",\n",
|
317 |
+
" \"description\": \"Intervention or treatment keyword. (e.g., 'aspirin', 'TAE')\"\n",
|
318 |
+
" },\n",
|
319 |
+
" \"sponsor\": {\n",
|
320 |
+
" \"type\": \"string\",\n",
|
321 |
+
" \"description\": \"Name of the lead sponsor or organization. (e.g., 'Pfizer', 'NIH')\"\n",
|
322 |
+
" },\n",
|
323 |
+
" \"ipd_sharing\": {\n",
|
324 |
+
" \"type\": \"string\",\n",
|
325 |
+
" \"description\": \"Will individual participant data (IPD) be shared? (e.g., 'YES', 'NO', 'UND')\"\n",
|
326 |
+
" },\n",
|
327 |
+
" \"expanded_access\": {\n",
|
328 |
+
" \"type\": \"string\",\n",
|
329 |
+
" \"description\": \"Whether expanded access is available. (e.g., 'YES', 'NO', 'UNKNOWN')\"\n",
|
330 |
+
" },\n",
|
331 |
+
" \"start_date_from\": {\n",
|
332 |
+
" \"type\": \"string\",\n",
|
333 |
+
" \"description\": \"Earliest start date allowed (format: YYYY-MM or YYYY-MM-DD)\"\n",
|
334 |
+
" },\n",
|
335 |
+
" \"start_date_to\": {\n",
|
336 |
+
" \"type\": \"string\",\n",
|
337 |
+
" \"description\": \"Latest start date allowed\"\n",
|
338 |
+
" },\n",
|
339 |
+
" \"completion_date_from\": {\n",
|
340 |
+
" \"type\": \"string\",\n",
|
341 |
+
" \"description\": \"Earliest completion date allowed\"\n",
|
342 |
+
" },\n",
|
343 |
+
" \"completion_date_to\": {\n",
|
344 |
+
" \"type\": \"string\",\n",
|
345 |
+
" \"description\": \"Latest completion date allowed\"\n",
|
346 |
+
" },\n",
|
347 |
+
" \"max_results\": {\n",
|
348 |
+
" \"type\": \"integer\",\n",
|
349 |
+
" \"description\": \"Maximum number of studies to return\"\n",
|
350 |
+
" }\n",
|
351 |
+
" },\n",
|
352 |
+
" \"required\": [\"query\"],\n",
|
353 |
+
" \"additionalProperties\": False\n",
|
354 |
+
" }\n",
|
355 |
+
"}"
|
356 |
+
]
|
357 |
+
},
|
358 |
+
{
|
359 |
+
"cell_type": "code",
|
360 |
+
"execution_count": 89,
|
361 |
+
"id": "657037fe-ba29-4ee4-b0e3-7eff225189b2",
|
362 |
+
"metadata": {},
|
363 |
+
"outputs": [],
|
364 |
+
"source": [
|
365 |
+
"# And this is included in a list of tools:\n",
|
366 |
+
"\n",
|
367 |
+
"tools = [{\"type\": \"function\", \"function\": search_function}]"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": 90,
|
373 |
+
"id": "f6fdb4aa-a9a8-41f9-b650-0419047816f5",
|
374 |
+
"metadata": {},
|
375 |
+
"outputs": [],
|
376 |
+
"source": [
|
377 |
+
"def chat(message, history):\n",
|
378 |
+
" messages = [{\"role\": \"system\", \"content\": system_message}] + history + [{\"role\": \"user\", \"content\": message}]\n",
|
379 |
+
"\n",
|
380 |
+
" # 🔄 First attempt: try to stream the LLM output\n",
|
381 |
+
" response_stream = openai.chat.completions.create(\n",
|
382 |
+
" model=MODEL,\n",
|
383 |
+
" messages=messages,\n",
|
384 |
+
" tools=tools,\n",
|
385 |
+
" tool_choice=\"auto\",\n",
|
386 |
+
" stream=True\n",
|
387 |
+
" )\n",
|
388 |
+
"\n",
|
389 |
+
" full_response = \"\"\n",
|
390 |
+
" tool_call_detected = False\n",
|
391 |
+
"\n",
|
392 |
+
" for chunk in response_stream:\n",
|
393 |
+
" choice = chunk.choices[0]\n",
|
394 |
+
" delta = choice.delta\n",
|
395 |
+
"\n",
|
396 |
+
" # 🧠 Detect tool call request during stream\n",
|
397 |
+
" if hasattr(delta, \"tool_calls\") and delta.tool_calls:\n",
|
398 |
+
" tool_call_detected = True\n",
|
399 |
+
" break # Exit streaming — can't continue past tool call\n",
|
400 |
+
"\n",
|
401 |
+
" if delta.content:\n",
|
402 |
+
" full_response += delta.content\n",
|
403 |
+
" yield full_response # Live stream to user\n",
|
404 |
+
"\n",
|
405 |
+
" # 🧰 Tool call fallback (non-streamed)\n",
|
406 |
+
" if tool_call_detected:\n",
|
407 |
+
" fallback = openai.chat.completions.create(\n",
|
408 |
+
" model=MODEL,\n",
|
409 |
+
" messages=messages,\n",
|
410 |
+
" tools=tools,\n",
|
411 |
+
" tool_choice=\"auto\" # No stream here, required to get tool_calls\n",
|
412 |
+
" )\n",
|
413 |
+
"\n",
|
414 |
+
" message = fallback.choices[0].message\n",
|
415 |
+
" print(\"Finish reason:\", fallback.choices[0].finish_reason)\n",
|
416 |
+
" print(\"Tool calls:\", message.tool_calls if hasattr(message, 'tool_calls') else None)\n",
|
417 |
+
"\n",
|
418 |
+
" # 🔧 Call the tool(s)\n",
|
419 |
+
" tool_responses = handle_tool_call(message)\n",
|
420 |
+
"\n",
|
421 |
+
" # Add the assistant tool call message and all corresponding tool responses\n",
|
422 |
+
" messages.append(message)\n",
|
423 |
+
" messages.extend(tool_responses)\n",
|
424 |
+
"\n",
|
425 |
+
" # 🧠 Now ask GPT to summarize the tool result(s)\n",
|
426 |
+
" final_response_stream = openai.chat.completions.create(\n",
|
427 |
+
" model=MODEL,\n",
|
428 |
+
" messages=messages,\n",
|
429 |
+
" stream=True\n",
|
430 |
+
" )\n",
|
431 |
+
"\n",
|
432 |
+
" final_output = \"\"\n",
|
433 |
+
" for chunk in final_response_stream:\n",
|
434 |
+
" delta = chunk.choices[0].delta\n",
|
435 |
+
" if delta.content:\n",
|
436 |
+
" final_output += delta.content\n",
|
437 |
+
" yield final_output # Stream final GPT summary\n",
|
438 |
+
"\n",
|
439 |
+
" # 🧯 Final fallback if nothing streamed\n",
|
440 |
+
" elif not full_response:\n",
|
441 |
+
" fallback = openai.chat.completions.create(\n",
|
442 |
+
" model=MODEL,\n",
|
443 |
+
" messages=messages,\n",
|
444 |
+
" tools=tools,\n",
|
445 |
+
" tool_choice=\"auto\"\n",
|
446 |
+
" )\n",
|
447 |
+
" yield fallback.choices[0].message.content"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "code",
|
452 |
+
"execution_count": 91,
|
453 |
+
"id": "c5963527-a08f-4da4-997a-958fb4e09ab7",
|
454 |
+
"metadata": {},
|
455 |
+
"outputs": [],
|
456 |
+
"source": [
|
457 |
+
"def handle_tool_call(message):\n",
|
458 |
+
" import json\n",
|
459 |
+
"\n",
|
460 |
+
" tool_responses = []\n",
|
461 |
+
"\n",
|
462 |
+
" for tool_call in message.tool_calls:\n",
|
463 |
+
" arguments = json.loads(tool_call.function.arguments)\n",
|
464 |
+
" result = search_studies(arguments)\n",
|
465 |
+
"\n",
|
466 |
+
" tool_responses.append({\n",
|
467 |
+
" \"role\": \"tool\",\n",
|
468 |
+
" \"tool_call_id\": tool_call.id,\n",
|
469 |
+
" \"content\": result if isinstance(result, str) else json.dumps(result)\n",
|
470 |
+
" })\n",
|
471 |
+
"\n",
|
472 |
+
" return tool_responses"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "code",
|
477 |
+
"execution_count": 92,
|
478 |
+
"id": "e4529594-eae4-4042-be76-d56f26cb0467",
|
479 |
+
"metadata": {
|
480 |
+
"scrolled": true
|
481 |
+
},
|
482 |
+
"outputs": [
|
483 |
+
{
|
484 |
+
"name": "stdout",
|
485 |
+
"output_type": "stream",
|
486 |
+
"text": [
|
487 |
+
"* Running on local URL: http://127.0.0.1:7867\n",
|
488 |
+
"\n",
|
489 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
490 |
+
]
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"data": {
|
494 |
+
"text/html": [
|
495 |
+
"<div><iframe src=\"http://127.0.0.1:7867/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
|
496 |
+
],
|
497 |
+
"text/plain": [
|
498 |
+
"<IPython.core.display.HTML object>"
|
499 |
+
]
|
500 |
+
},
|
501 |
+
"metadata": {},
|
502 |
+
"output_type": "display_data"
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"data": {
|
506 |
+
"text/plain": []
|
507 |
+
},
|
508 |
+
"execution_count": 92,
|
509 |
+
"metadata": {},
|
510 |
+
"output_type": "execute_result"
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"name": "stdout",
|
514 |
+
"output_type": "stream",
|
515 |
+
"text": [
|
516 |
+
"Finish reason: tool_calls\n",
|
517 |
+
"Tool calls: [ChatCompletionMessageToolCall(id='call_OKLKZYud2q7uxJVw5NXn1TpH', function=Function(arguments='{\"query\":\"Asthma\",\"status\":\"RECRUITING\",\"start_date_from\":\"2022-01-01\",\"max_results\":10}', name='search_studies'), type='function')]\n",
|
518 |
+
"[DEBUG] Received query: {'query': 'Asthma', 'status': 'RECRUITING', 'start_date_from': '2022-01-01', 'max_results': 10} (type: <class 'dict'>)\n",
|
519 |
+
"Requesting: https://clinicaltrials.gov/api/v2/studies?query.cond=Asthma&filter.advanced=%28AREA[OverallStatus]RECRUITING%29&pageSize=10\n",
|
520 |
+
"Status Code: 200\n",
|
521 |
+
"Finish reason: tool_calls\n",
|
522 |
+
"Tool calls: [ChatCompletionMessageToolCall(id='call_5sqJJdsQTLkqh9FfziUkr7Up', function=Function(arguments='{\"query\":\"NCT06100289\"}', name='search_studies'), type='function')]\n",
|
523 |
+
"[DEBUG] Received query: {'query': 'NCT06100289'} (type: <class 'dict'>)\n",
|
524 |
+
"Requesting: https://clinicaltrials.gov/api/v2/studies?query.cond=NCT06100289&pageSize=3\n",
|
525 |
+
"Status Code: 200\n"
|
526 |
+
]
|
527 |
+
}
|
528 |
+
],
|
529 |
+
"source": [
|
530 |
+
"example_prompts = [\n",
|
531 |
+
" \"Show me trials in Taiwan studying Vedolizumab\",\n",
|
532 |
+
" \"List studies for Crohn's disease that started after 2015\",\n",
|
533 |
+
" \"Give me 5 completed trials on lung cancer in Japan\",\n",
|
534 |
+
" \"Find interventional Phase 3 studies for breast cancer in France\",\n",
|
535 |
+
" \"List observational studies in Asia with female participants over 65.\",\n",
|
536 |
+
" \"Get details for NCT06100289\",\n",
|
537 |
+
" \"Show studies that started after 2022 for Asthma and are still ongoing.\"\n",
|
538 |
+
"]\n",
|
539 |
+
"\n",
|
540 |
+
"gr.ChatInterface(\n",
|
541 |
+
" fn=chat,\n",
|
542 |
+
" type=\"messages\",\n",
|
543 |
+
" title=\"ClinicalTrials.gov Agent\",\n",
|
544 |
+
" description=(\n",
|
545 |
+
" \"Ask about medical conditions, NCT IDs, trial phases, study types, recruitment status, interventions, sponsors, age groups, sex, sampling methods, IPD sharing, expanded access, countries, locations, and date ranges (start/completion).\\n\\n\"\n",
|
546 |
+
" \"💡 You can also try one of the examples below to get started.\"\n",
|
547 |
+
" ),\n",
|
548 |
+
" chatbot=gr.Chatbot(label=\"CTGagent\", type=\"messages\"),\n",
|
549 |
+
" examples=example_prompts\n",
|
550 |
+
").launch(app_kwargs={\"title\": \"CTGagent\"})\n"
|
551 |
+
]
|
552 |
+
}
|
553 |
+
],
|
554 |
+
"metadata": {
|
555 |
+
"kernelspec": {
|
556 |
+
"display_name": "Python 3 (ipykernel)",
|
557 |
+
"language": "python",
|
558 |
+
"name": "python3"
|
559 |
+
},
|
560 |
+
"language_info": {
|
561 |
+
"codemirror_mode": {
|
562 |
+
"name": "ipython",
|
563 |
+
"version": 3
|
564 |
+
},
|
565 |
+
"file_extension": ".py",
|
566 |
+
"mimetype": "text/x-python",
|
567 |
+
"name": "python",
|
568 |
+
"nbconvert_exporter": "python",
|
569 |
+
"pygments_lexer": "ipython3",
|
570 |
+
"version": "3.11.11"
|
571 |
+
}
|
572 |
+
},
|
573 |
+
"nbformat": 4,
|
574 |
+
"nbformat_minor": 5
|
575 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
openai
|
2 |
+
python-dotenv
|
3 |
+
gradio
|
4 |
+
requests
|