Skip to content

batching

batching

Spatial batching: assign features to spatially contiguous groups.

Group polygon features into spatially contiguous batches using KD-tree recursive bisection. This ensures that each batch's bounding box is compact, which is critical for efficient spatial subsetting of high-resolution source rasters (e.g., 3DEP at 10 m, gNATSGO at 30 m).

Without spatial batching, arbitrarily ordered fabric features would produce large bounding boxes that fetch far more raster data than needed, leading to excessive memory use and slow processing. The KD-tree approach (design.md section 5.5.1, Approach 5) provides O(n log n) partitioning with guaranteed spatial locality.

See Also

design.md : Section 5.5.1 (spatial batching approaches and trade-offs). hydro_param.pipeline : Orchestrator that processes batches sequentially.

References

.. [1] design.md section 5.5.1 -- KD-tree recursive bisection (Approach 5).

spatial_batch

spatial_batch(
    gdf: GeoDataFrame, batch_size: int = 500
) -> gpd.GeoDataFrame

Assign spatially contiguous batch IDs via KD-tree recursive bisection.

Group polygon features into spatially compact batches so that each batch's bounding box covers a small geographic area. This is the primary entry point for spatial batching in the pipeline, called during stage 1 (resolve fabric) before any data access occurs.

Compact bounding boxes are critical for memory efficiency: when the pipeline fetches source rasters clipped to a batch's bbox, a tight bbox means less data loaded into memory. For high-resolution datasets like gNATSGO (30 m, ~1.25 GB per variable), this prevents OOM errors.

PARAMETER DESCRIPTION
gdf

Target fabric with polygon geometries (Polygon or MultiPolygon). May be in any CRS -- centroids are computed for spatial grouping only (approximate centroids are sufficient).

TYPE: GeoDataFrame

batch_size

Target number of features per batch. The actual batch sizes will vary due to the recursive bisection algorithm (typically within [batch_size/2, batch_size*2]). Default is 500.

TYPE: int DEFAULT: 500

RETURNS DESCRIPTION
GeoDataFrame

Copy of input with a batch_id column (int) added. Batch IDs are sequential integers starting from 0. Features within the same batch are spatially contiguous.

Notes

For fabrics with geographic CRS (e.g., EPSG:4326), the centroid computation emits a UserWarning about geographic CRS accuracy. This warning is suppressed because only approximate centroids are needed for spatial grouping -- the batch boundaries do not need to be geometrically precise.

The recursion depth is computed as ceil(log2(n_features / batch_size)) to produce approximately the right number of batches. The min_batch_size is set to batch_size / 2 to prevent excessive fragmentation.

Examples:

>>> import geopandas as gpd
>>> fabric = gpd.read_file("nhru.gpkg")
>>> batched = spatial_batch(fabric, batch_size=200)
>>> batched["batch_id"].nunique()
4  # for ~765 features with batch_size=200
See Also

_recursive_bisect : The recursive partitioning algorithm. hydro_param.pipeline : Uses batch IDs to iterate over spatial groups.

Source code in src/hydro_param/batching.py
def spatial_batch(
    gdf: gpd.GeoDataFrame,
    batch_size: int = 500,
) -> gpd.GeoDataFrame:
    """Assign spatially contiguous batch IDs via KD-tree recursive bisection.

    Group polygon features into spatially compact batches so that each
    batch's bounding box covers a small geographic area. This is the
    primary entry point for spatial batching in the pipeline, called
    during stage 1 (resolve fabric) before any data access occurs.

    Compact bounding boxes are critical for memory efficiency: when the
    pipeline fetches source rasters clipped to a batch's bbox, a tight
    bbox means less data loaded into memory. For high-resolution datasets
    like gNATSGO (30 m, ~1.25 GB per variable), this prevents OOM errors.

    Parameters
    ----------
    gdf : gpd.GeoDataFrame
        Target fabric with polygon geometries (Polygon or MultiPolygon).
        May be in any CRS -- centroids are computed for spatial grouping
        only (approximate centroids are sufficient).
    batch_size : int
        Target number of features per batch. The actual batch sizes will
        vary due to the recursive bisection algorithm (typically within
        ``[batch_size/2, batch_size*2]``). Default is 500.

    Returns
    -------
    gpd.GeoDataFrame
        Copy of input with a ``batch_id`` column (int) added. Batch IDs
        are sequential integers starting from 0. Features within the
        same batch are spatially contiguous.

    Notes
    -----
    For fabrics with geographic CRS (e.g., EPSG:4326), the centroid
    computation emits a ``UserWarning`` about geographic CRS accuracy.
    This warning is suppressed because only approximate centroids are
    needed for spatial grouping -- the batch boundaries do not need to be
    geometrically precise.

    The recursion depth is computed as ``ceil(log2(n_features / batch_size))``
    to produce approximately the right number of batches. The
    ``min_batch_size`` is set to ``batch_size / 2`` to prevent excessive
    fragmentation.

    Examples
    --------
    >>> import geopandas as gpd
    >>> fabric = gpd.read_file("nhru.gpkg")
    >>> batched = spatial_batch(fabric, batch_size=200)
    >>> batched["batch_id"].nunique()
    4  # for ~765 features with batch_size=200

    See Also
    --------
    _recursive_bisect : The recursive partitioning algorithm.
    hydro_param.pipeline : Uses batch IDs to iterate over spatial groups.
    """
    if gdf.empty:
        result = gdf.copy()
        result["batch_id"] = np.array([], dtype=int)
        return result

    # Short-circuit: single batch when all features fit
    if len(gdf) <= batch_size:
        result = gdf.copy()
        result["batch_id"] = 0
        logger.info(
            "Spatial batching: %d features → 1 batch (all fit in batch_size=%d)",
            len(gdf),
            batch_size,
        )
        return result

    # Geographic CRS centroid warning is expected — we only need
    # approximate centroids for spatial grouping, not precision.
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message=".*geographic CRS.*centroid.*")
        centroids = np.column_stack(
            [gdf.geometry.centroid.x.values, gdf.geometry.centroid.y.values]
        )

    n_batches = max(1, len(gdf) // batch_size)
    max_depth = max(1, int(np.ceil(np.log2(n_batches))))

    batches = _recursive_bisect(
        centroids,
        np.arange(len(gdf)),
        max_depth=max_depth,
        min_batch_size=max(1, batch_size // 2),
    )

    batch_ids = np.empty(len(gdf), dtype=int)
    for batch_id, indices in enumerate(batches):
        batch_ids[indices] = batch_id

    result = gdf.copy()
    result["batch_id"] = batch_ids

    logger.info(
        "Spatial batching: %d features → %d batches (target size=%d, actual range=%d%d)",
        len(gdf),
        len(batches),
        batch_size,
        min(len(b) for b in batches),
        max(len(b) for b in batches),
    )

    return result