Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-114563: C decimal falls back to pydecimal for unsupported format strings #114879

Merged
merged 12 commits into from
Feb 12, 2024
22 changes: 22 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,13 @@ def test_formatting(self):
('z>z6.1f', '-0.', 'zzz0.0'),
('x>z6.1f', '-0.', 'xxx0.0'),
('🖤>z6.1f', '-0.', '🖤🖤🖤0.0'), # multi-byte fill char
belm0 marked this conversation as resolved.
Show resolved Hide resolved
('\x00>z6.1f', '-0.', '\x00\x00\x000.0'), # null fill char

# issue 114563 ('z' format on F type in cdecimal)
('z3,.10F', '-6.24E-323', '0.0000000000'),

# issue 91060 ('#' format in cdecimal)
('#', '0', '0.'),

# issue 6850
('a=-7.0', '0.12345', 'aaaa0.1'),
Expand Down Expand Up @@ -5726,6 +5733,21 @@ def test_c_signaldict_segfault(self):
with self.assertRaisesRegex(ValueError, err_msg):
sd.copy()

def test_format_fallback_capitals(self):
# Fallback to _pydecimal formatting (triggered by `#` format which
# is unsupported by mpdecimal) should honor the current context.
x = C.Decimal('6.09e+23')
self.assertEqual(format(x, '#'), '6.09E+23')
with C.localcontext(capitals=0):
self.assertEqual(format(x, '#'), '6.09e+23')

def test_format_fallback_rounding(self):
y = C.Decimal('6.09')
self.assertEqual(format(y, '#.1f'), '6.1')
with C.localcontext(rounding=C.ROUND_DOWN):
self.assertEqual(format(y, '#.1f'), '6.0')


@requires_docstrings
@requires_cdecimal
class SignatureTest(unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Fix several `format()` bugs when using the C implementation of `Decimal`:
* memory leak in some rare cases when using the `z` format option (coerce negative 0)
belm0 marked this conversation as resolved.
Show resolved Hide resolved
* incorrect output when applying the `z` format option to type `F` (fixed-point with capital `NAN` / `INF`)
* incorrect output when applying the `#` format option (alternate form)
184 changes: 62 additions & 122 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ typedef struct {
/* Convert rationals for comparison */
PyObject *Rational;

/* Invariant: NULL or pointer to _pydecimal.Decimal */
PyObject *PyDecimal;

PyObject *SignalTuple;

struct DecCondMap *signal_map;
Expand Down Expand Up @@ -3336,56 +3339,6 @@ dotsep_as_utf8(const char *s)
return utf8;
}

/* copy of libmpdec _mpd_round() */
static void
_mpd_round(mpd_t *result, const mpd_t *a, mpd_ssize_t prec,
const mpd_context_t *ctx, uint32_t *status)
{
mpd_ssize_t exp = a->exp + a->digits - prec;

if (prec <= 0) {
mpd_seterror(result, MPD_Invalid_operation, status);
return;
}
if (mpd_isspecial(a) || mpd_iszero(a)) {
mpd_qcopy(result, a, status);
return;
}

mpd_qrescale_fmt(result, a, exp, ctx, status);
if (result->digits > prec) {
mpd_qrescale_fmt(result, result, exp+1, ctx, status);
}
}

/* Locate negative zero "z" option within a UTF-8 format spec string.
* Returns pointer to "z", else NULL.
* The portion of the spec we're working with is [[fill]align][sign][z] */
static const char *
format_spec_z_search(char const *fmt, Py_ssize_t size) {
char const *pos = fmt;
char const *fmt_end = fmt + size;
/* skip over [[fill]align] (fill may be multi-byte character) */
pos += 1;
while (pos < fmt_end && *pos & 0x80) {
pos += 1;
}
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
} else {
/* fill not present-- skip over [align] */
pos = fmt;
if (pos < fmt_end && strchr("<>=^", *pos) != NULL) {
pos += 1;
}
}
/* skip over [sign] */
if (pos < fmt_end && strchr("+- ", *pos) != NULL) {
pos += 1;
}
return pos < fmt_end && *pos == 'z' ? pos : NULL;
}

