diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index cf3ae2615d..fa875f0be1 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -16,6 +16,7 @@ FixedLocator, LinearLocator, LogLocator, + SymmetricalLogLocator, MaxNLocator, MultipleLocator, EngFormatter, @@ -222,16 +223,48 @@ def spacer(x): return new - def tick(self, locator=None): + def tick(self, locator: Locator | None = None): + """ + Configure the selection of ticks for the scale's axis or legend. + + .. note:: + This API is under construction and will be enhanced over time. + At the moment, it is probably not very useful. + + Parameters + ---------- + locator: :class:`matplotlib.ticker.Locator` subclass + Pre-configured matplotlib locator; other parameters will not be used. + Returns + ------- + Copy of self with new tick configuration. + + """ new = copy(self) new._tick_params = { "locator": locator, } return new - def label(self, formatter=None): + def label(self, formatter: Formatter | None = None): + """ + Configure the selection of labels for the scale's axis or legend. + + .. note:: + This API is under construction and will be enhanced over time. + At the moment, it is probably not very useful. + Parameters + ---------- + formatter: :class:`matplotlib.ticker.Formatter` subclass + Pre-configured matplotlib formatter; other parameters will not be used. + + Returns + ------- + Copy of self with new tick configuration. + + """ new = copy(self) new._label_params = { "formatter": formatter, @@ -396,7 +429,7 @@ def tick( Parameters ---------- - locator: matplotlib Locator + locator: :class:`matplotlib.ticker.Locator` subclass Pre-configured matplotlib locator; other parameters will not be used. at : sequence of floats Place ticks at these specific locations (in data units). @@ -452,7 +485,7 @@ def tick( def label( self, - formatter=None, *, + formatter: Formatter | None = None, *, like: str | Callable | None = None, base: int | None = None, unit: str | None = None, @@ -462,7 +495,7 @@ def label( Parameters ---------- - formatter : matplotlib Formatter + formatter: :class:`matplotlib.ticker.Formatter` subclass Pre-configured formatter to use; other parameters will be ignored. like : str or callable Either a format pattern (e.g., `".2f"`), a format string with fields named @@ -499,17 +532,21 @@ def label( } return new + def _parse_for_log_params(self, transform): + + log_base = symlog_thresh = None + if isinstance(transform, str): + m = re.match(r"^log(\d*)", transform) + if m is not None: + log_base = float(m[1] or 10) + m = re.match(r"symlog(\d*)", transform) + if m is not None: + symlog_thresh = float(m[1] or 1) + return log_base, symlog_thresh + def _get_locators(self, locator, at, upto, count, every, between, minor): - # TODO what about symlog? - if isinstance(self.transform, str): - m = re.match(r"log(\d*)", self.transform) - log_transform = m is not None - log_base = m[1] or 10 if m is not None else None - forward, inverse = self._get_transform() - else: - log_transform = False - log_base = forward = inverse = None + log_base, symlog_thresh = self._parse_for_log_params(self.transform) if locator is not None: major_locator = locator @@ -517,7 +554,7 @@ def _get_locators(self, locator, at, upto, count, every, between, minor): # TODO raise if locator is passed with any other parameters elif upto is not None: - if log_transform: + if log_base: major_locator = LogLocator(base=log_base, numticks=upto) else: major_locator = MaxNLocator(upto, steps=[1, 1.5, 2, 2.5, 3, 5, 10]) @@ -527,7 +564,8 @@ def _get_locators(self, locator, at, upto, count, every, between, minor): # This is rarely useful (unless you are setting limits) major_locator = LinearLocator(count) else: - if log_transform: + if log_base or symlog_thresh: + forward, inverse = self._get_transform() lo, hi = forward(between) ticks = inverse(np.linspace(lo, hi, num=count)) else: @@ -546,12 +584,17 @@ def _get_locators(self, locator, at, upto, count, every, between, minor): major_locator = FixedLocator(at) else: - major_locator = LogLocator(log_base) if log_transform else AutoLocator() + if log_base: + major_locator = LogLocator(log_base) + elif symlog_thresh: + major_locator = SymmetricalLogLocator(linthresh=symlog_thresh, base=10) + else: + major_locator = AutoLocator() if minor is None: - minor_locator = LogLocator(log_base, subs=None) if log_transform else None + minor_locator = LogLocator(log_base, subs=None) if log_base else None else: - if log_transform: + if log_base: subs = np.linspace(0, log_base, minor + 2)[1:-1] minor_locator = LogLocator(log_base, subs=subs) else: @@ -561,11 +604,11 @@ def _get_locators(self, locator, at, upto, count, every, between, minor): def _get_formatter(self, locator, formatter, like, base, unit): - # TODO this has now been copied in a few places - if base is None and isinstance(self.transform, str): - # TODO handle symlog too - m = re.match(r"log(\d*)", self.transform) - base = m[1] or 10 if m is not None else None + log_base, symlog_thresh = self._parse_for_log_params(self.transform) + if base is None: + if symlog_thresh: + log_base = 10 + base = log_base if formatter is not None: return formatter @@ -623,11 +666,12 @@ def tick( """ Configure the selection of ticks for the scale's axis or legend. - This API is under construction and will be enhanced over time. + .. note:: + This API is under construction and will be enhanced over time. Parameters ---------- - locator: matplotlib Locator + locator: :class:`matplotlib.ticker.Locator` subclass Pre-configured matplotlib locator; other parameters will not be used. upto : int Choose "nice" locations for ticks, but do not exceed this number. @@ -656,11 +700,12 @@ def label( """ Configure the appearance of tick labels for the scale's axis or legend. - This API is under construction and will be enhanced over time. + .. note:: + This API is under construction and will be enhanced over time. Parameters ---------- - formatter : matplotlib Formatter + formatter: :class:`matplotlib.ticker.Formatter` subclass Pre-configured formatter to use; other parameters will be ignored. concise : bool If True, use :class:`matplotlib.dates.ConciseDateFormatter` to make diff --git a/tests/_core/test_scales.py b/tests/_core/test_scales.py index 57f458554b..24d12e6aac 100644 --- a/tests/_core/test_scales.py +++ b/tests/_core/test_scales.py @@ -212,6 +212,17 @@ def test_log_tick_every(self, x): with pytest.raises(RuntimeError, match="`every` not supported"): Continuous(transform="log").tick(every=2) + def test_symlog_tick_default(self, x): + + s = Continuous(transform="symlog")._setup(x, Coordinate()) + a = PseudoAxis(s._matplotlib_scale) + a.set_view_interval(-1050, 1050) + ticks = a.major.locator() + assert ticks[0] == -ticks[-1] + pos_ticks = np.sort(np.unique(np.abs(ticks))) + assert np.allclose(np.diff(np.log10(pos_ticks[1:])), 1) + assert pos_ticks[0] == 0 + def test_label_formatter(self, x): fmt = mpl.ticker.FormatStrFormatter("%.3f")