Skip to content

Commit

Permalink
gh-114563: C decimal falls back to pydecimal for unsupported format s…
Browse files Browse the repository at this point in the history
…trings (GH-114879)

Immediate merits:
* eliminate complex workarounds for 'z' format support
  (NOTE: mpdecimal recently added 'z' support, so this becomes
  efficient in the long term.)
* fix 'z' format memory leak
* fix 'z' format applied to 'F'
* fix missing '#' format support

Suggested and prototyped by Stefan Krah.

Fixes gh-114563, gh-91060

Co-authored-by: Stefan Krah <[email protected]>
  • Loading branch information
belm0 and Stefan Krah authored Feb 12, 2024
1 parent 235cacf commit 72340d1
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 122 deletions.
22 changes: 22 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,13 @@ def test_formatting(self):
('z>z6.1f', '-0.', 'zzz0.0'),
('x>z6.1f', '-0.', 'xxx0.0'),
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char

# issue 114563 ('z' format on F type in cdecimal)
('z3,.10F', '-6.24E-323', '0.0000000000'),

# issue 91060 ('#' format in cdecimal)
('#', '0', '0.'),

# issue 6850
('a=-7.0', '0.12345', 'aaaa0.1'),
Expand Down Expand Up @@ -5726,6 +5733,21 @@ def test_c_signaldict_segfault(self):
with self.assertRaisesRegex(ValueError, err_msg):
sd.copy()

def test_format_fallback_capitals(self):
# Fallback to _pydecimal formatting (triggered by `#` format which
# is unsupported by mpdecimal) should honor the current context.
x = C.Decimal('6.09e+23')
self.assertEqual(format(x, '#'), '6.09E+23')
with C.localcontext(capitals=0):
self.assertEqual(format(x, '#'), '6.09e+23')

def test_format_fallback_rounding(self):
y = C.Decimal('6.09')
self.assertEqual(format(y, '#.1f'), '6.1')
with C.localcontext(rounding=C.ROUND_DOWN):
self.assertEqual(format(y, '#.1f'), '6.0')


@requires_docstrings
@requires_cdecimal
class SignatureTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix several :func:`format()` bugs when using the C implementation of :class:`~decimal.Decimal`:
* memory leak in some rare cases when using the ``z`` format option (coerce negative 0)
* incorrect output when applying the ``z`` format option to type ``F`` (fixed-point with capital ``NAN`` / ``INF``)
* incorrect output when applying the ``#`` format option (alternate form)
184 changes: 62 additions & 122 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ typedef struct {
/* Convert rationals for comparison */
PyObject *Rational;

/* Invariant: NULL or pointer to _pydecimal.Decimal */
PyObject *PyDecimal;

PyObject *SignalTuple;

struct DecCondMap *signal_map;
Expand Down Expand Up @@ -3336,56 +3339,6 @@ dotsep_as_utf8(const char *s)
return utf8;
}

/* copy of libmpdec _mpd_round() */
static void
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
const mpd_context_t *ctx, uint32_t *status)
{
mpd_ssize_t exp = a->exp + a->digits - prec;

if (prec <= 0) {
mpd_seterror(result, MPD_Invalid_operation, status);
return;
}
if (mpd_isspecial(a) || mpd_iszero(a)) {
mpd_qcopy(result, a, status);
return;
}

mpd_qrescale_fmt(result, a, exp, ctx, status);
if (result->digits > prec) {
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
}
}

/* Locate negative zero "z" option within a UTF-8 format spec string.
* Returns pointer to "z", else NULL.
* The portion of the spec we're working with is [[fill]align][sign][z] */
static const char *
format_spec_z_search(char const *fmt, Py_ssize_t size) {
char const *pos = fmt;
char const *fmt_end = fmt + size;
/* skip over [[fill]align] (fill may be multi-byte character) */
pos += 1;
while (pos < fmt_end && *pos & 0x80) {
pos += 1;
}
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
} else {
/* fill not present-- skip over [align] */
pos = fmt;
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
}
}
/* skip over [sign] */
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
pos += 1;
}
return pos < fmt_end && *pos == 'z' ? pos : NULL;
}

static int
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
{
Expand All @@ -3411,6 +3364,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
return 0;
}

/*
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
*/
static PyObject *
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt, decimal_state *state)
{
PyObject *result;
PyObject *pydec;
PyObject *u;

if (state->PyDecimal == NULL) {
state->PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
if (state->PyDecimal == NULL) {
return NULL;
}
}

u = dec_str(dec);
if (u == NULL) {
return NULL;
}

pydec = PyObject_CallOneArg(state->PyDecimal, u);
Py_DECREF(u);
if (pydec == NULL) {
return NULL;
}

result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
Py_DECREF(pydec);

if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
/* Do not confuse users with the _pydecimal exception */
PyErr_Clear();
PyErr_SetString(PyExc_ValueError, "invalid format string");
}

