nasa_expert_agent / tools /nasa_neo_data_fetcher.py
hardknee's picture
Writing changes fixed locally
320acdc verified
import os
import requests
from calendar import monthrange
from smolagents.tools import Tool
from typing import Generator
from dotenv import load_dotenv
load_dotenv()
class NasaNeoDataFetcher(Tool):
name = "nasa_neo_data_fetcher"
description = "Retrieves Neo data from the NASA API for the dates in your query then returns the available Neo data."
inputs = {
"start_date": {
"type": "string",
"description": "The start date of the range to query.",
},
"end_date": {
"type": "string",
"description": "The end date of the range to query.",
},
}
output_type = "string"
def __init__(self):
self.api_key = os.getenv("NASA_API_KEY")
self.root_url = "https://api.nasa.gov/neo/rest/v1/feed?"
self.is_initialized = False
def forward(self, start_date: str, end_date: str) -> list[tuple[str, dict]]:
"""A function to fetch Near Earth Object data from NASA API for a given date range.
Args:
start_date: A string representing the start date of the data to be fetched.
end_date: A string representing the end date of the data to be fetched.
Returns: The data fetched from the API as a list of tuples containing the date and the JSON-like dictionaries.
"""
return self._fetch_neo_data_in_chunks(start_date, end_date)
def _get_nasa_neo_data(self, start_date: str, end_date: str) -> dict:
"""A function to get Near Earth Object data from NASA API for a given date range.
Args:
start_date: A string representing the start date of the data to be fetched.
end_date: A string representing the end date of the data to be fetched.
Returns: The data fetched from the API as a JSON-like dictionary.
"""
params = f"start_date={start_date}&end_date={end_date}&api_key={self.api_key}"
url = self.root_url + params
response = requests.get(url)
if repr(response.status_code) == "200":
data = response.json()
print(f"Successfully fetched data for {start_date} to {end_date}")
return data
else:
print("Error: ", response.status_code)
def _split_date_range_into_chunks(
self, year: int, month: int, first_day: int, last_day: int
) -> list:
days_in_month = monthrange(year, month)[1]
week_dates = []
last_day_in_chunk = first_day + 6
while last_day_in_chunk <= last_day:
start_date = f"{year}-{month}-{first_day}"
end_date = f"{year}-{month}-{last_day_in_chunk}"
week_dates.append((start_date, end_date))
first_day += 7
last_day_in_chunk += 7
remaining_days = days_in_month - first_day + 1
if remaining_days > 0:
start_date = f"{year}-{month}-{first_day}"
end_date = f"{year}-{month}-{last_day}"
week_dates.append((start_date, end_date))
return week_dates
def _split_date_into_ymd(self, date: str) -> tuple[int, int, int]:
split_date = date.split("-")
year = int(split_date[0])
month = int(split_date[1])
day = int(split_date[2])
return year, month, day
def _get_neo_data_chunks(
self, *args: tuple[str, str]
) -> Generator[dict, None, None]:
"""A function that splits a date range into chunks of 7 days.
Args:
*args: A variable number of date ranges to split.
Returns: A generator yielding NASA NEO data dictionaries.
"""
for chunk in self._split_date_range_into_chunks(*args):
start_date, end_date = chunk
data = self._get_nasa_neo_data(start_date=start_date, end_date=end_date)
yield data
def _fetch_neo_data_in_chunks(
self, start_date: str, end_date: str
) -> list[tuple[str, dict]]:
"""A function that fetches Near Earth Object data from NASA API in chunks of 7 days. NB: The API
returns an error when you try to pull data covering longer than 7 days.
Args:
start_date: A string representing the start date of the data to be fetched.
end_date: A string representing the end date of the data to be fetched.
Returns: A list of tuples containing the date and the JSON-like dictionaries of Near Earth Object data fetched from the API.
"""
neo_data = []
start_year, start_month, start_day = self._split_date_into_ymd(start_date)
end_year, end_month, end_day = self._split_date_into_ymd(end_date)
for year in range(start_year, end_year + 1):
for month in range(start_month, end_month + 1):
days_in_month = monthrange(year, month)[1]
if month == start_month and start_month != end_month:
args = (year, month, start_day, days_in_month)
data = list(self._get_neo_data_chunks(*args))
elif month == end_month and start_day != 1:
args = (year, month, start_day, end_day)
data = list(self._get_neo_data_chunks(*args))
elif month == end_month:
args = (year, month, 1, end_day)
data = list(self._get_neo_data_chunks(*args))
else:
args = (year, month, 1, days_in_month)
data = list(self._get_neo_data_chunks(*args))
if data:
for item in data:
if "near_earth_objects" in item:
neo_data.extend(item["near_earth_objects"].items())
return neo_data