diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c index 376269e5002..40fdecae75f 100644 --- a/lib/compress/zstd_compress.c +++ b/lib/compress/zstd_compress.c @@ -945,9 +945,11 @@ size_t ZSTD_CCtx_refPrefix_advanced( { RETURN_ERROR_IF(cctx->streamStage != zcss_init, stage_wrong); ZSTD_clearAllDicts(cctx); - cctx->prefixDict.dict = prefix; - cctx->prefixDict.dictSize = prefixSize; - cctx->prefixDict.dictContentType = dictContentType; + if (prefix != NULL && prefixSize > 0) { + cctx->prefixDict.dict = prefix; + cctx->prefixDict.dictSize = prefixSize; + cctx->prefixDict.dictContentType = dictContentType; + } return 0; } diff --git a/tests/fuzz/dictionary_round_trip.c b/tests/fuzz/dictionary_round_trip.c index a0e7037c0ea..ce3cd672797 100644 --- a/tests/fuzz/dictionary_round_trip.c +++ b/tests/fuzz/dictionary_round_trip.c @@ -32,6 +32,7 @@ static size_t roundTripTest(void *result, size_t resultCapacity, { ZSTD_dictContentType_e dictContentType = ZSTD_dct_auto; FUZZ_dict_t dict = FUZZ_train(src, srcSize, producer); + int const refPrefix = FUZZ_dataProducer_uint32Range(producer, 0, 1) != 0; size_t cSize; if (FUZZ_dataProducer_uint32Range(producer, 0, 15) == 0) { int const cLevel = FUZZ_dataProducer_int32Range(producer, kMinClevel, kMaxClevel); @@ -46,17 +47,27 @@ static size_t roundTripTest(void *result, size_t resultCapacity, FUZZ_setRandomParameters(cctx, srcSize, producer); /* Disable checksum so we can use sizes smaller than compress bound. */ FUZZ_ZASSERT(ZSTD_CCtx_setParameter(cctx, ZSTD_c_checksumFlag, 0)); - FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced( + if (refPrefix) + FUZZ_ZASSERT(ZSTD_CCtx_refPrefix_advanced( + cctx, dict.buff, dict.size, + dictContentType)); + else + FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced( cctx, dict.buff, dict.size, (ZSTD_dictLoadMethod_e)FUZZ_dataProducer_uint32Range(producer, 0, 1), dictContentType)); cSize = ZSTD_compress2(cctx, compressed, compressedCapacity, src, srcSize); } FUZZ_ZASSERT(cSize); - FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced( - dctx, dict.buff, dict.size, - (ZSTD_dictLoadMethod_e)FUZZ_dataProducer_uint32Range(producer, 0, 1), - dictContentType)); + if (refPrefix) + FUZZ_ZASSERT(ZSTD_DCtx_refPrefix_advanced( + dctx, dict.buff, dict.size, + dictContentType)); + else + FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced( + dctx, dict.buff, dict.size, + (ZSTD_dictLoadMethod_e)FUZZ_dataProducer_uint32Range(producer, 0, 1), + dictContentType)); { size_t const ret = ZSTD_decompressDCtx( dctx, result, resultCapacity, compressed, cSize);