static int
dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const char **valuestr)
{
Expand All @@ -3411,6 +3364,48 @@ dict_get_item_string(PyObject *dict, const char *key, PyObject **valueobj, const
return 0;
}

/*
* Fallback _pydecimal formatting for new format specifiers that mpdecimal does
* not yet support. As documented, libmpdec follows the PEP-3101 format language:
* https://www.bytereef.org/mpdecimal/doc/libmpdec/assign-convert.html#to-string
*/
static PyObject *
pydec_format(PyObject *dec, PyObject *context, PyObject *fmt, decimal_state *state)
{
PyObject *result;
PyObject *pydec;
PyObject *u;

if (state->PyDecimal == NULL) {
state->PyDecimal = _PyImport_GetModuleAttrString("_pydecimal", "Decimal");
if (state->PyDecimal == NULL) {
return NULL;
}
}

u = dec_str(dec);
if (u == NULL) {
return NULL;
}

pydec = PyObject_CallOneArg(state->PyDecimal, u);
belm0 marked this conversation as resolved.
Show resolved Hide resolved
Py_DECREF(u);
if (pydec == NULL) {
return NULL;
}

result = PyObject_CallMethod(pydec, "__format__", "(OO)", fmt, context);
Py_DECREF(pydec);

if (result == NULL && PyErr_ExceptionMatches(PyExc_ValueError)) {
/* Do not confuse users with the _pydecimal exception */
PyErr_Clear();
PyErr_SetString(PyExc_ValueError, "invalid format string");
}

return result;
}

/* Formatted representation of a PyDecObject. */
static PyObject *
dec_format(PyObject *dec, PyObject *args)
Expand All @@ -3423,16 +3418,11 @@ dec_format(PyObject *dec, PyObject *args)
PyObject *fmtarg;
PyObject *context;
mpd_spec_t spec;
char const *fmt;
char *fmt_copy = NULL;
char *fmt;
serhiy-storchaka marked this conversation as resolved.
Show resolved Hide resolved
char *decstring = NULL;
uint32_t status = 0;
int replace_fillchar = 0;
int no_neg_0 = 0;
Py_ssize_t size;
mpd_t *mpd = MPD(dec);
mpd_uint_t dt[MPD_MINALLOC_MAX];
mpd_t tmp = {MPD_STATIC|MPD_STATIC_DATA,0,0,0,MPD_MINALLOC_MAX,dt};


decimal_state *state = get_module_state_by_def(Py_TYPE(dec));
Expand All @@ -3442,7 +3432,7 @@ dec_format(PyObject *dec, PyObject *args)
}

if (PyUnicode_Check(fmtarg)) {
fmt = PyUnicode_AsUTF8AndSize(fmtarg, &size);
fmt = (char *)PyUnicode_AsUTF8AndSize(fmtarg, &size);
if (fmt == NULL) {
return NULL;
}
Expand All @@ -3454,35 +3444,15 @@ dec_format(PyObject *dec, PyObject *args)
}
}

/* NOTE: If https://github.com/python/cpython/pull/29438 lands, the
* format string manipulation below can be eliminated by enhancing
* the forked mpd_parse_fmt_str(). */
if (size > 0 && fmt[0] == '\0') {
/* NUL fill character: must be replaced with a valid UTF-8 char
before calling mpd_parse_fmt_str(). */
replace_fillchar = 1;
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
fmt = dec_strdup(fmt, size);
if (fmt == NULL) {
return NULL;
}
fmt_copy[0] = '_';
}
/* Strip 'z' option, which isn't understood by mpd_parse_fmt_str().
* NOTE: fmt is always null terminated by PyUnicode_AsUTF8AndSize() */
char const *z_position = format_spec_z_search(fmt, size);
if (z_position != NULL) {
no_neg_0 = 1;
size_t z_index = z_position - fmt;
if (fmt_copy == NULL) {
fmt = fmt_copy = dec_strdup(fmt, size);
if (fmt_copy == NULL) {
return NULL;
}
}
/* Shift characters (including null terminator) left,
overwriting the 'z' option. */
memmove(fmt_copy + z_index, fmt_copy + z_index + 1, size - z_index);
size -= 1;
fmt[0] = '_';
}
}
else {
Expand All @@ -3492,10 +3462,13 @@ dec_format(PyObject *dec, PyObject *args)
}

