# -*- coding: utf8 -*-
"""
Kriging adapted from vacumm's module
https://github.com/VACUMM/vacumm/blob/master/lib/python/vacumm/misc/grid/kriging.py
"""
from __future__ import absolute_import
from multiprocessing import Pool, cpu_count
import warnings
import numpy as np
import xarray as xr
from . import exceptions
from . import misc
from . import geo as xgeo
from . import coords as xcoords
def _get_blas_func_(name):
import scipy.linalg.blas
return scipy.linalg.blas.get_blas_funcs(name)
def _get_lapack_func_(name):
import scipy.linalg.lapack
return scipy.linalg.lapack.get_lapack_funcs(name)
def _dgemv_(a, x):
blas_dgemv = _get_blas_func_('gemv')
return blas_dgemv(1.0, a, x)
def _symm_(a, b):
blas_dgemm = _get_blas_func_('gemm')
return blas_dgemm(1.0, a, b)
def _syminv_(a):
"""Invert a real symmetric definite matrix"""
return np.linalg.pinv(a)
n = a.shape[0]
jj, ii = np.triu_indices(n)
up = a[jj, ii]
pptri = _get_lapack_func_("pptri")
res = pptri(n, up)
if isinstance(res, tuple):
info = res[1]
if info:
raise exceptions.KrigingError(f'Error during call to Lapack DPPTRI (info={info})')
return res[0]
else:
return res
[docs]
def get_xyz(obj):
"""Get lon/lat coordinates and data values from a data array or dataset
Parameters
----------
obj: xarray.DataArray, xarray.Dataset
If a data array, it must have valid longitude and latitude coordinates.
If a dataset, it must have a single variable as in the data array case.
Return
------
numpy.array
Longitudes as 1D array
numpy.array
Latitudes as 1D array
numpy.array
Values as a 1D or 2D. None if `obj` is a dataset.
"""
# Xarray stuff
obj = xcoords.geo_stack(obj, "npts")
lon = xcoords.get_lon(obj)
lat = xcoords.get_lat(obj)
# Numpy
x = obj.coords[lon.name].values
y = obj.coords[lat.name].values
if isinstance(obj, xr.DataArray):
z = obj.values.reshape(-1, x.size)
else:
z = None
return x, y, z
[docs]
def empirical_variogram(
da: xr.DataArray,
nbin=30,
nbin0=10,
nmax=1500,
dist_units="m",
distmax=None,
errfunc=None,
):
"""Compute the semi-variogram from data
Parameters
----------
da: xarray.DataArray
Data array with lon and lat coordinates.
nmax: optional
Above this number, size of the sample is reduced by a crude undersampling.
binned: optional
If set to a number,
data are arranged in bins to estimate
variogram. If set to ``None``, data are
arranged in bins if the number of pairs
of points is greater than ``nbindef*nbmin``.
nbindef: optional
Default number of bins (not used if ``binned`` is a number).
nbin0: optional
If set to a number > 1,
the first bin is split into nbin0 sub-bins.
If set to ``None``, it is evaluated with
``min(bins[1]/nbmin, nbin)``.
nbmin: optional
Minimal number of points in a bin.
dist_units: str, int, xoa.geo.distance_units
Distance units as one of: {xgeo.distance_units.rst_with_links}
distmax: optional
Max distance to consider.
errfunc: optional
Callable function to compute "errors" like square
root difference between to z values. It take two arguments and
defaults to :math:`(z1-z0)^2/2`.
Return
------
xarray.DataArray
Values as 1D array with name "semivariogram" and with "dist" as distance coordinate in km
"""
da = xcoords.geo_stack(da, "npts", rename=True, reset_index=True)
npts = da.sizes["npts"]
# Undepsample?
if npts > nmax:
samp = npts / nmax
da = da.isel(npts=slice(None, None, samp))
npts = da.sizes["npts"]
# Distances
dist_units = xgeo.distance_units[dist_units]
dd = xgeo.get_distances(da).values
if dist_units == xgeo.distance_units.km:
dd *= 1e-3
iitriu = np.triu_indices(dd.shape[0], 1)
d = dd[iitriu]
del dd
# Max distance
if distmax:
iiclose = d <= distmax
d = d[iiclose]
# v = v[valid]
# del valid
else:
iiclose = ...
# Variogram
if errfunc is None:
def errfunc(a0, a1):
return 0.5 * (a1 - a0) ** 2
z = np.atleast_2d(da.values)
v = np.asarray([errfunc(*np.meshgrid(z[i], z[i]))[iitriu][iiclose] for i in range(z.shape[0])])
# Unique
d, iiuni = np.unique(d, return_index=True)
v = v[:, iiuni]
# Compute edges
# - classic bins
nbin = min(d.size, nbin)
iiedges = np.linspace(0, d.size - 1, nbin + 1).astype('l').tolist()
# - more details in the first bin
if nbin0 > 1 and nbin0 < iiedges[1]: # split first bin
iiedges = np.linspace(0.0, iiedges[1], nbin0 + 1).astype('l')[:-1].tolist() + iiedges[1:]
nbin = nbin - 1 + nbin0 # len(iiedges)-1
# Compute histogram
db = np.empty(nbin)
vb = np.empty(nbin)
for ib in range(nbin):
iib = slice(iiedges[ib], iiedges[ib + 1] + 1)
db[ib] = d[iib].mean()
vb[ib] = v[:, iib].mean()
# Dataarray
# dist = xr.DataArray(db, dims="dist", attrs={'long_name': "Distance", "units": "km"})
attrs = {}
if "long_name" in da.attrs:
attrs = {"long_name": "Semi-variogram of " + da.attrs["long_name"]}
else:
attrs = {"long_name": "Semi-variogram"}
if "units" in da.attrs: # TODO: pint support
attrs = {"units": da.attrs["units"] + "^2"}
return xr.DataArray(
vb,
dims="dist",
coords={"dist": ("dist", db, {"long_name": "Distance", "units": str(dist_units)})},
attrs=attrs,
name="semivariogram",
)
empirical_variogram.__doc__ = empirical_variogram.__doc__.format(**locals())
[docs]
class variogram_model_types(misc.IntEnumChoices, metaclass=misc.DefaultEnumMeta):
"""Supported types of variograms"""
#: Exponential (default)
exponential = 1
#: Linear
linear = 0
#: Gaussian
gaussian = 2
#: Spherical
spherical = 3
[docs]
def get_variogram_model_func(mtype, n, s, r, nrelmax=0.2):
"""Get the variogram model function from its name"""
mtype = variogram_model_types[mtype]
n = max(n, 0)
n = min(n, nrelmax * s)
r = max(r, 0)
s = max(s, 0)
if mtype.name == 'linear':
return lambda h: n + (s - n) * ((h / r) * (h <= r) + 1 * (h > r))
if mtype.name == 'exponential':
return lambda h: n + (s - n) * (1 - np.exp(-3 * h / r))
if mtype.name == 'gaussian':
return lambda h: n + (s - n) * (1 - np.exp(-3 * h**2 / r**2))
if mtype.name == 'spherical':
return lambda h: n + (s - n) * ((1.5 * h / r - 0.5 * (h / r) ** 3) * (h <= r) + 1 * (h > r))
[docs]
class VariogramModel(object):
"""Class used when fitting a variogram model to data to better control params
Parameters
----------
mtype: int, str, variogram_model_types
dist_units: int, str, xoa.geo.distance_units
**frozen_params:
Variogram parameters that must be frozen.
"""
param_names = list(get_variogram_model_func.__code__.co_varnames[1:])
param_names.remove('nrelmax')
[docs]
def __init__(self, mtype, dist_units="m", **frozen_params):
self._dist_units = xgeo.distance_units[dist_units]
self.mtype = variogram_model_types[mtype]
self._frozen_params = {}
self._estimated_params = {}
self._fit = None
self._fit_err = None
self.set_params(**frozen_params)
self._ev = None
def __str__(self):
clsname = self.__class__.__name__
mtype = self.mtype.name
dist_units = self._dist_units
sp = []
for name in self.param_names:
sp.append("{}={}".format(name, self[name]))
sp = ', '.join(sp)
return f"<{clsname}('{mtype}', dist_units='{dist_units}', {sp})>"
def __repr__(self):
return str(self)
@property
def dist_units(self):
"""Distance units of type :class:`~xoa.geo.distance_units`"""
return self._dist_units
@property
def frozen_params(self):
"""Frozen parameters"""
return dict(
(name, self._frozen_params[name])
for name in self.param_names
if name in self._frozen_params
)
[docs]
def get_estimated_params(self):
"""Get parameters that were estimated (not frozen)"""
return dict(
(name, self._estimated_params.get(name))
for name in self.param_names
if name not in self._frozen_params
)
[docs]
def set_estimated_params(self, overwrite=True, **params):
"""Set the value of non-frozen parameters"""
params_update = dict(
(name, params[name])
for name in self.param_names
if name not in self._frozen_params
and name in params
and (not overwrite or name not in self._estimated_params)
)
self._estimated_params.update(params_update)
estimated_params = property(
get_estimated_params, set_estimated_params, doc='Estimated parameters as :class:`dict`'
)
[docs]
def set_params(self, **params):
"""Freeze some parameters"""
params = dict(
[(p, v) for (p, v) in params.items() if p in self.param_names and v is not None]
)
self._frozen_params.update(params)
[docs]
def get_params(self, **params):
"""Get current parameters with optional update
Parameters
----------
params:
Extra parameters to alter currents values
Return
------
dict, numpy.array
"""
these_params = dict(**self.frozen_params, **self.estimated_params)
if params:
these_params.update(
dict(
[(p, v) for (p, v) in params.items() if p in self.param_names and v is not None]
)
)
return dict((name, these_params[name]) for name in self.param_names)
[docs]
def get_param(self, name):
"""Get a single parameter
Parameters
----------
name: str
A valid parameter name
Return
------
float, None
Returns None if the parameter is not frozen and has not been estimated yet.
"""
if name not in self.param_names:
raise exceptions.KrigingError(
f"Invalid param name: {name}. Please use one of: " + ", ".join(self.param_names)
)
if name in self._frozen_params:
return self._frozen_params[name]
return self._estimated_params.get(name)
__getitem__ = get_param
[docs]
def get_params_array(self):
"""Get the :attr:`estimated_params` as an array
Return
------
numpy.array
"""
pp = list(self.estimated_params.values())
if None in pp:
raise exceptions.XoaKrigingError(
"Not all parameters are estimated: {}".format(self.estimated_params)
)
return np.array(pp)
[docs]
def set_params_array(self, pp):
"""Set the :attr:`estimated_param` with an array
Parameters
----------
pp: numpy.array
Array of estimated parameters
"""
for i, name in enumerate(self.estimated_params):
self._estimated_params[name] = pp[i]
return self.params
@property
def params(self):
"""Current variogram model parameters"""
return self.get_params()
[docs]
def apply(self, d, pp=None):
"""Call the variogram model function
Parameters
----------
d: array
Distances
"""
return self.get_func(pp)(d)
__call__ = apply
[docs]
def get_func(self, pp=None):
"""Get the variogram model function using `pp` variable arguments"""
if pp is not None:
params = self.set_params_array(pp)
else:
params = self.get_params()
if None in list(params.values()):
raise exceptions.XoaKrigingError(
"Not all parameters are estimated: {}".format(self.estimated_params)
)
return get_variogram_model_func(self.mtype, **params)
[docs]
def fit(self, da: xr.DataArray, **kwargs):
"""Estimate parameters from data"""
# We need a data array
if isinstance(da, xr.Dataset):
if len(da) == 1:
da = da[list(da)[0]]
elif len(da) > 1:
exceptions.exceptions.xoa_warn(
"Multiple candidate variables found in the dataset for estimating "
"the variogram parameters. Keeping only the first one."
)
elif len(da) == 0:
raise exceptions.KrigingError(
"No variable found in the dataset for estimating the variogram parameters."
)
# Empirical variogram
if da.name == "semivariogram" or da.name == "variogram" and "dist" in da.coords:
if (
"units" in da.dist.attrs
and da.dist.attrs["units"] in xgeo.distance_units
and xgeo.distance_units[da.dist.attrs["units"]] != self._dist_units
):
exceptions.exceptions.xoa_warn(
"Incompatible distance units: {} vs {}".format(
self._dist_units, da.dist.attrs["units"]
)
)
ev = da
else:
kwargs["dist_units"] = self._dist_units
ev = empirical_variogram(da, **kwargs)
dist = ev.dist.values
values = ev.values
self._ev = ev
# First guess of paramaters
imax = np.ma.argmax(values)
self.set_estimated_params(n=0.0, s=values[imax], r=dist[imax], overwrite=False)
pp0 = self.get_params_array()
# Fitting
from scipy.optimize import minimize
def func(pp):
return ((values - self(dist, pp)) ** 2).sum()
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'divide by zero encountered in divide')
self._fit = minimize(
func, pp0, bounds=[(np.finfo('d').eps, None)] * len(pp0), method='L-BFGS-B'
)
pp = self._fit['x']
self._fit_err = np.sqrt(func(pp)) / values.size
self.set_params_array(pp)
[docs]
def plot(self, rmax=None, nr=100, show_params=True, **kwargs):
"""Plot the semivariogram
Parameters
----------
rmax: float
Max range in meters
nr: int
Number of points to plot the curve
show_params: bool, dict
Show a text box that contains the variogram parameters in the lower right corner.
kwargs: dict
Extra keyword are passed to the `xarray.DataArray.plot` callable accessor
"""
# Distances
if rmax is None and self._ev is not None:
rmax = self._ev.dist.max()
else:
rmax = self["r"] * 1.2
dist = np.linspace(0, rmax, nr)
# Array and plot
du = str(self._dist_units)
mv = xr.DataArray(
self.get_func()(dist),
dims='dist',
coords=[("dist", dist, {"long_name": "Distance", "units": du})],
attrs={"long_name": self.mtype.name.title() + " fit"},
)
kwargs.setdefault("label", mv.long_name)
p = mv.plot(**kwargs)
# Text box for params
if show_params:
params = self.params.copy()
text = [
"r[ange] = {:<g} {}".format(params["r"], du),
"n[ugget] = {:<g}".format(params["n"]),
"s[ill] = {:<g}".format(params["s"]),
]
maxlen = max([len(t) for t in text])
text = "\n".join(t.ljust(maxlen) for t in text)
axes = p[0].axes
axes.text(
0.98,
0.04,
text,
transform=axes.transAxes,
family="monospace",
bbox=dict(facecolor=(1, 1, 1, 0.5)),
ha="right",
)
return p
[docs]
class kriging_types(misc.IntEnumChoices, metaclass=misc.DefaultEnumMeta):
"""Supported kriging types"""
#: Ordinary kriging (default)
ordinary = 1
#: Simple kriging
simple = 0
[docs]
class Kriger(object):
"""Kriger that supports clusterization to limit memory
Big input cloud of points (size > ``npmax``)
are split into smaller clusters using cluster analysis of distance with
function :func:`~xoa.geo.clusterize`.
The problem is solved in this way:
#. Input points are split in clusters if necessary.
#. The input variogram matrix is inverted
for each cluster, possibly using
:mod:`multiprocessing` if ``nproc>1``.
#. Values are computed at output positions
each using the inverted matrix of cluster.
#. Final value is a weighted average of
the values estimated using each cluster.
Weights are inversely proportional to the inverse
of the squared error.
Parameters
----------
da: xarray.DataArray, xarray.Dataset
Input positions and optionally data.
krigtype: optional
Kriging type: {kriging_types.rst_with_links}.
variogram_func: callable, VariogramModel, optional
Callable to be used as a variogram function.
It is either a function or an instance of :class:`VariogramModel`.
npmax: optional
Maximal number of points to be used simultaneously for kriging.
When the number of input points is greater than this value,
clusterization is applied.
nproc: optional
Number of processes to use to invert matrices.
Set it to a number <2 to switch off parallelisation.
exact: optional
If True, variogram is exactly zero when distance is zero.
"""
[docs]
def __init__(
self,
da,
krigtype,
variogram_func,
npmax=None,
nproc=None,
exact=False,
dist_units=None,
mean=None,
farvalue=None,
**kwargs,
):
# Kriging type
self.krigtype = kriging_types[krigtype]
# Variogram function
if isinstance(variogram_func, str):
variogram_func = VariogramModel(variogram_func, dist_units=dist_units)
if isinstance(variogram_func, VariogramModel):
if dist_units is None:
dist_units = variogram_func.dist_units
else:
dist_units = xgeo.distance_units[dist_units]
if dist_units != variogram_func.dist_units:
exceptions.xoa_warn(
f"Incompatible distance units: {dist_units} vs {variogram_func.dist_units}"
)
variogram_func.fit(da)
self._variogram_func = variogram_func
else:
self._variogram_func = variogram_func
self._dist_units = xgeo.distance_units[dist_units]
# Clusters
if npmax is None:
npmax = np.inf
self._clusters = xgeo.clusterize(da, npmax=npmax, split=True)
self._unstacked_coords = {}
for cname, cdat in self.clusters[0].coords.items():
if "npts" not in cdat.dims:
self._unstacked_coords[cname] = cdat
# Number of cores for parallel computing
if nproc is None:
nproc = cpu_count()
else:
nproc = max(1, min(cpu_count(), nproc))
self.nproc = min(nproc, self.nclust)
# Other parameters
self.exact = exact
if self.krigtype != kriging_types.simple:
mean = 0.0
elif mean is None:
mean = float(da.mean())
self.mean = mean
@property
def clusters(self):
return self._clusters
@property
def nclust(self):
"""Number of clusters"""
return len(self.clusters)
@property
def npmax(self):
"""Max number of points per cluster"""
return self.clusters[0].npmax
@property
def dist_units(self):
"""Distance units"""
return self._dist_units
@property
def variogram_func(self):
"""Variogram function or callable, like a :class:`VariogramModel` instance"""
return self._variogram_func
@property
def Ainv(self):
"""Get the inverse of the A matrix"""
# Already computed
if hasattr(self, '_Ainv'):
return self._Ainv
# Variogram function
vgf = self.variogram_func
# Loop on clusters
Ainv = []
AA = []
plus1 = int(self.krigtype == kriging_types.ordinary)
for ic, clust in enumerate(self.clusters):
# Get distance between input points
dd = xgeo.get_distances(clust).values
# Form A
npts = clust.sizes["npts"]
A = np.empty((npts + plus1, npts + plus1))
A[:npts, :npts] = vgf(dd)
if self.exact:
np.fill_diagonal(A, 0)
A[:npts, :npts][np.isclose(A[:npts, :npts], 0.0)] = 0.0
if self.krigtype == kriging_types.ordinary:
A[-1] = 1
A[:, -1] = 1
A[-1, -1] = 0
# Invert for single processor
if self.nproc == 1:
Ainv.append(_syminv_(A))
else:
AA.append(A)
# Multiprocessing inversion
if self.nproc > 1:
pool = Pool(self.nproc)
Ainv = pool.map(_syminv_, AA, chunksize=1)
pool.close()
# Fortran arrays
Ainv = [np.asfortranarray(ainv, 'd') for ainv in Ainv]
self._Ainv = Ainv
return Ainv
[docs]
def interp(self, dso: xr.Dataset, block=None, name=None):
"""Interpolate to the `dso` positions
Parameters
----------
dso: xr.Dataset
Dataset that contains lon and lat coordinates
block: None, int
Number of nearest neighbours for block kriging
name: str, None:
Default name of the output data array.
Return
------
xarray.Dataset
A dataset that contains the interpolated values and associated errors
"""
# Inits
dso = xcoords.geo_stack(dso, rename=False, drop=True)
sname, xname, yname = dso.encoding["geo_stack"]
stacked_coords = {}
for cname, cdat in dso.coords.items():
if sname in cdat.dims:
stacked_coords[cname] = cdat
vgf = self.variogram_func
nptso = dso.sizes[sname]
vname = self.clusters[0].encoding["clust_var_names"][0]
so = self.clusters[0][vname].shape[:-1] + (nptso,)
dimso = self.clusters[0][vname].dims[:-1] + (sname,)
dimsoe = (sname,)
zo = np.zeros(so, 'd')
eo = np.zeros(nptso, 'd')
wo = np.zeros(nptso, 'd')
if block:
xyo = np.dstack([dso[xname].values, dso[yname].values])[0]
# Loop on clusters
Ainv = self.Ainv
plus1 = int(self.krigtype == kriging_types.ordinary)
for ic in range(self.nclust): # TODO: multiproc here?
# Distances to output points
dd = xgeo.get_distances(self.clusters[ic], dso, units=self._dist_units).values
# Form B
npts = self.clusters[ic].sizes["npts"]
B = np.empty((npts + plus1, nptso))
B[:npts] = vgf(dd)
if self.krigtype == kriging_types.ordinary:
B[-1] = 1
if self.exact:
B[:npts][np.isclose(B[:npts], 0.0)] = 0.0
del dd
# Block kriging
if block:
from scipy.spatial import cKDTree
tree = cKDTree(xyo)
Bb = B.copy()
for i, iineigh in enumerate(tree.query_ball_tree(tree, block)):
Bb[:, i] = B[:, iineigh].mean()
B = Bb
# Compute weights
W = np.ascontiguousarray(_symm_(Ainv[ic], np.asfortranarray(B, 'd')))
# Interpolate
zc = self.clusters[ic][vname].values
if self.krigtype == kriging_types.simple:
zc = zc - self.mean
z = zc @ W[:npts]
if self.krigtype == kriging_types.simple:
z += self.mean
# Get error
# e = (W[:-1]*B[:-1]).sum(axis=0)
e = (W * B).sum(axis=0)
del W, B
# Weighted contribution based on errors
w = 1 / e**2
if self.nclust > 1:
z[:] *= w
wo += w
del w
zo += z
del z
# Error
eo = 1 / np.sqrt(wo)
# Normalization
if self.nclust > 1:
zo /= wo
# Format
coords = self._unstacked_coords.copy()
coords.update(stacked_coords)
dao = xr.DataArray(zo, dims=dimso, coords=coords, attrs=self.clusters[0].attrs)
daoe = xr.DataArray(
eo, dims=dimsoe, coords=stacked_coords, attrs={"long_name": "Squared error"}
)
if name is None:
name = self.clusters[0].encoding["clust_var_names"][0]
if name is None:
name = "data"
return xr.Dataset({name: dao, name + "_error": daoe}).unstack()
# gc.collect()
# if geterr:
# return zo, eo
# return zo
__call__ = interp
Kriger.__doc__ = Kriger.__doc__.format(**locals())
[docs]
def krig(dsi, dso, krigtype="ordinary", **kwargs):
"""Quickly krig data"""
kwinterp = {}
for key in "block", "name":
if key in kwargs:
kwinterp[key] = kwargs.pop(key)
return Kriger(dsi, krigtype, **kwargs).interp(dso, **kwinterp)