Skip to content

Commit b0aa389

Browse files
thodson-usgsclaude
andcommitted
Validate monitoring_location_id format in waterdata functions
Passing an integer (e.g. 5129115) or a bare string without an agency prefix (e.g. "dog") to any waterdata function silently wasted an API call and returned empty data. Now all ten public functions that accept monitoring_location_id raise before touching the network: - TypeError if the value is not a string or list of strings - ValueError if any string doesn't match the 'AGENCY-ID' format (e.g. 'USGS-01646500') Closes #188. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent c4d0f84 commit b0aa389

3 files changed

Lines changed: 112 additions & 1 deletion

File tree

dataretrieval/waterdata/api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
)
2525
from dataretrieval.waterdata.utils import (
2626
SAMPLES_URL,
27+
_check_monitoring_location_id,
2728
_check_profiles,
2829
_default_headers,
2930
_get_args,
@@ -205,6 +206,7 @@ def get_daily(
205206
... approval_status = "Approved",
206207
... time = "2024-01-01/.."
207208
"""
209+
_check_monitoring_location_id(monitoring_location_id)
208210
service = "daily"
209211
output_id = "daily_id"
210212

@@ -371,6 +373,7 @@ def get_continuous(
371373
... time="2021-01-01T00:00:00Z/2022-01-01T00:00:00Z",
372374
... )
373375
"""
376+
_check_monitoring_location_id(monitoring_location_id)
374377
service = "continuous"
375378
output_id = "continuous_id"
376379

@@ -662,6 +665,7 @@ def get_monitoring_locations(
662665
... properties=["monitoring_location_id", "state_name", "country_name"],
663666
... )
664667
"""
668+
_check_monitoring_location_id(monitoring_location_id)
665669
service = "monitoring-locations"
666670
output_id = "monitoring_location_id"
667671

@@ -878,6 +882,7 @@ def get_time_series_metadata(
878882
... begin="1990-01-01/..",
879883
... )
880884
"""
885+
_check_monitoring_location_id(monitoring_location_id)
881886
service = "time-series-metadata"
882887
output_id = "time_series_id"
883888

@@ -1050,6 +1055,7 @@ def get_latest_continuous(
10501055
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
10511056
... )
10521057
"""
1058+
_check_monitoring_location_id(monitoring_location_id)
10531059
service = "latest-continuous"
10541060
output_id = "latest_continuous_id"
10551061

@@ -1224,6 +1230,7 @@ def get_latest_daily(
12241230
... monitoring_location_id=["USGS-05114000", "USGS-09423350"]
12251231
... )
12261232
"""
1233+
_check_monitoring_location_id(monitoring_location_id)
12271234
service = "latest-daily"
12281235
output_id = "latest_daily_id"
12291236

@@ -1397,6 +1404,7 @@ def get_field_measurements(
13971404
... time = "P20Y"
13981405
... )
13991406
"""
1407+
_check_monitoring_location_id(monitoring_location_id)
14001408
service = "field-measurements"
14011409
output_id = "field_measurement_id"
14021410

@@ -1850,6 +1858,7 @@ def get_stats_por(
18501858
... )
18511859
"""
18521860
# Build argument dictionary, omitting None values
1861+
_check_monitoring_location_id(monitoring_location_id)
18531862
params = _get_args(locals(), exclude={"expand_percentiles"})
18541863

18551864
return get_stats_data(
@@ -1979,6 +1988,7 @@ def get_stats_date_range(
19791988
... )
19801989
"""
19811990
# Build argument dictionary, omitting None values
1991+
_check_monitoring_location_id(monitoring_location_id)
19821992
params = _get_args(locals(), exclude={"expand_percentiles"})
19831993

19841994
return get_stats_data(
@@ -2144,6 +2154,7 @@ def get_channel(
21442154
... monitoring_location_id="USGS-02238500",
21452155
... )
21462156
"""
2157+
_check_monitoring_location_id(monitoring_location_id)
21472158
service = "channel-measurements"
21482159
output_id = "channel_measurements_id"
21492160

dataretrieval/waterdata/utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1109,6 +1109,57 @@ def _check_profiles(
11091109
)
11101110

11111111

1112+
_MONITORING_LOCATION_ID_RE = re.compile(r"^.+-.+$")
1113+
1114+
1115+
def _check_monitoring_location_id(
1116+
monitoring_location_id: str | list[str] | None,
1117+
) -> None:
1118+
"""Validate the format of a monitoring_location_id value.
1119+
1120+
Parameters
1121+
----------
1122+
monitoring_location_id : str, list of str, or None
1123+
One or more monitoring location identifiers.
1124+
1125+
Raises
1126+
------
1127+
TypeError
1128+
If any identifier is not a string (e.g. an integer was passed).
1129+
ValueError
1130+
If any string identifier does not follow the required
1131+
``'AGENCY-ID'`` format (e.g. ``'USGS-01646500'``).
1132+
"""
1133+
if monitoring_location_id is None:
1134+
return
1135+
1136+
if not isinstance(monitoring_location_id, (str, list)):
1137+
raise TypeError(
1138+
f"monitoring_location_id must be a string or list of strings, "
1139+
f"not {type(monitoring_location_id).__name__}. "
1140+
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{monitoring_location_id}'."
1141+
)
1142+
1143+
ids = (
1144+
[monitoring_location_id]
1145+
if isinstance(monitoring_location_id, str)
1146+
else monitoring_location_id
1147+
)
1148+
1149+
for id_ in ids:
1150+
if not isinstance(id_, str):
1151+
raise TypeError(
1152+
f"monitoring_location_id must be a string or list of strings, "
1153+
f"not {type(id_).__name__}. "
1154+
f"Expected format: 'AGENCY-ID', e.g., 'USGS-{id_}'."
1155+
)
1156+
if not _MONITORING_LOCATION_ID_RE.match(id_):
1157+
raise ValueError(
1158+
f"Invalid monitoring_location_id: {id_!r}. "
1159+
f"Expected 'AGENCY-ID' format, e.g., 'USGS-01646500'."
1160+
)
1161+
1162+
11121163
def _get_args(
11131164
local_vars: dict[str, Any], exclude: set[str] | None = None
11141165
) -> dict[str, Any]:

tests/waterdata_test.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
get_stats_por,
2222
get_time_series_metadata,
2323
)
24-
from dataretrieval.waterdata.utils import _check_profiles
24+
from dataretrieval.waterdata.utils import _check_monitoring_location_id, _check_profiles
2525

2626

2727
def mock_request(requests_mock, request_url, file_path):
@@ -380,3 +380,52 @@ def test_get_channel():
380380
assert df.shape[0] > 470
381381
assert df.shape[1] == 27 # if geopandas installed, 21 columns if not
382382
assert "channel_measurements_id" in df.columns
383+
384+
385+
class TestCheckMonitoringLocationId:
386+
"""Tests for _check_monitoring_location_id input validation.
387+
388+
Regression tests for GitHub issue #188.
389+
"""
390+
391+
def test_valid_string(self):
392+
"""A correctly formatted string passes without error."""
393+
_check_monitoring_location_id("USGS-01646500")
394+
395+
def test_valid_list(self):
396+
"""A list of correctly formatted strings passes without error."""
397+
_check_monitoring_location_id(["USGS-01646500", "USGS-02238500"])
398+
399+
def test_none_passes(self):
400+
"""None is allowed (optional parameter)."""
401+
_check_monitoring_location_id(None)
402+
403+
def test_integer_raises_type_error(self):
404+
"""An integer ID raises TypeError with a helpful message."""
405+
with pytest.raises(TypeError, match="not int"):
406+
_check_monitoring_location_id(5129115)
407+
408+
def test_integer_in_list_raises_type_error(self):
409+
"""An integer inside a list raises TypeError."""
410+
with pytest.raises(TypeError, match="not int"):
411+
_check_monitoring_location_id(["USGS-01646500", 5129115])
412+
413+
def test_missing_agency_prefix_raises_value_error(self):
414+
"""A string without the AGENCY- prefix raises ValueError."""
415+
with pytest.raises(ValueError, match="Invalid monitoring_location_id"):
416+
_check_monitoring_location_id("dog")
417+
418+
def test_bare_site_number_raises_value_error(self):
419+
"""A bare site number string (no agency prefix) raises ValueError."""
420+
with pytest.raises(ValueError, match="Invalid monitoring_location_id"):
421+
_check_monitoring_location_id("01646500")
422+
423+
def test_get_daily_integer_id_raises(self):
424+
"""get_daily raises TypeError before making any network call."""
425+
with pytest.raises(TypeError):
426+
get_daily(monitoring_location_id=5129115, parameter_code="00060")
427+
428+
def test_get_daily_malformed_id_raises(self):
429+
"""get_daily raises ValueError for a malformed string ID."""
430+
with pytest.raises(ValueError):
431+
get_daily(monitoring_location_id="dog", parameter_code="00060")

0 commit comments

Comments
 (0)