Source code for xarrayutils.plotting

from tkinter import Y
import numpy as np
import xarray as xr
import warnings

# import mpl and change the backend before other mpl imports
try:
    import matplotlib as mpl
    from matplotlib.transforms import blended_transform_factory

    mpl.use("Agg")
    import matplotlib.pyplot as plt

    mpl = True
except ImportError:
    raise RuntimeError(
        "The `plotting` module requires `matplotlib`. Install using conda install -c conda-forge matplotlib "
    )

try:
    import gsw
except:
    gsw = None

import string

try:
    import cartopy
except ImportError:
    cartopy = None


[docs]def xr_violinplot(ds, ax=None, x_dim="xt_ocean", width=1, color="0.5"): """Wrapper of matplotlib violinplot for xarray.DataArray. Parameters ---------- ds : xr.DataArray Input data. ax : matplotlib.axis Plotting axis (the default is None). x_dim : str dimension that defines the x-axis of the plot (the default is 'xt_ocean'). width : float Scaling width of each violin (the default is 1). color : type Color of the violin (the default is '0.5'). Returns ------- type Description of returned object. """ x = ds[x_dim].data.copy() y = [ds.loc[{x_dim: xx}].data for xx in x] y = [data[~np.isnan(data)] for data in y] # check if all are nan idx = [len(dat) == 0 for dat in y] x = [xx for xx, ii in zip(x, idx) if not ii] y = [yy for yy, ii in zip(y, idx) if not ii] if ax is None: ax = plt.gca() vp = ax.violinplot( y, x, widths=width, showextrema=False, showmedians=False, showmeans=True ) [item.set_facecolor(color) for item in vp["bodies"]] for item in ["cmaxes", "cmins", "cbars", "cmedians", "cmeans"]: if item in vp.keys(): vp[item].set_edgecolor(color) return vp
[docs]def axis_arrow(ax, x_loc, text, arrowprops={}, **kwargs): """Puts an arrow pointing at `x_loc` onto (but outside of ) the xaxis of a plot.For now only works on xaxis and on the top. Modify when necessary Parameters ---------- ax : matplotlib.axis axis to plot on. x_loc : type Position of the arrow (in units of `ax` x-axis). text : str Text next to arrow. arrowprops: dict Additional arguments to pass to arrowprops. See mpl.axes.annotate for details. kwargs: additional keyword arguments passed to ax.annotate """ ar_props = dict(dict(fc="k", lw=1.5, ec=None)) ar_props.update(arrowprops) tform = blended_transform_factory(ax.transData, ax.transAxes) ax.annotate( text, xy=[x_loc, 1], xytext=(x_loc, 1.25), xycoords=tform, textcoords=tform, ha="center", va="center", arrowprops=ar_props, **kwargs, )
[docs]def letter_subplots(axes, start_idx=0, box_color=None, labels=None, **kwargs): """Adds panel letters in boxes to each element of `axes` in the upper left corner. Parameters ---------- axes : list, array_like List or array of matplotlib axes objects. start_idx : type Starting index in the alphabet (e.g. 0 is 'a'). box_color : type Color of the box behind each letter (the default is None). labels: list List of strings used as labels (if None (default), uses lowercase alphabet followed by uppercase alphabet) **kwargs : type kwargs passed to matplotlib.axis.text """ if labels is None: labels = list(string.ascii_letters) for ax, letter in zip(axes.flat, labels[start_idx:]): t = ax.text( 0.1, 0.85, letter + ")", horizontalalignment="center", verticalalignment="center", transform=ax.transAxes, **kwargs, ) if box_color: t.set_bbox(dict(facecolor=box_color, alpha=0.5, edgecolor=None))
[docs]def map_util_plot( ax, land_color="0.7", coast_color="0.3", lake_alpha=0.5, labels=False ): """Helper tool to add good default map to cartopy axes. Parameters ---------- ax : cartopy.geoaxes (not sure this is right) The axis to plot on (must be a cartopy axis). land_color : type Color of land fill (the default is '0.7'). coast_color : type Color of costline (the default is '0.3'). lake_alpha : type Transparency of lakes (the default is 0.5). labels : type Not implemented. """ if cartopy is None: raise RuntimeError( "Mapping functions require `cartopy`. Install using conda install -c conda-forge cartopy " ) # I could default to plt.gca() for ax, but does it work when I just pass # the axis object as positonal argument? ax.add_feature(cartopy.feature.LAND, color=land_color) ax.add_feature(cartopy.feature.COASTLINE, edgecolor=coast_color) ax.add_feature(cartopy.feature.LAKES, alpha=lake_alpha)
# add option for gridlines and labelling
[docs]def same_y_range(axes): """Adjusts multiple axes so that the range of y values is the same everywhere, but not the actual values. Parameters ---------- axes : np.array An array of matplotlib.axes objects produced by e.g. plt.subplots() """ ylims = [ax.get_ylim() for ax in axes.flat] yranges = [lim[1] - lim[0] for lim in ylims] # find the max range yrange_max = np.max(yranges) # determine the difference from max range for other ranges y_range_missing = [yrange_max - rang for rang in yranges] # define new ylims by expanding with (missing range / 2) at each end y_lims_new = [ np.array(lim) + np.array([-1, 1]) * yrm / 2 for lim, yrm in zip(ylims, y_range_missing) ] for ax, lim in zip(axes.flat, y_lims_new): ax.set_ylim(lim)
[docs]def center_lim(ax, which="y"): if which == "y": lim = np.array(ax.get_ylim()) ax.set_ylim(np.array([-1, 1]) * abs(lim).max()) elif which == "x": lim = np.array(ax.get_xlim()) ax.set_xlim(np.array([-1, 1]) * abs(lim).max()) elif which in ["xy", "yx"]: center_lim(ax, "x") center_lim(ax, "y") else: raise ValueError("`which` is not in (`x,`y`, `xy`) found %s" % which)
[docs]def depth_logscale(ax, yscale=400, ticks=None): if ticks is None: ticks = [0, 100, 250, 500, 1000, 2500, 5000] ax.set_yscale("symlog", linthreshy=yscale) ticklabels = [str(a) for a in ticks] ax.set_yticks(ticks) ax.set_yticklabels(ticklabels) ax.invert_yaxis()
[docs]def shaded_line_plot( da, dim, ax=None, horizontal=True, spreads=None, alphas=[0.25, 0.4], spread_style="std", line_kwargs=dict(), fill_kwargs=dict(), **kwargs, ): """Produces a line plot with shaded intervals based on the spread of `da` in `dim`. Parameters ---------- da : xr.DataArray The input data. Needs to be 2 dimensional, so that when `dim` is reduced, it is a line plot. dim : str Dimension of `da` which is used to calculate spread ax : matplotlib.axes Matplotlib axes object to plot on (the default is plt.gca()). horizontal : bool Determines if the plot is horizontal or vertical (e.g. x is plotted on the y-axis). spread : np.array, optional Values specifying the 'spread-values', dependent on `spread_style`. Defaults to shading the range of 1 and 2 standard deviations in `dim` alpha: np.array, optional Transparency values of the shaded ranges. Defaults to [0.5,0.15]. spread_style : str Metric used to define spread on `dim`. Options: 'std': Calculates standard deviation along `dim` and shading indicates multiples of std centered on the mean 'quantile': Calculates quantile ranges. An input of `spread=[0.2,0.5]` would show an inner shading for the 40th-60th percentile, and an outer shading for the 25th-75th percentile, centered on the 50th quantile (~median). Must be within [0,100]. line_kwargs : dict optional parameters for line plot. fill_kwargs : dict optional parameters for std fill plot. **kwargs Keyword arguments passed to both line plot and fill_between. Example ------ """ # check input if isinstance(spreads, float) or isinstance(spreads, int): spreads = [spreads] if isinstance(alphas, float): alphas = [alphas] if isinstance(dim, float): dim = [dim] # set axis if not ax: ax = plt.gca() # Option to plot a straight line when the dim is not present (TODO) # check if the data is 2 dimensional dims = da.mean(dim).dims if len(dims) != 1: raise ValueError( f"`da` must be 1 dimensional after reducing over {dim}. Found {dims}" ) # assemble plot elements xdim = dims[0] x = da[xdim] # define the line plot values if spread_style == "std": y = da.mean(dim) if spreads is None: spreads = [1, 3] elif spread_style in ["quantile", "percentile"]: y = da.quantile(0.5, dim) if spreads is None: spreads = [0.5, 0.8] else: raise ValueError( f"Got unknown option ['{spread_style}'] for `spread_style`. Supported options are : ['std', 'quantile']" ) # set line kwargs line_defaults = {} line_defaults.update(line_kwargs) if horizontal: ll = ax.plot(x, y, **line_defaults) else: ll = ax.plot(y, x, **line_defaults) # now loop over the spreads: fill_defaults = {"facecolor": ll[-1].get_color(), "edgecolor": "none"} # Apply defaults but respect input fill_defaults.update(fill_kwargs) ff = [] spreads = list(np.flip(spreads)) alphas = list(np.flip(alphas)) # np.flip(this ensures that the shadings are drawn from outer to inner otherwise they blend too much into each other for spread, alpha in zip(spreads, alphas): f_kwargs = {k: v for k, v in fill_defaults.items()} f_kwargs["alpha"] = alpha if spread_style == "std": y_std = da.std(dim) # i could probably precompute that. y_spread = y_std * spread y_lower = y - (y_spread / 2) y_upper = y + (y_spread / 2) elif spread_style in ["quantile", "percentile"]: y_lower = da.quantile(0.5 - (spread / 2), dim) y_upper = da.quantile(0.5 + (spread / 2), dim) if horizontal: ff.append(ax.fill_between(x.data, y_lower.data, y_upper.data, **f_kwargs)) else: ff.append(ax.fill_betweenx(x.data, y_lower.data, y_upper.data, **f_kwargs)) return ll, ff
[docs]def plot_line_shaded_std( x, y, std_y, horizontal=True, ax=None, line_kwargs=dict(), fill_kwargs=dict() ): """Plot wrapper to draw line for y and shaded patch according to std_y. The shading represents one std on each side of the line... Parameters ---------- x : numpy.array or xr.DataArray Coordinate. y : numpy.array or xr.DataArray line data. std_y : numpy.array or xr.DataArray std corresponding to y. horizontal : bool Determines if the plot is horizontal or vertical (e.g. x is plotted on the y-axis). ax : matplotlib.axes Matplotlib axes object to plot on (the default is plt.gca()). line_kwargs : dict optional parameters for line plot. fill_kwargs : dict optional parameters for std fill plot. Returns ------- (ll, ff) Tuple of line and patch objects. """ warnings.warn( "This is an outdated function. Use `shaded_line_plot` instead", DeprecationWarning, ) line_defaults = {} # Set plot defaults into the kwargs if not ax: ax = plt.gca() # Apply defaults but respect input line_defaults.update(line_kwargs) if horizontal: ll = ax.plot(x, y, **line_defaults) else: ll = ax.plot(y, x, **line_defaults) fill_defaults = { "facecolor": ll[-1].get_color(), "alpha": 0.35, "edgecolor": "none", } # Apply defaults but respect input fill_defaults.update(fill_kwargs) if horizontal: ff = ax.fill_between(x, y - std_y, y + std_y, **fill_defaults) else: ff = ax.fill_betweenx(x, y - std_y, y + std_y, **fill_defaults) return ll, ff
[docs]def box_plot(box, ax=None, split_detection="True", **kwargs): """plots box despite coordinate discontinuities. INPUT ----- box: np.array Defines the box in the coordinates of the current axis. Describing the box corners [x1, x2, y1, y2] ax: matplotlib.axis axis for plotting. Defaults to plt.gca() kwargs: optional anything that can be passed to plot can be put as kwarg """ if len(box) != 4: raise RuntimeError( "'box' must be a 4 element np.array, \ describing the box corners [x1, x2, y1, y2]" ) xlim = plt.gca().get_xlim() ylim = plt.gca().get_ylim() x_split = False y_split = False if ax is None: ax = plt.gca() if split_detection: if np.diff([box[0], box[1]]) < 0: x_split = True if np.diff([box[2], box[3]]) < 0: y_split = True if y_split and not x_split: ax.plot( [box[0], box[0], box[1], box[1], box[0]], [ylim[1], box[2], box[2], ylim[1], ylim[1]], **kwargs, ) ax.plot( [box[0], box[0], box[1], box[1], box[0]], [ylim[0], box[3], box[3], ylim[0], ylim[0]], **kwargs, ) elif x_split and not y_split: ax.plot( [xlim[1], box[0], box[0], xlim[1], xlim[1]], [box[2], box[2], box[3], box[3], box[2]], **kwargs, ) ax.plot( [xlim[0], box[1], box[1], xlim[0], xlim[0]], [box[2], box[2], box[3], box[3], box[2]], **kwargs, ) elif x_split and y_split: ax.plot([xlim[1], box[0], box[0]], [box[2], box[2], ylim[1]], **kwargs) ax.plot([xlim[0], box[1], box[1]], [box[2], box[2], ylim[1]], **kwargs) ax.plot([xlim[1], box[0], box[0]], [box[3], box[3], ylim[0]], **kwargs) ax.plot([xlim[0], box[1], box[1]], [box[3], box[3], ylim[0]], **kwargs) elif not x_split and not y_split: ax.plot( [box[0], box[0], box[1], box[1], box[0]], [box[2], box[3], box[3], box[2], box[2]], **kwargs, )
[docs]def dict2box(di, xdim="lon", ydim="lat"): return np.array([di[xdim].start, di[xdim].stop, di[ydim].start, di[ydim].stop])
[docs]def box_plot_dict(di, xdim="lon", ydim="lat", **kwargs): """plot box from xarray selection dict e.g. `{'xdim':slice(a, b), 'ydim':slice(c,d), ...}`""" # extract box from dict box = dict2box(di, xdim=xdim, ydim=ydim) # plot box_plot(box, **kwargs)
[docs]def draw_dens_contours_teos10( sigma="sigma0", add_labels=True, ax=None, density_grid=20, dens_interval=1.0, salt_on_x=True, slim=None, tlim=None, contour_kwargs={}, c_label_kwargs={}, **kwargs, ): """draws density contours on the current plot. Assumes that the salinity and temperature values are given as SA and CT. Needs documentation...""" if gsw is None: raise RuntimeError( "`gsw` is not available. Install with `conda install -c conda-forge gsw`" ) if ax is None: ax = plt.gca() if sigma not in ["sigma%i" % s for s in range(5)]: raise ValueError( "Sigma function has to be one of `sigma0`...`sigma4` \ is: %s" % (sigma) ) # get salt (default: xaxis) and temp (default: yaxis) limits if salt_on_x: if not (slim is None): slim = ax.get_xlim() if not (tlim is None): tlim = ax.get_ylim() x = np.linspace(*(slim + [density_grid])) y = np.linspace(*(tlim + [density_grid])) else: if not tlim: tlim = ax.get_xlim() if not slim: slim = ax.get_ylim() x = np.linspace(*(slim + [density_grid])) y = np.linspace(*(tlim + [density_grid])) if salt_on_x: ss, tt = np.meshgrid(x, y) else: tt, ss = np.meshgrid(x, y) sigma_func = getattr(gsw, sigma) sig = sigma_func(ss, tt) levels = np.arange(np.floor(sig.min()), np.ceil(sig.max()), dens_interval) c_kwarg_defaults = dict( levels=levels, colors="0.4", linestyles="--", linewidths=0.5 ) c_kwarg_defaults.update(kwargs) c_kwarg_defaults.update(contour_kwargs) c_label_kwarg_defaults = dict(fmt="%.02f") c_label_kwarg_defaults.update(kwargs) c_label_kwarg_defaults.update(c_label_kwargs) ch = ax.contour(x, y, sig, **c_kwarg_defaults) ax.clabel(ch, **c_label_kwarg_defaults) if add_labels: plt.text( 0.05, 0.05, "$\sigma_{%s}$" % (sigma[-1]), fontsize=14, verticalalignment="center", horizontalalignment="center", transform=ax.transAxes, color=c_kwarg_defaults["colors"], )
[docs]def tsdiagram( salt, temp, color=None, size=None, lon=None, lat=None, pressure=None, convert_teos10=True, ts_kwargs={}, ax=None, fig=None, draw_density_contours=True, draw_cbar=True, add_labels=True, **kwargs, ): if ax is None: ax = plt.gca() if fig is None: fig = plt.gcf() if convert_teos10: temp_label = "Conservative Temperature [$^{\circ}C$]" salt_label = "Absolute Salinity [$g/kg$]" if any([a is None for a in [lon, lat, pressure]]): raise ValueError( "when converting to teos10 variables, \ input for lon, lat and pressure is needed" ) else: salt = gsw.SA_from_SP(salt, pressure, lon, lat) temp = gsw.CT_from_pt(salt, temp) else: temp_label = "Potential Temperature [$^{\circ}C$]" salt_label = "Practical Salinity [$g/kg$]" if add_labels: ax.set_xlabel(salt_label) ax.set_ylabel(temp_label) scatter_kw_defaults = dict(s=size, c=color) scatter_kw_defaults.update(kwargs) s = ax.scatter(salt, temp, **scatter_kw_defaults) if draw_density_contours: draw_dens_contours_teos10(ax=ax, **ts_kwargs) if draw_cbar and color is not None: if isinstance(color, str) or isinstance(color, tuple): pass elif ( isinstance(color, list) or isinstance(color, np.ndarray) or isinstance(color, xr.DataArray) ): fig.colorbar(s, ax=ax) else: raise RuntimeError("`color` not recognized. %s" % type(color)) return s
[docs]def linear_piecewise_scale( cut, scale, ax=None, axis="y", scaled_half="upper", add_cut_line=False ): """This function sets a piecewise linear scaling for a given axis to highlight e.g. processes in the upper ocean vs deep ocean. Parameters ---------- cut : float value along the chosen axis used as transition between the two linear scalings. scale : float scaling coefficient for the chosen axis portion (determined by `axis` and `scaled_half`). A higher number means the chosen portion of the axis will be more compressed. Must be positive. 0 means no compression. ax : matplotlib.axis, optional The plot axis object. Defaults to current matplotlib axis axis : str, optional Which axis of the plot to act on. * 'y' (Default) * 'x' scaled_half: str, optional Determines which half of the axis is scaled (compressed). * 'upper' (default). Values larger than `cut` are compressed * 'lower'. Values smaller than `cut` are compressed Returns ------- ax_scaled : matplotlib.axis """ if ax is None: ax = plt.gca() if scale < 0: raise ValueError(f"`Scale can not be negative. Got value of {scale}") if scale == 0: # do nothing return ax else: if scaled_half == "upper": def inverse(x): return np.piecewise( x, [x <= cut, x > cut], [lambda x: x + (scale * (x - cut)), lambda x: x], ) def forward(x): return np.piecewise( x, [x <= cut, x > cut], [lambda x: x + (scale * (x - cut)), lambda x: x], ) elif scaled_half == "lower": def inverse(x): return np.piecewise( x, [x >= cut, x < cut], [lambda x: x + (scale * (x - cut)), lambda x: x], ) def forward(x): return np.piecewise( x, [x >= cut, x < cut], [lambda x: x + (scale * (x - cut)), lambda x: x], ) else: raise ValueError( f"`scaled_half` value not recognized. Must be ['upper', 'lower']. Got {scaled_half}" ) if axis == "y": axlim = ax.get_ylim() ax.set_yscale("function", functions=(forward, inverse)) ax.set_ylim(axlim) elif axis == "x": axlim = ax.get_xlim() ax.set_xscale("function", functions=(forward, inverse)) ax.set_xlim(axlim) else: raise ValueError( f"`axis` value not recognized. Must be ['x', 'y']. Got {axis}" ) return ax