# -*- coding: utf-8 -*-
"""
Coordinates and dimensions utilities
"""
# Copyright 2020-2021 Shom
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Mapping
import xarray as xr
from .__init__ import XoaError, xoa_warn
from . import misc
from . import cf as xcf
[docs]@misc.ERRORS.format_function_docstring
def get_lon(da, errors="raise"):
"""Get the longitude coordinate
Parameters
----------
da: xarray.DataArray
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lat
get_depth
get_altitude
get_level
get_vertical
get_time
xoa.cf.CFSpecs.search_coord
"""
return xcf.get_cf_specs(da).search(da, 'lon', errors=errors)
[docs]def is_lon(da, loc="any"):
"""Tell if a data array is identified as longitudes
Parameters is_vertical
da: xarray.DataArray
Return
------
bool
See also
--------
is_lat
is_depth
is_altitude
is_level
is_time
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).coords.match(da, "lon", loc=loc)
[docs]@misc.ERRORS.format_function_docstring
def get_lat(da, errors="raise"):
"""Get the latitude coordinate
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_depth
get_altitude
get_level
get_vertical
get_time
xoa.cf.CFSpecs.search_coord
"""
return xcf.get_cf_specs(da).search(da, 'lat', errors=errors)
[docs]def is_lat(da, loc="any"):
"""Tell if a data array is identified as latitudes
Parameters
----------
da: xarray.DataArray
Return
------
bool
See also
--------
is_lon
is_depth
is_altitude
is_level
is_time
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).coords.match(da, "lat", loc=loc)
[docs]@misc.ERRORS.format_function_docstring
def get_depth(da, errors="raise"):
"""Get or compute the depth coordinate
If a depth variable cannot be found, it tries to compute either
from sigma-like coordinates or from layer thinknesses.
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_lat
get_time
get_altitude
get_level
get_vertical
xoa.cf.CFSpecs.search_coord
xoa.sigma.decode_cf_sigma
xoa.grid.decode_cf_dz2depth
"""
cfspecs = xcf.get_cf_specs(da)
errors = misc.ERRORS[errors]
ztype = cfspecs["vertical"]["type"]
# From variable
depth = cfspecs.search(da, 'depth', errors="ignore")
if depth is not None:
return depth
if ztype == "z" or not hasattr(da, "data_vars"): # explicitly
msg = "No depth coordinate found"
if errors == "raise":
raise XoaError(msg)
xoa_warn(msg)
return
# Decode the dataset
if ztype == "sigma" or ztype is None:
err = "ignore" if ztype is None else errors
from .sigma import decode_cf_sigma
da = decode_cf_sigma(da, errors=err)
if "depth" in da:
return da.depth
if ztype == "dz2depth" or ztype is None:
err = "ignore" if ztype is None else errors
from .grid import decode_cf_dz2depth
da = decode_cf_dz2depth(da, errors=err)
if "depth" in da:
return da.depth
msg = "Can't infer depth coordinate from dataset"
if errors == "raise":
raise XoaError(msg)
xoa_warn(msg)
[docs]def is_depth(da, loc="any"):
"""Tell if a data array is identified as depths
Parameters
----------
da: xarray.DataArray
Return
------
bool
See also
--------
is_lon
is_lat
is_altitude
is_level
is_time
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).coords.match(da, "depth", loc=loc)
[docs]@misc.ERRORS.format_function_docstring
def get_altitude(da, errors="raise"):
"""Get the altitude coordinate
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_lat
get_depth
get_level
get_vertical
get_time
xoa.cf.CFSpecs.search_coord
"""
return xcf.get_cf_specs(da).search(da, 'altitude', errors=errors)
[docs]def is_altitude(da, loc="any"):
"""Tell if a data array is identified as altitudes
Parameters
----------
da: xarray.DataArray
Return
------
bool
See also
--------
is_lon
is_lat
is_depth
is_level
is_time
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).coords.match(da, "altitude", loc=loc)
[docs]@misc.ERRORS.format_function_docstring
def get_level(da, errors="raise"):
"""Get the level coordinate
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_lat
get_depth
get_altitude
get_time
xoa.cf.CFSpecs.search_coord
"""
return xcf.get_cf_specs(da).coords.search(da, 'level', errors=errors)
[docs]def is_level(da, loc="any"):
"""Tell if a data array is identified as levels
Parameters
----------
da: xarray.DataArray
Return
------
bool
See also
--------
is_lon
is_lat
is_depth
is_altitude
is_time
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).coords.match(da, "levels", loc=loc)
[docs]@misc.ERRORS.format_function_docstring
def get_vertical(da, errors="raise"):
"""Get either depth or altitude
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_lat
get_depth
get_altitude
get_level
get_time
xoa.cf.CFSpecs.search_coord
"""
cfspecs = xcf.get_cf_specs()
height = cfspecs.search(da, 'depth', errors="ignore")
if height is None:
height = cfspecs.search(da, 'altitude', errors="ignore")
if height is None:
errors = misc.ERRORS[errors]
msg = "No vertical coordinate found"
if errors == "raise":
raise xcf.XoaCFError(msg)
elif errors == "warn":
xoa_warn(msg)
else:
return height
[docs]@misc.ERRORS.format_function_docstring
def get_time(da, errors="raise"):
"""Get the time coordinate
Parameters
----------
{errors}
Return
------
xarray.DataArray or None
See also
--------
get_lon
get_lat
get_depth
get_altitude
get_level
get_vertical
xoa.cf.CFSpecs.search_coord
"""
return xcf.get_cf_specs(da).coords.search(da, 'time', errors=errors)
[docs]def is_time(da):
"""Tell if a data array is identified as time
Parameters
----------
da: xarray.DataArray
Return
------
bool
See also
--------
is_lon
is_lat
is_depth
is_altitude
is_level
xoa.cf.CFCoordSpecs.match
"""
return xcf.get_cf_specs(da).match(da, "time")
[docs]@misc.ERRORS.format_function_docstring
def get_cf_coords(da, coord_names, errors="raise"):
"""Get several coordinates
Parameters
----------
{errors}
Return
------
list(xarray.DataArray)
See also
--------
xoa.cf.CFSpecs.search_coord
"""
cfspecs = xcf.get_cf_specs(da)
return [cfspecs.search_coord(da, coord_name, errors=errors)
for coord_name in coord_names]
[docs]@misc.ERRORS.format_function_docstring
def get_dims(da, dim_types, allow_positional=False, positions='tzyx', errors="warn"):
"""Get the data array dimensions names from their type
Parameters
----------
da: xarray.DataArray
Array to scan
dim_types: str, list
Letters among "x", "y", "z", "t" and "f".
allow_positional: bool
Fall back to positional dimension of types if unkown.
positions: str
Default position per type starting from the end.
{errors}
Return
------
tuple
Tuple of dimension name or None when the dimension if not found
See also
--------
xoa.cf.CFSpecs.get_dims
"""
return xcf.get_cf_specs(da).get_dims(
da, dim_types, allow_positional=allow_positional,
positions=positions, errors=errors)
[docs]@misc.ERRORS.format_function_docstring
def get_xdim(da, errors="warn", **kwargs):
"""Get the dimension of X type
It is a simple call to :func:`get_dims` with ``dim_types="x"``
Parameters
----------
da: xarray.DataArray
Array to scan
positions: str
Default position per type starting from the end.
{errors}
kwargs: dict
Extra parameters are passed to :func:`get_dims`
Return
------
str or None
The dimension name or None
See also
--------
get_dims
"""
dims = get_dims(da, "x", errors=errors)
if dims:
return dims[0]
[docs]@misc.ERRORS.format_function_docstring
def get_ydim(da, errors="warn", **kwargs):
"""Get the dimension of Y type
It is a simple call to :func:`get_dims` with ``dim_types="y"``
Parameters
----------
da: xarray.DataArray
Array to scan
positions: str
Default position per type starting from the end.
{errors}
kwargs: dict
Extra parameters are passed to :func:`get_dims`
Return
------
str or None
The dimension name or None
See also
--------
get_dims
"""
dims = get_dims(da, "y", errors=errors)
if dims:
return dims[0]
[docs]@misc.ERRORS.format_function_docstring
def get_zdim(da, errors="warn", **kwargs):
"""Get the dimension of Z type
It is a simple call to :func:`get_dims` with ``dim_types="z"``
Parameters
----------
da: xarray.DataArray
Array to scan
positions: str
Default position per type starting from the end.
{errors}
kwargs: dict
Extra parameters are passed to :func:`get_dims`
Return
------
str or None
The dimension name or None
See also
--------
get_dims
"""
dims = get_dims(da, "z", errors=errors)
if dims:
return dims[0]
[docs]@misc.ERRORS.format_function_docstring
def get_tdim(da, errors="warn", **kwargs):
"""Get the dimension of T type
It is a simple call to :func:`get_dims` with ``dim_types="t"``
Parameters
----------
da: xarray.DataArray
Array to scan
positions: str
Default position per type starting from the end.
{errors}
kwargs: dict
Extra parameters are passed to :func:`get_dims`
Return
------
str or None
The dimension name or None
See also
--------
get_dims
"""
dims = get_dims(da, "t", errors=errors)
if dims:
return dims[0]
[docs]@misc.ERRORS.format_function_docstring
def get_fdim(da, errors="warn", **kwargs):
"""Get the dimension of F type
It is a simple call to :func:`get_dims` with ``dim_types="f"``
Parameters
----------
da: xarray.DataArray
Array to scan
positions: str
Default position per type starting from the end.
{errors}
kwargs: dict
Extra parameters are passed to :func:`get_dims`
Return
------
str or None
The dimension name or None
See also
--------
get_dims
"""
dims = get_dims(da, "f", errors=errors)
if dims:
return dims[0]
[docs]class transpose_modes(misc.IntEnumChoices, metaclass=misc.DefaultEnumMeta):
"""Supported :func:`transpose` modes"""
#: Basic xarray transpose with :meth:`xarray.DataArray.transpose`
classic = 0
basic = 0
xarray = 0
#: Transpose skipping incompatible dimensions
compat = -1
#: Transpose adding missing dimensions with a size of 1
insert = 1
#: Transpose resizing to missing dimensions.
#: Note that dims must be an array or a dict of sizes
#: otherwise new dimensions will have a size of 1.
resize = 2
[docs]def transpose(da, dims, mode='compat'):
"""Transpose an array
Parameters
----------
da: xarray.DataArray
Array to tranpose
dims: tuple(str), xarray.DataArray, dict
Target dimensions or array with dimensions
mode: str, int
Transpose mode as one of the following:
{transpose_modes.rst_with_links}
Return
------
xarray.DataArray
Transposed array
Example
-------
.. ipython:: python
@suppress
import xarray as xr, numpy as np
@suppress
from xoa.coords import transpose
a = xr.DataArray(np.ones((2, 3, 4)), dims=('y', 'x', 't'))
b = xr.DataArray(np.ones((10, 3, 2)), dims=('m', 'y', 'x'))
# classic
transpose(a, (Ellipsis, 'y', 'x'), mode='classic')
# insert
transpose(a, ('m', 'y', 'x', 'z'), mode='insert')
transpose(a, b, mode='insert')
# resize
transpose(a, b, mode='resize')
transpose(a, b.sizes, mode='resize') # with dict
# compat mode
transpose(a, ('y', 'x'), mode='compat').dims
transpose(a, b.dims, mode='compat').dims
transpose(a, b, mode='compat').dims # same as with b.dims
See also
--------
xarray.DataArray.transpose
"""
# Inits
if hasattr(dims, 'dims'):
sizes = dims.sizes
dims = dims.dims
elif isinstance(dims, Mapping):
sizes = dims
dims = list(dims.keys())
else:
sizes = None
mode = str(transpose_modes[mode])
kw = dict(transpose_coords=True)
# Classic
if mode == "classic":
return da.transpose(*dims, **kw)
# Get specs
odims = ()
expand_dims = {}
with_ell = False
for dim in dims:
if dim is Ellipsis:
with_ell = True
odims += dim,
elif dim in da.dims:
odims += dim,
elif mode == "insert":
expand_dims[dim] = 1
odims += dim,
elif mode == "resize":
if sizes is None or dim not in sizes:
xoa_warn(f"new dim '{dim}' in transposition is set to one"
" since no size is provided to it")
size = 1
else:
size = sizes[dim]
expand_dims[dim] = size
odims += dim,
# Expand
if expand_dims:
da = da.expand_dims(expand_dims)
# Input dimensions that were not specified in transposition
# are flushed to the left
if not with_ell and set(odims) < set(da.dims):
odims = (...,) + odims
# Transpose
return da.transpose(*odims, **kw)
transpose.__doc__ = transpose.__doc__.format(**locals())
[docs]def get_dim_types(da, unknown=None, asdict=False):
"""Get dimension types
Parameters
----------
da: xarray.DataArray or tuple(str)
Data array or tuple of dimensions
unknown:
Value to assign to unknown types
asdict: bool
Get the result as dictionary
Return
------
tuple
"""
return xcf.get_cf_specs(da).coords.get_dim_types(
da, unknown=unknown, asdict=asdict)
[docs]def get_order(da):
"""Like :func:`get_dim_types` but returning a string"""
return "".join(get_dim_types(da, unknown="-", asdict=False))
[docs]def reorder(da, order):
"""Transpose an array to match a given order
Parameters
----------
da: xarray.DataArray
Data array to transpose
order: str
A combination of x, y, z, t, f and - symbols and
their upper case value.
Letters refer to the dimension type.
When the value is -, it may match any dimension type.
Return
------
xarray.DataArray
"""
# Convert from dim_types
if isinstance(order, dict):
order = tuple(order.values())
if isinstance(order, tuple):
order = ''.join([
('-' if o not in "ftzyx" else o) for o in order])
# From order to dims
to_dims = ()
dim_types = get_dim_types(da, asdict=True)
ndim = len(dim_types)
for i, o in enumerate(order[::-1]):
if i+1 == ndim:
break
for dim in da.dims:
if o == dim_types[dim]:
to_dims = (dim, ) + to_dims
break
else:
raise XoaError(
f"Coordinate type not found: {o}. Dims are: {da.dims}")
# Final transpose
return transpose(da, to_dims)
[docs]def get_coords_compat_with_dims(da, include_dims=None, exclude_dims=None):
"""Return the coordinates that are compatible with dims
Parameters
----------
da: xarray.DataArray
Data array
include_dims: set(str)
If provided, the coordinates must have at least one of these
dimensions
exclude_dims: set(str)
If provided, the coordinates must not have one of these dimnesions
Return
------
list(str)
List of coordinates
"""
if isinstance(include_dims, str):
include_dims = {include_dims}
if isinstance(exclude_dims, str):
exclude_dims = {exclude_dims}
coords = []
for coord in da.coords.values():
dims = set(coord.dims)
if include_dims and not include_dims.intersection(dims):
continue
if exclude_dims and exclude_dims.intersection(dims):
continue
coords.append(coord)
return coords
[docs]def change_index(da, dim, values):
"""Change the values of a dataset or data array index
Parameters
----------
da: xarray.Dataset, xarray.DataArray
dim: str
values: array_like
Return
------
xarray.Dataset, xarray.DataArray
See also
--------
xarray.DataArray.reset_index
xarray.DataArray.assign_coords
"""
attrs = da.coords[dim].attrs
if hasattr(values, "attrs"):
attrs.update(attrs)
if dim in da.indexes:
da = da.reset_index([dim], drop=True)
coord = xr.DataArray(values, dims=dim, attrs=attrs)
return da.assign_coords({dim: coord})
[docs]def drop_dim_coords(da, dim):
"""Drop coords that have a particular dim"""
return da.drop([c.name for c in da.coords.values() if dim in c.dims])
[docs]class positive_attr(misc.IntEnumChoices, metaclass=misc.DefaultEnumMeta):
"""Allowed value for the positive attribute argument"""
#: Infer it from the axis coordinate
infer = 0
guess = 0
#: Coordinates are increasing up
up = 1
#: Coordinates are increasing down
down = -1
[docs]def get_positive_attr(da, zdim=None):
"""Get the positive attribute of a dataset
Parameters
----------
da: xarray.Dataset, xarray.DataArray
zdim: None, str
The index coordinate name that is supposed to have this attribute,
which is usually the vertical dimension
Return
------
None, "up" or "down"
"""
# Targets
if zdim is None:
zdim = get_dims(da, "z", errors="ignore")
if zdim:
zdim = zdim[0]
if zdim and zdim in da.coords:
targets = [da.coords[zdim]]
else:
targets = list(da.coords.values())
if isinstance(da, xr.Dataset):
targets.extend(da.data_vars.values())
# Loop on targets
for target in targets:
if "positive" in target.attrs:
positive = da.coords[zdim].attrs["positive"]
return positive_attr[positive].name
# Fall back to current CFSpecs
cfspecs = xcf.get_cf_specs(da)
return cfspecs["vertical"]["positive"]