if (!mpd_parse_fmt_str(&spec, fmt, CtxCaps(context))) {
PyErr_SetString(PyExc_ValueError,
"invalid format string");
goto finish;
if (replace_fillchar) {
PyMem_Free(fmt);
}

return pydec_format(dec, context, fmtarg, state);
}

if (replace_fillchar) {
/* In order to avoid clobbering parts of UTF-8 thousands separators or
decimal points when the substitution is reversed later, the actual
Expand Down Expand Up @@ -3548,45 +3521,8 @@ dec_format(PyObject *dec, PyObject *args)
}
}

if (no_neg_0 && mpd_isnegative(mpd) && !mpd_isspecial(mpd)) {
/* Round into a temporary (carefully mirroring the rounding
of mpd_qformat_spec()), and check if the result is negative zero.
If so, clear the sign and format the resulting positive zero. */
mpd_ssize_t prec;
mpd_qcopy(&tmp, mpd, &status);
if (spec.prec >= 0) {
switch (spec.type) {
case 'f':
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case '%':
tmp.exp += 2;
mpd_qrescale(&tmp, &tmp, -spec.prec, CTX(context), &status);
break;
case 'g':
prec = (spec.prec == 0) ? 1 : spec.prec;
if (tmp.digits > prec) {
_mpd_round(&tmp, &tmp, prec, CTX(context), &status);
}
break;
case 'e':
if (!mpd_iszero(&tmp)) {
_mpd_round(&tmp, &tmp, spec.prec+1, CTX(context), &status);
}
break;
}
}
if (status & MPD_Errors) {
PyErr_SetString(PyExc_ValueError, "unexpected error when rounding");
goto finish;
}
if (mpd_iszero(&tmp)) {
mpd_set_positive(&tmp);
mpd = &tmp;
}
}

decstring = mpd_qformat_spec(mpd, &spec, CTX(context), &status);
decstring = mpd_qformat_spec(MPD(dec), &spec, CTX(context), &status);
if (decstring == NULL) {
if (status & MPD_Malloc_error) {
PyErr_NoMemory();
Expand All @@ -3609,7 +3545,7 @@ dec_format(PyObject *dec, PyObject *args)
Py_XDECREF(grouping);
Py_XDECREF(sep);
Py_XDECREF(dot);
if (fmt_copy) PyMem_Free(fmt_copy);
if (replace_fillchar) PyMem_Free(fmt);
if (decstring) mpd_free(decstring);
return result;
}
Expand Down Expand Up @@ -5987,6 +5923,9 @@ _decimal_exec(PyObject *m)
Py_CLEAR(collections_abc);
Py_CLEAR(MutableMapping);

/* For format specifiers not yet supported by libmpdec */
state->PyDecimal = NULL;

/* Add types to the module */
CHECK_INT(PyModule_AddType(m, state->PyDec_Type));
CHECK_INT(PyModule_AddType(m, state->PyDecContext_Type));
Expand Down Expand Up @@ -6192,6 +6131,7 @@ decimal_clear(PyObject *module)
Py_CLEAR(state->extended_context_template);
Py_CLEAR(state->Rational);
Py_CLEAR(state->SignalTuple);
Py_CLEAR(state->PyDecimal);

PyMem_Free(state->signal_map);
PyMem_Free(state->cond_map);
Expand Down
Loading