""" Utilities for rasterio SeeAlso util_gdal.py """ import kwimage import numpy as np # import pygeos import shapely as shp import shapely.geometry import shapely.ops import ubelt as ub import warnings from contextlib import ExitStack from dataclasses import dataclass # from tempenv import TemporaryEnvironment from typing import Union # import pyproj import rasterio import rasterio.features import rasterio.mask from rasterio import Affine, MemoryFile from rasterio.enums import Resampling try: from line_profiler import profile except Exception: profile = ub.identity def _ensure_open( raster: Union[rasterio.DatasetReader, str]) -> rasterio.DatasetReader: if not isinstance(raster, rasterio.DatasetReader) or raster.closed: # workaround for # https://rasterio.readthedocs.io/en/latest/faq.html#why-can-t-rasterio-find-proj-db-rasterio-from-pypi-versions-1-2-0 # with TemporaryEnvironment({'PROJ_LIB': None, 'PROJ_DEBUG': '3'}): if 1: return rasterio.open(raster) else: return raster def _swapxy(poly: shp.geometry.Polygon) -> shp.geometry.Polygon: return kwimage.Polygon.from_shapely(poly).swap_axes().to_shapely() @profile def mask(raster: Union[rasterio.DatasetReader, str], default_nodata=None, save=False, convex_hull=False, as_poly=True, tolerance=None, max_polys=None, use_overview=0): """ Compute a raster's valid data mask in pixel coordinates. Note that this is the rasterio mask, which for multi-band rasters is the binary OR of the individual band masks. This is different from the gdal mask, which is always per-band. Args: raster (str): Path to a dataset (raster image file) default_nodata (int): if raster's nodata value is None, default to this save (bool): if True and raster's nodata value is None, write the default to it. If False, performance overhead is incurred from creating a tempfile convex_hull (bool): if True, return the convex hull of the mask image or poly as_poly (bool): if True, return the mask as a shapely Polygon or MultiPolygon instead of a raster image, in (w, h) order (opposite of Python convention). tolerance (int): if specified, simplifies the valid polygon. use_overview (int): if non-zero uses the closest overview if it is available. This increases computation time, but gives a better polygon when use_overview is closer to 0. Note, the polygon is rescaled to ensure it is returned in the input pixel space, not the overview space. Returns: If as_poly, a shapely Polygon or MultiPolygon bounding the valid data region(s) in pixel coordinates. Else, a uint8 raster mask of the same shape as the input, where 255 == valid and 0 == invalid. Ignore: raster = '/home/joncrall/data/dvc-repos/smart_watch_dvc/drop1/../drop1/_assets/google-cloud/LS/LC08_L1TP_016039_20160216_20170224_01_T1/LC08_L1TP_016039_20160216_20170224_01_T1_B1.TIF' # noqa from geowatch.utils.util_raster import * mask_img = mask(raster, as_poly=False) Example: >>> # xdoctest: +REQUIRES(--network) >>> from geowatch.utils.util_raster import * >>> # FIXME; this demo path no longer has any nodata values >>> # Find a better demo with nodata >>> from geowatch.demo.landsat_demodata import grab_landsat_product >>> path = grab_landsat_product()['bands'][0] >>> # >>> mask_img = mask(path, as_poly=False) >>> import kwimage as ki >>> assert mask_img.shape == ki.load_image_shape(path)[:2] >>> got = set(np.unique(mask_img)) >>> print(f'got={got}') >>> # assert got == {0, 255} # cant do this until nodata is fixed >>> # >>> mask_poly = mask(path, as_poly=True) >>> import shapely >>> assert isinstance(mask_poly, shapely.geometry.Polygon) >>> # xdoctest: +REQUIRES(--show) >>> import kwplot >>> kwplot.autompl() >>> kwplot.figure(fnum=1, doclf=True) >>> imdata = kwimage.imread(path) >>> imdata = kwimage.normalize_intensity(imdata) >>> kw_poly = kwimage.Polygon.coerce(mask_poly.buffer(0).simplify(10)) >>> canvas = imdata.copy() >>> mask_alpha = kwimage.ensure_alpha_channel(mask_img, alpha=(mask_img > 0)) >>> canvas = kwimage.overlay_alpha_layers([mask_alpha, canvas]) >>> canvas = kw_poly.scale(0.9, about='center').draw_on(canvas, color='green', alpha=0.6) >>> kw_poly.scale(1.1, about='center').draw(alpha=0.5, color='red', setlim=True) >>> kwplot.imshow(canvas) Example: >>> # Test how the "save" functionality modifies the data >>> import kwimage >>> from geowatch.utils.util_raster import * >>> import pathlib >>> dpath = ub.Path.appdir('geowatch/tests/empty_raster').ensuredir() >>> raster = dpath / 'empty.tif' >>> ub.delete(raster) >>> kwimage.imwrite(raster, np.zeros((3, 3, 5))) >>> info1 = ub.cmd('gdalinfo {}'.format(raster)) >>> nodata = 0 >>> mask_img = mask(raster, as_poly=False) >>> print('mask_img = {!r}'.format(mask_img)) >>> info2 = ub.cmd('gdalinfo {}'.format(raster)) >>> mask_poly = mask(raster, as_poly=True) >>> info3 = ub.cmd('gdalinfo {}'.format(raster)) >>> print(info1['out']) >>> print(info2['out']) >>> print(info3['out']) """ scale_factor = None with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', category=rasterio.errors.NotGeoreferencedWarning) # workaround for # https://rasterio.readthedocs.io/en/latest/faq.html#why-can-t-rasterio-find-proj-db-rasterio-from-pypi-versions-1-2-0 # Do we need the temporary env anymore? # with TemporaryEnvironment({'PROJ_LIB': None, 'PROJ_DEBUG': '3'}): if True: img = _ensure_open(raster) img_height = img.height img_width = img.width # Work at the coarsest overview level for speed overviews = { tuple(img.overviews(bandx)) for bandx in range(1, img.count + 1) } if len(overviews) == 1: overview_levels = ub.peek(overviews) if len(overview_levels): img.close() # Open image with a higher overview level # https://github.com/rasterio/rasterio/issues/1504 if use_overview < 0: use_overview = len(overview_levels) + use_overview requested_overview = min(max(use_overview, 0), len(overview_levels) - 1) img = rasterio.open(img.name, 'r', overview_level=requested_overview) scale_factor = overview_levels[requested_overview] try: mask_img = None if default_nodata is None: nodata = img.nodata use_disk_nodata = True else: nodata = default_nodata use_disk_nodata = False if nodata is None: # Not specified, and not introspectable # TODO: early return # if as_poly: # pass # else: mask_img = np.full((img.height, img.width), fill_value=255, dtype=np.uint8) else: if save: raise NotImplementedError( 'Dont update here. It can be unsafe. ' 'Probably should be done in a separate script') # if needs_nodata: # # The image was closed, so we must open a new one # if save: # img.close() # # Add on necessary information in footer # with rasterio.open(raster, 'r+') as img: # img.nodata = nodata # img = rasterio.open(raster, 'r') # else: # profile = img.profile.copy() # profile['nodata'] = nodata # # TODO could optimize this with rasterio.shutil.copy # # or https://rasterio.readthedocs.io/en/latest/topics/windowed-rw.html#blocks # data = img.read() # img.close() # profile.update(nodata=nodata) # memfile = MemoryFile() # img = memfile.open(**profile) # img.write(data) if use_disk_nodata: mask_img = img.dataset_mask() else: # simulate 0 = nodata, 255=valid data # operate inplace when possible imdata = img.read(1, out_dtype=np.uint8) np.not_equal(imdata, nodata, dtype=np.uint8, out=imdata) np.multiply(imdata, 255, out=imdata) mask_img = imdata finally: img.close() if convex_hull: from skimage.morphology import convex_hull_image with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', 'Input image is entirely zero', category=UserWarning) mask_img = convex_hull_image(mask_img).astype(np.uint8) if not as_poly: return mask_img if mask_img is None: raise AssertionError('mask image was None') # mask has values 0 and 255 polys = [] for poly, val in rasterio.features.shapes(mask_img, connectivity=4): if val > 0: polys.append(shp.geometry.shape(poly)) if max_polys is not None and len(polys) > max_polys: break if tolerance is not None: scaled_tolerance = tolerance if scale_factor is not None: scaled_tolerance = tolerance / scale_factor polys = [poly.buffer(0).simplify(scaled_tolerance) for poly in polys] mask_poly = shp.ops.unary_union(polys).buffer(0) if tolerance is not None: mask_poly = mask_poly.simplify(scaled_tolerance) # do this again to fix any weirdness from union if convex_hull: mask_poly = mask_poly.convex_hull if scale_factor is not None: # Move from area space into point space? # mask_poly = shapely.affinity.translate(mask_poly, xoff=-0.5, yoff=-0.5) import shapely.affinity mask_poly = shapely.affinity.scale(mask_poly, xfact=scale_factor, yfact=scale_factor, origin=(0.0, 0.0)) # Using overviews to compute a polygon has slack. # Buffer to account for this. mask_poly = mask_poly.buffer(scale_factor) # Clip to the bounds bounds = shapely.geometry.box(0, 0, img_width, img_height) mask_poly = mask_poly.intersection(bounds) if tolerance is not None: mask_poly = mask_poly.simplify(tolerance) return mask_poly @dataclass class ResampledRaster(ExitStack): """ Context manager to rescale a raster on the fly using rasterio This changes the number of pixels in the raster while maintaining its geographic bounds, that is, it changes the raster's GSD. Args: raster: a DatasetReader (the object returned by rasterio.open) or path to a dataset scale: factor to upscale the resolution, aka downscale the GSD, by read: if True, read and return the resampled data (an expensive operation if scale>1) else, return the resampled dataset's .profile attribute (metadata) resampling: resampling algorithm, from rasterio.enums.Resampling [1] Example: >>> # xdoctest: +REQUIRES(--slow) >>> # xdoctest: +REQUIRES(--network) >>> from geowatch.utils.util_raster import * >>> from geowatch.demo.landsat_demodata import grab_landsat_product >>> path = grab_landsat_product()['bands'][0] >>> # >>> current_gsd_meters = 60 >>> desired_gsd_meters = 10 >>> scale = current_gsd_meters / desired_gsd_meters >>> # >>> with rasterio.open(path) as f: >>> old_profile = f.profile >>> # >>> # can instantiate this class in a with-block >>> with ResampledRaster(path, scale=scale, read=False) as f: >>> pass >>> # >>> # or have it stick around and change the resampling on the fly >>> resampled = ResampledRaster(path, scale=scale, read=False) >>> # >>> # the computation only happens when you invoke 'with' >>> with resampled as new_profile: >>> assert new_profile['width'] == int(old_profile['width'] * scale) >>> assert new_profile['crs'] == old_profile['crs'] >>> # >>> resampled.scale = scale / 2 >>> resampled.read = True >>> # >>> with resampled as new: >>> assert new.profile['width'] == int(old_profile['width'] * scale / 2) >>> assert new.profile['crs'] == old_profile['crs'] >>> # do other stuff with new References: https://gis.stackexchange.com/a/329439 https://rasterio.readthedocs.io/en/latest/topics/reading.html https://rasterio.readthedocs.io/en/latest/topics/profiles.html [1] https://rasterio.readthedocs.io/en/latest/api/rasterio.enums.html#rasterio.enums.Resampling """ raster: Union[str, rasterio.DatasetReader] scale: float = 2 read: bool = True resampling: Resampling = Resampling.bilinear def __post_init__(self): super().__init__() def __enter__(self): self.raster = _ensure_open(self.raster) t = self.raster.transform # rescale the metadata transform = Affine(t.a / self.scale, t.b, t.c, t.d, t.e / self.scale, t.f) height = int(np.ceil(self.raster.height * self.scale)) width = int(np.ceil(self.raster.width * self.scale)) profile = self.raster.profile profile.update(transform=transform, driver='GTiff', height=height, width=width) if self.read: # Note changed order of indexes, arrays are band, row, col order # not row, col, band data = self.raster.read( out_shape=(self.raster.count, height, width), resampling=self.resampling, ) # enter_context is from contextlib.ExitStack, which takes care of # closing these memfile = self.enter_context(MemoryFile()) with memfile.open(**profile) as dataset: # Open as DatasetWriter dataset.write(data) del data dataset = self.enter_context( memfile.open()) # Reopen as DatasetReader return dataset else: return profile def __exit__(self, *exc): pass