return result;
}

/* Formatted representation of a PyDecObject. */
static PyObject *
dec_format(PyObject *dec, PyObject *args)
Expand All @@ -3423,16 +3418,11 @@ dec_format(PyObject *dec, PyObject *args)
PyObject *fmtarg;
PyObject *context;
mpd_spec_t spec;
char const *fmt;
char *fmt_copy = NULL;
char *fmt;
char *decstring = NULL;
uint32_t status = 0;
int replace_fillchar = 0;
int no_neg_0 = 0;
Py_ssize_t size;
mpd_t *mpd = MPD(dec);
mpd_uint_t dt[MPD_MINALLOC_MAX];
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};


decimal_state *state = get_module_state_by_def(Py_TYPE(dec));
Expand All @@ -3442,7 +3432,7 @@ dec_format(PyObject *dec, PyObject *args)
}

if (PyUnicode_Check(fmtarg)) {
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
if (fmt == NULL) {
return NULL;
}
Expand All @@ -3454,35 +3444,15 @@ dec_format(PyObject *dec, PyObject *args)
}
}

/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
* format string manipulation below can be eliminated by enhancing
* the forked mpd_parse_fmt_str(). */
if (size > 0 && fmt[0] == '\0') {
/* NUL fill character: must be replaced with a valid UTF-8 char
before calling mpd_parse_fmt_str(). */
replace_fillchar = 1;
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
fmt = dec_strdup(fmt, size);
if (fmt == NULL) {
return NULL;
}
fmt_copy[0] = '_';
}
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
char const *z_position = format_spec_z_search(fmt, size);
if (z_position != NULL) {
no_neg_0 = 1;
size_t z_index = z_position - fmt;
if (fmt_copy == NULL) {
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
return NULL;
}
}
/* Shift characters (including null terminator) left,
overwriting the 'z' option. */
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
size -= 1;
fmt[0] = '_';
}
}
else {
Expand All @@ -3492,10 +3462,13 @@ dec_format(PyObject *dec, PyObject *args)
}

if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
PyErr_SetString(PyExc_ValueError,
"invalid format string");
goto finish;
if (replace_fillchar) {
PyMem_Free(fmt);
}

return pydec_format(dec, context, fmtarg, state);
}

if (replace_fillchar) {
/* In order to avoid clobbering parts of UTF-8 thousands separators or
decimal points when the substitution is reversed later, the actual
Expand Down Expand Up @@ -3548,45 +3521,8 @@ dec_format(PyObject *dec, PyObject *args)
}
}

if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
/* Round into a temporary (carefully mirroring the rounding
of mpd_qformat_spec()), and check if the result is negative zero.
If so, clear the sign and format the resulting positive zero. */
mpd_ssize_t prec;
mpd_qcopy(&tmp, mpd, &status);
if (spec.prec >= 0) {
switch (spec.type) {
case 'f':
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case '%':
tmp.exp += 2;
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case 'g':
prec = (spec.prec == 0) ? 1 : spec.prec;
if (tmp.digits > prec) {
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
}
break;
case 'e':
if (!mpd_iszero(&tmp)) {
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
}
break;
}
}
if (status & MPD_Errors) {
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
goto finish;
}
if (mpd_iszero(&tmp)) {
mpd_set_positive(&tmp);
mpd = &tmp;
}
}

decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
if (decstring == NULL) {
if (status & MPD_Malloc_error) {
PyErr_NoMemory();
Expand All @@ -3609,7 +3545,7 @@ dec_format(PyObject *dec, PyObject *args)
Py_XDECREF(grouping);
Py_XDECREF(sep);
Py_XDECREF(dot);
if (fmt_copy) PyMem_Free(fmt_copy);
if (replace_fillchar) PyMem_Free(fmt);
if (decstring) mpd_free(decstring);
return result;
}
Expand Down Expand Up @@ -5987,6 +5923,9 @@ _decimal_exec(PyObject *m)
Py_CLEAR(collections_abc);
Py_CLEAR(MutableMapping);

/* For format specifiers not yet supported by libmpdec */
state->PyDecimal = NULL;

/* Add types to the module */
CHECK_INT(PyModule_AddType(m, state->PyDec_Type));
CHECK_INT(PyModule_AddType(m, state->PyDecContext_Type));
Expand Down Expand Up @@ -6192,6 +6131,7 @@ decimal_clear(PyObject *module)
Py_CLEAR(state->extended_context_template);
Py_CLEAR(state->Rational);
Py_CLEAR(state->SignalTuple);
Py_CLEAR(state->PyDecimal);

PyMem_Free(state->signal_map);
PyMem_Free(state->cond_map);
Expand Down

0 comments on commit 72340d1

Please sign in to comment.