From 955a2fbf4e2b0140c6954a6344bf129fc07a7d27 Mon Sep 17 00:00:00 2001 From: yych0745 <1398089567@qq.com> Date: Tue, 7 Jan 2025 17:24:45 +0800 Subject: [PATCH 01/32] Add performance and accuracy test code for FP8 GEMM operations --- sgl-kernel/benchmark/bench_fp8_gemm.py | 71 ++++++++++++++++++++++++++ sgl-kernel/tests/test_fp8_gemm.py | 59 +++++++++++++++++++++ 2 files changed, 130 insertions(+) create mode 100644 sgl-kernel/benchmark/bench_fp8_gemm.py create mode 100644 sgl-kernel/tests/test_fp8_gemm.py diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py new file mode 100644 index 00000000000..d4bc2fdb91a --- /dev/null +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -0,0 +1,71 @@ +import torch +import torch.nn.functional as F +import triton + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + + +def to_int8(tensor: torch.Tensor) -> torch.Tensor: + return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=["vllm-fp8", "torch-fp8"], + line_names=["vllm-fp8", "torch-fp8"], + styles=[("green", "-"), ("blue", "-")], + ylabel="GB/s", + plot_name="int8 scaled matmul", + args={}, + ) +) +def benchmark(batch_size, provider): + M, N, K = batch_size, 8192, 21760 + a = torch.ones((M, K), device="cuda") * 5.0 + b = torch.ones((N, K), device="cuda") * 5.0 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm-fp8": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: vllm_scaled_mm( + a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, torch.bfloat16 + ), + quantiles=quantiles, + ) + if provider == "torch-fp8": + scale_a_2d = scale_a_fp8.float().unsqueeze(1) # [M, 1] + scale_b_2d = scale_b_fp8.float().unsqueeze(0) # [1, N] + try: + out = torch.empty( + (a_fp8.shape[0], b_fp8.shape[0]), device="cuda", dtype=torch.bfloat16 + ) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch._scaled_mm( + a_fp8, + b_fp8, + out=out, + out_dtype=torch.bfloat16, + scale_a=scale_a_2d, + scale_b=scale_b_2d, + use_fast_accum=True, + ), + quantiles=quantiles, + ) + except RuntimeError as e: + print("Error details:", e) + raise + gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +benchmark.run(print_data=True, show_plots=True, save_path="bench_int8_res") \ No newline at end of file diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py new file mode 100644 index 00000000000..a233b3b435a --- /dev/null +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -0,0 +1,59 @@ +import unittest + +import torch +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant + + +def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): + o = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + + o = o.to(torch.float32) + temp1 = o * scale_a.view(-1, 1) + temp2 = temp1 * scale_b.view(1, -1) + final = temp2.to(out_dtype) + + return final + + +class TestInt8Gemm(unittest.TestCase): + def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): + a = torch.randn((M, K), device=device) * 5 + b = torch.randn((N, K), device=device) * 5 + + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + if with_bias: + bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + else: + bias = None + o1 = torch.empty((a.shape[0], b.shape[1]), device="cuda", dtype=torch.bfloat16) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + b_fp8 = b_fp8.t() + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) + o = torch_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) + o1 = vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) + max_val = max(o.abs().max().item(), o1.abs().max().item()) + rtol = 2e-2 + atol = max_val * rtol + torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) + print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 128, 512, 1024, 4096] + Ns = [16, 128, 512, 1024, 4096] + Ks = [512, 1024, 4096, 8192, 16384] + bias_opts = [True, False] + out_dtypes = [torch.bfloat16] + for M in Ms: + for N in Ns: + for K in Ks: + for with_bias in bias_opts: + for out_dtype in out_dtypes: + self._test_accuracy_once( + M, N, K, with_bias, out_dtype, "cuda" + ) + + +if __name__ == "__main__": + unittest.main() From 30bdf20c81cdddf9eab4a9daba47742ab1e7fe17 Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Wed, 8 Jan 2025 19:25:23 +0800 Subject: [PATCH 02/32] support w8a8 fp8 --- sgl-kernel/CMakeLists.txt | 1 + sgl-kernel/setup.py | 1 + sgl-kernel/src/sgl-kernel/__init__.py | 2 + .../src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 571 ++++++++++++++++++ .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 6 + sgl-kernel/src/sgl-kernel/csrc/utils.hpp | 5 + sgl-kernel/src/sgl-kernel/ops/__init__.py | 11 + sgl-kernel/tests/test_fp8_gemm.py | 28 +- 8 files changed, 615 insertions(+), 10 deletions(-) create mode 100644 sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 3c267a4de50..c2bfd356c3d 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -32,6 +32,7 @@ add_library(_kernels SHARED src/sgl-kernel/csrc/trt_reduce_kernel.cu src/sgl-kernel/csrc/moe_align_kernel.cu src/sgl-kernel/csrc/int8_gemm_kernel.cu + src/sgl-kernel/csrc/fp8_gemm_kernel.cu src/sgl-kernel/csrc/sgl_kernel_ops.cu ) diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index c93e87f6bad..3a60f6ba0a6 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -50,6 +50,7 @@ def update_wheel_platform_tag(): "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", ], include_dirs=include_dirs, diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 892808f1ee1..2a4a2bd5177 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -3,6 +3,7 @@ custom_reduce, init_custom_reduce, int8_scaled_mm, + fp8_scaled_mm, moe_align_block_size, ) @@ -12,4 +13,5 @@ "custom_dispose", "custom_reduce", "int8_scaled_mm", + "fp8_scaled_mm", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu new file mode 100644 index 00000000000..79532893063 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -0,0 +1,571 @@ +/* + * Copyright (c) 2022-2024, Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wstrict-aliasing" +#endif // __GNUC__ + +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/conv/convolution.h" +// Order matters here, packed_stride.hpp is missing cute and convolution includes +#include "cutlass/util/packed_stride.hpp" + +#ifdef __GNUC__ // Check if the compiler is GCC or Clang +#pragma GCC diagnostic pop +#endif // __GNUC__ + +// #include "fp8_rowwise_gemm_kernel_template_sm89.h" +// #include "fp8_rowwise_gemm_kernel_template_sm90.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + + +#include "utils.hpp" +using namespace cute; + +template +struct DeviceGemmFp8RowwiseSm90 +{ + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // Auxiliary matrix configuration and other fusion types + using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using TileScheduler = TileSchedulerType; + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBias, ElementBias, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + using ComputeBias = cutlass::epilogue::fusion::Sm90Compute; + + using EVTComputeBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = EVTCompute1; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastDefaultSchedule; + using MainLoopSchedule = cute::conditional_t; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +// template +struct DeviceGemmFp8RowwiseSm89 +{ + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + using ElementA = ElementType; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementType; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = OutElementType; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + + using ElementOutput = OutElementType; + using LayoutOutput = cutlass::layout::RowMajor; + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + using ElementAccumulator = AccumElementType; + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm89; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + // Number of epilogue stages in EVT + static constexpr int EVTEpilogueStages = 1; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout; + + // Definition of EVT + using accSrc = cutlass::epilogue::threadblock::VisitorAccFetch; + + using ComputeBScale = cutlass::epilogue::threadblock::VisitorCompute; + using bScaleSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using EpilogueBScale = cutlass::epilogue::threadblock::Sm80EVT; + + using ComputeAScale = cutlass::epilogue::threadblock::VisitorCompute; + using aScaleSrc = cutlass::epilogue::threadblock::VisitorColBroadcast>; + using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; + + // // With bias + // using biasSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + // using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute; + // using EpilogueAScaleWithBias = cutlass::epilogue::threadblock::Sm80EVT; + + + using dTar = cutlass::epilogue::threadblock::VisitorAuxStore>; + using EpilogueStore = cutlass::epilogue::threadblock::Sm80EVT; + // using EpilogueStore = cutlass::platform::conditional, + // cutlass::epilogue::threadblock::Sm80EVT>::type; + + + using EpilogueOp = EpilogueStore; + + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + + +template +typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) +{ + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + + // int const lda = k; + // int const ldb = k; + // int const ldc = n; + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + + int64_t lda = a.stride(0); + int64_t ldb = b.stride(1); + int64_t ldc = out.stride(0); + + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode + {m, n, k}, // Problem size + 1, // Split-k factor + {}, // Epilogue args + ptr_a, // a pointer + ptr_b, // b pointer + nullptr, // c pointer (unused) + nullptr, // d pointer (unused) + m * k, // batch stride a (unused) + n * k, // batch stride b (unused) + m * n, // batch stride c (unused) + m * n, // batch stride d (unused) + lda, // stride a + ldb, // stride b + ldc, // stride c (unused) + ldc); // stride d (unused) + + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), + {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + return args; +} + +template +void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) +{ + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + Gemm gemm_op; + // CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + // auto status = gemm_op.run(args, workspace.data_ptr(), stream); + auto status = gemm_op(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) + // return typedFp8RowwiseGemmKernelLauncher( + // Gemm{}, args, D, A, B, C_bias, workspace, workspaceBytes, stream, occupancy); +} + + +template +void s89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(32), next_pow_2(m)); // next power of 2 + + uint32_t const n = out.size(1); + uint32_t const np2 = next_pow_2(n); + + if (mp2 <= 16) { + // M in [1, 16] + if (np2 <= 8192) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 24576) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 32) { + // M in (16, 32] + if (np2 <= 8192) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 64) { + // M in (32, 64] + if (np2 <= 8192) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 128) { + // M in (64, 128] + if (np2 <= 8192) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 16384) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } + } else if (mp2 <= 256) { + // M in (128, 256] + if (np2 <= 4096) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } else { + // M in (256, inf) + if (np2 <= 4096) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } else if (np2 <= 8192) { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + } else { + return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + } + } +} + +template +typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) +{ + using ElementT = typename Gemm::ElementA; + using ElementOutput = typename Gemm::ElementD; + using ElementComputeEpilogue = float; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + int32_t m = a.size(0); + int32_t n = b.size(1); + int32_t k = a.size(1); + ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); + ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); + ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); + ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + + // TODO: confirm correctess + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, make_shape(m, k, 1)); + StrideB stride_b = cutlass::make_cute_packed_stride(StrideB{}, make_shape(n, k, 1)); + StrideC stride_c; + StrideD stride_d = cutlass::make_cute_packed_stride(StrideD{}, make_shape(m, n, 1)); + typename Gemm::Arguments args + = {cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, {ptr_a, stride_a, ptr_b, stride_b}, + {{}, // epilogue.thread + nullptr, stride_c, ptr_d, stride_d}}; + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + return args; +} + +template +void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) +{ + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using MainloopScheduleType = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = void; + using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); + + // Launch the CUTLASS GEMM kernel. + Gemm gemm_op; + // CUTLASS_CHECK(gemm_op.can_implement(args)); + + size_t workspace_size = gemm_op.get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + + auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement = gemm_op.can_implement(args); + TORCH_CHECK(can_implement == cutlass::Status::kSuccess) + + auto status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess) + // cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); + // CUTLASS_CHECK(status); +// return typedFp8RowwiseGemmKernelLauncher( +// Gemm{}, args, D, A, B, C_bias, workspace, workspaceBytes, stream, occupancy); +// #else // COMPILE_HOPPER_TMA_GEMMS +// throw std::runtime_error( +// "[TensorRT-LLm Error][Fp8RowwiseGemmKernelLauncherSm90] Please recompile with support for hopper by passing " +// "90-real as an arch to build_wheel.py."); +// #endif // COMPILE_HOPPER_TMA_GEMMS +} + +template +void s90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return launch_sm90_fp8_scaled_mm, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); + } else if (mp2 <= 128) { + // m in (64, 128] + return launch_sm90_fp8_scaled_mm, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + } else { + // m in (128, inf) + return launch_sm90_fp8_scaled_mm, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + } +} + +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias) { + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment"); +// TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment"); +//TODO: % 8 + TORCH_CHECK(mat_b.size(1) % 16 == 0, "mat_b.size(1) must be multiple of 16 for memory alignment"); // out.stride(0) + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b msut be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + + auto sm_version = getSMVersion(); + + if (sm_version >= 90) { + if (out_dtype == torch::kBFloat16) { + s90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + s90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + s89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + s89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } + } else { + TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability: ", sm_version); + } + + return out; +} diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index 6ed543e6c54..b12d324cc62 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -17,6 +17,10 @@ torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& ma const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, const torch::Dtype& out_dtype, + const c10::optional& bias); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -26,4 +30,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); // int8_scaled_mm m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); + // fp8_scaled_mm + m.def("fp8_scaled_mm", &fp8_scaled_mm, "FP8 scaled matmul (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp index 2fed2d60c03..5820b1350ab 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/utils.hpp +++ b/sgl-kernel/src/sgl-kernel/csrc/utils.hpp @@ -44,3 +44,8 @@ inline int getSMVersion() { CHECK_CUDA_SUCCESS(cudaDeviceGetAttribute(&sm_minor, cudaDevAttrComputeCapabilityMinor, device)); return sm_major * 10 + sm_minor; } + +inline uint32_t next_pow_2(uint32_t const num) { + if (num <= 1) return num; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index e388ae35653..f339997b027 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -2,6 +2,7 @@ from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm +from sgl_kernel.ops._kernels import fp8_scaled_mm as _fp8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size @@ -48,3 +49,13 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): out_dtype, bias, ) + +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return _fp8_scaled_mm( + mat_a, + mat_b, + scales_a, + scales_b, + out_dtype, + bias, + ) diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py index a233b3b435a..c303bef1d1e 100644 --- a/sgl-kernel/tests/test_fp8_gemm.py +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -3,6 +3,7 @@ import torch from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from sgl_kernel import fp8_scaled_mm def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): @@ -16,15 +17,16 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): return final -class TestInt8Gemm(unittest.TestCase): +class TestFp8Gemm(unittest.TestCase): def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): - a = torch.randn((M, K), device=device) * 5 - b = torch.randn((N, K), device=device) * 5 + a = torch.randn((M, K), device=device) + b = torch.randn((N, K), device=device) - scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) - scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) * 0.01 + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) * 0.01 if with_bias: - bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + # bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 + bias = torch.randn((N,), device="cuda", dtype=out_dtype) else: bias = None o1 = torch.empty((a.shape[0], b.shape[1]), device="cuda", dtype=torch.bfloat16) @@ -32,9 +34,10 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): b_fp8 = b_fp8.t() a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) o = torch_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) - o1 = vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) + # o1 = vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) + o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) max_val = max(o.abs().max().item(), o1.abs().max().item()) - rtol = 2e-2 + rtol = 4e-3 atol = max_val * rtol torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") @@ -43,8 +46,13 @@ def test_accuracy(self): Ms = [1, 128, 512, 1024, 4096] Ns = [16, 128, 512, 1024, 4096] Ks = [512, 1024, 4096, 8192, 16384] - bias_opts = [True, False] - out_dtypes = [torch.bfloat16] + # Ms = [128] + # Ns = [512] + # Ks = [4096] + # bias_opts = [True, False] + bias_opts = [False] + out_dtypes = [torch.bfloat16, torch.float16] + # out_dtypes = [torch.float16] for M in Ms: for N in Ns: for K in Ks: From 4cac9fb925f58e3e90fa5c7053ad10d42afa099b Mon Sep 17 00:00:00 2001 From: HandH1998 <1335248067@qq.com> Date: Thu, 9 Jan 2025 17:41:46 +0800 Subject: [PATCH 03/32] support bias --- .../src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 505 +++++++++--------- sgl-kernel/tests/test_fp8_gemm.py | 21 +- 2 files changed, 262 insertions(+), 264 deletions(-) diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index 79532893063..ef88110e925 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -1,32 +1,9 @@ -/* - * Copyright (c) 2022-2024, Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - * - * Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_template.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm89.h +// https://github.com/NVIDIA/TensorRT-LLM/blob/v0.16.0/cpp/tensorrt_llm/kernels/cutlass_kernels/fp8_rowwise_gemm/fp8_rowwise_gemm_kernel_template_sm90.h #pragma once -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#endif // __GNUC__ - #include #include @@ -35,13 +12,6 @@ // Order matters here, packed_stride.hpp is missing cute and convolution includes #include "cutlass/util/packed_stride.hpp" -#ifdef __GNUC__ // Check if the compiler is GCC or Clang -#pragma GCC diagnostic pop -#endif // __GNUC__ - -// #include "fp8_rowwise_gemm_kernel_template_sm89.h" -// #include "fp8_rowwise_gemm_kernel_template_sm90.h" - #include "cutlass/cutlass.h" #include "cutlass/gemm/device/gemm.h" #include "cutlass/gemm/device/gemm_universal_adapter.h" @@ -59,127 +29,11 @@ #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/device/gemm_universal_adapter.h" - #include "utils.hpp" using namespace cute; -template -struct DeviceGemmFp8RowwiseSm90 -{ - static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); - - // A matrix configuration - using ElementA = ElementType; // Element type for A matrix operand - using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand - static constexpr int AlignmentA - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A - // matrix in units of elements (up to 16 bytes) - - // B matrix configuration - using ElementB = ElementType; // Element type for B matrix operand - using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand - static constexpr int AlignmentB - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B - // matrix in units of elements (up to 16 bytes) - - // C/D matrix configuration - using ElementC = void; // Element type for C matrix operands - using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands - static constexpr int AlignmentC - = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in - // units of elements (up to 16 bytes) - - // Output matrix configuration - using ElementOutput = OutElementType; // Element type for output matrix operands - using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands - static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; - - // Auxiliary matrix configuration and other fusion types - using ElementBias = float; - - // Multiply-accumulate blocking/pipelining details - using ElementAccumulator = AccumElementType; // Element type for internal accumulation - using ElementCompute = float; // Element type for compute - using ElementComputeEpilogue = float; - using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature - using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag - using TileShape = CTAShape; // Threadblock-level tile size - using TileScheduler = TileSchedulerType; - - static constexpr bool PONG = false; - static constexpr bool FAST_ACCUM = true; - static constexpr bool USE_BIAS = false; - - using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized - // based on the tile size - using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default - // setting in the Collective Builder - // Implement rowwise scaling epilogue. - using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<0>, cute::Int<0>>>; - - using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementBias, ElementBias, - cute::Stride, cute::Int<1>, cute::Int<0>>>; - - using Accum = cutlass::epilogue::fusion::Sm90AccFetch; - - using Compute0 = cutlass::epilogue::fusion::Sm90Compute; - - using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; - - using Compute1 = cutlass::epilogue::fusion::Sm90Compute; - - using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; - - using ComputeBias = cutlass::epilogue::fusion::Sm90Compute; - - using EVTComputeBias = cutlass::epilogue::fusion::Sm90EVT; - - using EpilogueEVT = EVTCompute1; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; - - using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; - using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; - using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; - using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; - - using SlowAccum = DefaultSchedule; - using FastAccum = FastDefaultSchedule; - using MainLoopSchedule = cute::conditional_t; - - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder( - sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainLoopSchedule>::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape - CollectiveMainloop, CollectiveEpilogue, TileScheduler>; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; -}; - template -// template + typename WarpShape, int Stages, bool WithBias> struct DeviceGemmFp8RowwiseSm89 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -227,20 +81,17 @@ struct DeviceGemmFp8RowwiseSm89 Stride<_1, _0, _0>>; using EpilogueAScale = cutlass::epilogue::threadblock::Sm80EVT; - // // With bias - // using biasSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; - // using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute; - // using EpilogueAScaleWithBias = cutlass::epilogue::threadblock::Sm80EVT; - + // With bias + using biasSrc = cutlass::epilogue::threadblock::VisitorRowBroadcast>; + using ComputeAScaleWithBias = cutlass::epilogue::threadblock::VisitorCompute; + using EpilogueAScaleWithBias = cutlass::epilogue::threadblock::Sm80EVT; using dTar = cutlass::epilogue::threadblock::VisitorAuxStore>; - using EpilogueStore = cutlass::epilogue::threadblock::Sm80EVT; - // using EpilogueStore = cutlass::platform::conditional, - // cutlass::epilogue::threadblock::Sm80EVT>::type; + using EpilogueStore = typename cutlass::platform::conditional, + cutlass::epilogue::threadblock::Sm80EVT>::type; - using EpilogueOp = EpilogueStore; using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor +template typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) @@ -262,9 +113,6 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: using ElementOutput = typename Gemm::ElementD; using ElementComputeEpilogue = float; - // int const lda = k; - // int const ldb = k; - // int const ldc = n; int32_t m = a.size(0); int32_t n = b.size(1); int32_t k = a.size(1); @@ -275,16 +123,22 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); + typename Gemm::Arguments args(cutlass::gemm::GemmUniversalMode::kGemm, // Mode {m, n, k}, // Problem size 1, // Split-k factor {}, // Epilogue args - ptr_a, // a pointer - ptr_b, // b pointer + ptr_a, // a pointer + ptr_b, // b pointer nullptr, // c pointer (unused) nullptr, // d pointer (unused) m * k, // batch stride a (unused) @@ -295,8 +149,22 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: ldb, // stride b ldc, // stride c (unused) ldc); // stride d (unused) - - args.epilogue = { + if constexpr (WithBias) { + args.epilogue = { + { + { + {}, // Accumulator + {ptr_scales_b, ElementComputeEpilogue(0), + {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_scales_a, ElementComputeEpilogue(0), {_1{}, _0{}, _0{}}}, + {ptr_bias, ElementOutput(0), {_0{}, _1{}, _0{}}}, + {} // Multiplies + }, + {ptr_d, {n, _1{}, _0{}}}}; + } else { + args.epilogue = { { { {}, // Accumulator @@ -308,45 +176,53 @@ typename Gemm::Arguments prepare_sm89_fp8_args(torch::Tensor& out, const torch:: {} // Multiplies }, {ptr_d, {n, _1{}, _0{}}}}; + } + return args; } -template +template void launch_sm89_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { - using ElementInput = cutlass::float_e4m3_t; - using ElementOutput = OutType; - using AccumElementType = float; - - using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; - - auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); + auto args = prepare_sm89_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; - // CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto can_implement = gemm_op.can_implement(args); TORCH_CHECK(can_implement == cutlass::Status::kSuccess) - // auto status = gemm_op.run(args, workspace.data_ptr(), stream); auto status = gemm_op(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess) - // return typedFp8RowwiseGemmKernelLauncher( - // Gemm{}, args, D, A, B, C_bias, workspace, workspaceBytes, stream, occupancy); +} + +template +void sm89_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm89::Gemm; + return launch_sm89_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } } template -void s89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, +void sm89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { uint32_t const m = a.size(0); @@ -359,59 +235,170 @@ void s89_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch: if (mp2 <= 16) { // M in [1, 16] if (np2 <= 8192) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else if (np2 <= 24576) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } } else if (mp2 <= 32) { // M in (16, 32] if (np2 <= 8192) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else if (np2 <= 16384) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 4>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<16, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } } else if (mp2 <= 64) { // M in (32, 64] if (np2 <= 8192) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else if (np2 <= 16384) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<32, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } } else if (mp2 <= 128) { // M in (64, 128] if (np2 <= 8192) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); } else if (np2 <= 16384) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); } } else if (mp2 <= 256) { // M in (128, 256] if (np2 <= 4096) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } } else { // M in (256, inf) if (np2 <= 4096) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } else if (np2 <= 8192) { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 3>(out, a, b, scales_a, scales_b, bias); } else { - return launch_sm89_fp8_scaled_mm, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); + return sm89_dispatch_bias, cutlass::gemm::GemmShape<64, 64, 64>, 5>(out, a, b, scales_a, scales_b, bias); } } } -template +template +struct DeviceGemmFp8RowwiseSm90 +{ + static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); + + // A matrix configuration + using ElementA = ElementType; // Element type for A matrix operand + using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand + static constexpr int AlignmentA + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A + // matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using ElementB = ElementType; // Element type for B matrix operand + using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand + static constexpr int AlignmentB + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B + // matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using ElementC = void; // Element type for C matrix operands + using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands + static constexpr int AlignmentC + = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in + // units of elements (up to 16 bytes) + + // Output matrix configuration + using ElementOutput = OutElementType; // Element type for output matrix operands + using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands + static constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + + // // Auxiliary matrix configuration and other fusion types + // using ElementBias = float; + + // Multiply-accumulate blocking/pipelining details + using ElementAccumulator = AccumElementType; // Element type for internal accumulation + using ElementCompute = float; // Element type for compute + using ElementComputeEpilogue = float; + using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using TileShape = CTAShape; // Threadblock-level tile size + using TileScheduler = TileSchedulerType; + + static constexpr bool PONG = false; + static constexpr bool FAST_ACCUM = true; + static constexpr bool USE_BIAS = false; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized + // based on the tile size + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default + // setting in the Collective Builder + // Implement rowwise scaling epilogue. + using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<0>, cute::Int<0>>>; + + using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementComputeEpilogue, ElementComputeEpilogue, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<0, TileShape, ElementOutput, ElementOutput, + cute::Stride, cute::Int<1>, cute::Int<0>>>; + + using Accum = cutlass::epilogue::fusion::Sm90AccFetch; + + using Compute0 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute0 = cutlass::epilogue::fusion::Sm90EVT; + + using Compute1 = cutlass::epilogue::fusion::Sm90Compute; + + using EVTCompute1 = cutlass::epilogue::fusion::Sm90EVT; + + // With bias + using ComputeWithBias = cutlass::epilogue::fusion::Sm90Compute; + using EVTComputeWithBias = cutlass::epilogue::fusion::Sm90EVT; + + using EpilogueEVT = typename cutlass::platform::conditional::type; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp; + + using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized; + using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong; + using FastDefaultSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using FastPongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum; + + using SlowAccum = DefaultSchedule; + using FastAccum = FastDefaultSchedule; + using MainLoopSchedule = cute::conditional_t; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainLoopSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal, // Indicates ProblemShape + CollectiveMainloop, CollectiveEpilogue, TileScheduler>; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) @@ -429,6 +416,11 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: int32_t k = a.size(1); ElementT const* ptr_a = reinterpret_cast(a.data_ptr()); ElementT const* ptr_b = reinterpret_cast(b.data_ptr()); + ElementOutput const* ptr_bias = nullptr; + if constexpr (WithBias) { + TORCH_CHECK(bias.has_value()) + ptr_bias = reinterpret_cast(bias.value().data_ptr()); + } ElementOutput* ptr_d = reinterpret_cast(out.data_ptr()); ElementComputeEpilogue const* ptr_scales_a = reinterpret_cast(scales_a.data_ptr()); ElementComputeEpilogue const* ptr_scales_b = reinterpret_cast(scales_b.data_ptr()); @@ -442,41 +434,42 @@ typename Gemm::Arguments prepare_sm90_fp8_args(torch::Tensor& out, const torch:: = {cutlass::gemm::GemmUniversalMode::kGemm, {m, n, k, 1}, {ptr_a, stride_a, ptr_b, stride_b}, {{}, // epilogue.thread nullptr, stride_c, ptr_d, stride_d}}; - args.epilogue.thread = { - {ptr_scales_a}, - { - {ptr_scales_b}, {}, // Accumulator - {} // Multiplies - }, - {}, // Multiplies - }; + if constexpr (WithBias) { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, {}, // Accumulator + {} // Multiplies + }, + {ptr_bias}, + {}, // Multiplies + }; + } else { + args.epilogue.thread = { + {ptr_scales_a}, + { + {ptr_scales_b}, {}, // Accumulator + {} // Multiplies + }, + {}, // Multiplies + }; + } + return args; } -template +template void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { - using ElementInput = cutlass::float_e4m3_t; - using ElementOutput = OutType; - using AccumElementType = float; - using MainloopScheduleType = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; - using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; - using TileSchedulerType = void; - using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; - auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); - - // Launch the CUTLASS GEMM kernel. + auto args = prepare_sm90_fp8_args(out, a, b, scales_a, scales_b, bias); Gemm gemm_op; - // CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); auto const workspace_options = torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); auto workspace = torch::empty(workspace_size, workspace_options); - auto stream = at::cuda::getCurrentCUDAStream(a.get_device()); auto can_implement = gemm_op.can_implement(args); @@ -484,19 +477,32 @@ void launch_sm90_fp8_scaled_mm(torch::Tensor& out, const torch::Tensor& a, const auto status = gemm_op.run(args, workspace.data_ptr(), stream); TORCH_CHECK(status == cutlass::Status::kSuccess) - // cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); - // CUTLASS_CHECK(status); -// return typedFp8RowwiseGemmKernelLauncher( -// Gemm{}, args, D, A, B, C_bias, workspace, workspaceBytes, stream, occupancy); -// #else // COMPILE_HOPPER_TMA_GEMMS -// throw std::runtime_error( -// "[TensorRT-LLm Error][Fp8RowwiseGemmKernelLauncherSm90] Please recompile with support for hopper by passing " -// "90-real as an arch to build_wheel.py."); -// #endif // COMPILE_HOPPER_TMA_GEMMS +} + + +template +void sm90_dispatch_bias(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, + const torch::Tensor& scales_b, + const c10::optional& bias) { + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutType; + using AccumElementType = float; + using MainloopScheduleType = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum; + using EpilogueScheduleType = cutlass::epilogue::TmaWarpSpecialized; + using TileSchedulerType = void; + if (bias) { + using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } else { + using Gemm = typename DeviceGemmFp8RowwiseSm90::Gemm; + return launch_sm90_fp8_scaled_mm(out, a, b, scales_a, scales_b, bias); + } } template -void s90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, +void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const c10::optional& bias) { uint32_t const m = a.size(0); @@ -505,13 +511,13 @@ void s90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch: if (mp2 <= 64) { // m in [1, 64] - return launch_sm90_fp8_scaled_mm, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); } else if (mp2 <= 128) { // m in (64, 128] - return launch_sm90_fp8_scaled_mm, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); } else { // m in (128, inf) - return launch_sm90_fp8_scaled_mm, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); } } @@ -526,10 +532,8 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); - TORCH_CHECK(mat_a.size(1) % 16 == 0, "mat_a.size(1) must be multiple of 16 for memory alignment"); -// TORCH_CHECK(mat_b.size(0) % 16 == 0, "mat_b.size(0) must be multiple of 16 for memory alignment"); -//TODO: % 8 - TORCH_CHECK(mat_b.size(1) % 16 == 0, "mat_b.size(1) must be multiple of 16 for memory alignment"); // out.stride(0) + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); @@ -548,21 +552,22 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat } torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); auto sm_version = getSMVersion(); if (sm_version >= 90) { - if (out_dtype == torch::kBFloat16) { - s90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); - } else { - s90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); - } + if (out_dtype == torch::kBFloat16) { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm90_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } } else if (sm_version == 89) { - if (out_dtype == torch::kBFloat16) { - s89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); - } else { - s89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); - } + if (out_dtype == torch::kBFloat16) { + sm89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } else { + sm89_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); + } } else { TORCH_CHECK_NOT_IMPLEMENTED(false, "No implemented int8_scaled_mm for current compute capability: ", sm_version); } diff --git a/sgl-kernel/tests/test_fp8_gemm.py b/sgl-kernel/tests/test_fp8_gemm.py index c303bef1d1e..2a474d7ea17 100644 --- a/sgl-kernel/tests/test_fp8_gemm.py +++ b/sgl-kernel/tests/test_fp8_gemm.py @@ -1,7 +1,6 @@ import unittest import torch -from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant from sgl_kernel import fp8_scaled_mm @@ -13,6 +12,8 @@ def torch_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias): temp1 = o * scale_a.view(-1, 1) temp2 = temp1 * scale_b.view(1, -1) final = temp2.to(out_dtype) + if bias is not None: + final = final + bias.view(1, -1) return final @@ -22,10 +23,9 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): a = torch.randn((M, K), device=device) b = torch.randn((N, K), device=device) - scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) * 0.01 - scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) * 0.01 + scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) * 0.001 + scale_b = torch.randn((N,), device="cuda", dtype=torch.float32) * 0.001 if with_bias: - # bias = torch.ones((N,), device="cuda", dtype=out_dtype) * 10 bias = torch.randn((N,), device="cuda", dtype=out_dtype) else: bias = None @@ -34,11 +34,9 @@ def _test_accuracy_once(self, M, N, K, with_bias, out_dtype, device): b_fp8 = b_fp8.t() a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) o = torch_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) - # o1 = vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) o1 = fp8_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, out_dtype, bias) - max_val = max(o.abs().max().item(), o1.abs().max().item()) - rtol = 4e-3 - atol = max_val * rtol + rtol = 0.01 + atol = 0.1 torch.testing.assert_close(o, o1, rtol=rtol, atol=atol) print(f"M={M}, N={N}, K={K}, with_bias={with_bias}, out_dtype={out_dtype}: OK") @@ -46,13 +44,8 @@ def test_accuracy(self): Ms = [1, 128, 512, 1024, 4096] Ns = [16, 128, 512, 1024, 4096] Ks = [512, 1024, 4096, 8192, 16384] - # Ms = [128] - # Ns = [512] - # Ks = [4096] - # bias_opts = [True, False] - bias_opts = [False] + bias_opts = [True, False] out_dtypes = [torch.bfloat16, torch.float16] - # out_dtypes = [torch.float16] for M in Ms: for N in Ns: for K in Ks: From ecc90a484fb6a150d4a76b760baad4640b2ae064 Mon Sep 17 00:00:00 2001 From: yych0745 <1398089567@qq.com> Date: Fri, 10 Jan 2025 17:22:15 +0800 Subject: [PATCH 04/32] opitmize --- sgl-kernel/benchmark/bench_fp8_gemm.py | 89 +++++++++- .../benchmark/bench_int8_res/results.html | 3 + sgl-kernel/benchmark/best_fp8_configs.json | 42 +++++ sgl-kernel/setup.py | 38 ++++- sgl-kernel/src/sgl-kernel/__init__.py | 2 + .../src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 159 ++++++++++++++++-- .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 15 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 22 ++- 8 files changed, 338 insertions(+), 32 deletions(-) create mode 100644 sgl-kernel/benchmark/bench_int8_res/results.html create mode 100644 sgl-kernel/benchmark/best_fp8_configs.json diff --git a/sgl-kernel/benchmark/bench_fp8_gemm.py b/sgl-kernel/benchmark/bench_fp8_gemm.py index d4bc2fdb91a..65efce4417c 100644 --- a/sgl-kernel/benchmark/bench_fp8_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_gemm.py @@ -4,8 +4,9 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant - - +from sgl_kernel import fp8_scaled_mm as sgl_scaled_mm +from sgl_kernel import fp8_scaled_mm_profile as sgl_scaled_mm_profile +import time def to_int8(tensor: torch.Tensor) -> torch.Tensor: return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) @@ -16,16 +17,18 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048], x_log=False, line_arg="provider", - line_vals=["vllm-fp8", "torch-fp8"], - line_names=["vllm-fp8", "torch-fp8"], - styles=[("green", "-"), ("blue", "-")], + # line_vals=["vllm-fp8", "torch-fp8", "sglang-fp8"], + # line_names=["vllm-fp8", "torch-fp8", "sglang-fp8"], + line_vals=["vllm-fp8", "sglang-fp8", "sglang-fp8-profile"], + line_names=["vllm-fp8", "sglang-fp8", "sglang-fp8-profile"], + styles=[("green", "-"), ("blue", "-"), ("red", "-")], ylabel="GB/s", plot_name="int8 scaled matmul", args={}, ) ) def benchmark(batch_size, provider): - M, N, K = batch_size, 8192, 21760 + M, N, K = batch_size, 4096, 8192 a = torch.ones((M, K), device="cuda") * 5.0 b = torch.ones((N, K), device="cuda") * 5.0 scale_a = torch.randn((M,), device="cuda", dtype=torch.float32) @@ -64,6 +67,80 @@ def benchmark(batch_size, provider): except RuntimeError as e: print("Error details:", e) raise + if provider == "sglang-fp8": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, torch.bfloat16), + quantiles=quantiles, + ) + if provider == "sglang-fp8-profile": + best_configs = [] + times = [] + valid_configs = [] + best_config_info = {} # 新增:用于存储每个输入规模的最优配置信息 + + try: + sgl_scaled_mm_profile(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, torch.bfloat16, bias=None, config_id=35) + except RuntimeError as e: + print(f"Skip config_id 35 due to error: {e}") + + for config_id in range(1, 7): + try: + torch.cuda.synchronize() + start = time.time() + sgl_scaled_mm_profile(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, + torch.bfloat16, bias=None, config_id=config_id) + torch.cuda.synchronize() + end = time.time() + times.append(end - start) + valid_configs.append(config_id) + print(f"config_id: {config_id}, time: {end - start}") + except RuntimeError as e: + print(f"Skip config_id {config_id} due to error: {e}") + continue + + if not valid_configs: + print("No valid config found") + return 0, 0, 0 + + min_time = float('inf') + best_config = None + for i, config_id in enumerate(valid_configs): + if times[i] < min_time: + min_time = times[i] + best_config = config_id + + # 记录当前输入规模的最优配置 + best_config_info[f"M{M}_N{N}_K{K}"] = { + "best_config": best_config, + "time": min_time, + "batch_size": batch_size + } + + # 将最优配置信息保存到文件 + import json + config_file = "best_fp8_configs.json" + try: + with open(config_file, "r") as f: + existing_configs = json.load(f) + except FileNotFoundError: + existing_configs = {} + + existing_configs.update(best_config_info) + with open(config_file, "w") as f: + json.dump(existing_configs, f, indent=4) + + print(f"Best config for batch_size={batch_size}: config_id={best_config}, time={min_time:.6f}s") + + # 使用最佳配置进行基准测试 + try: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: sgl_scaled_mm_profile(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, torch.bfloat16, bias=None, config_id=best_config), + quantiles=quantiles, + ) + except RuntimeError as e: + print("Error details:", e) + print(f"config_id is not valid {best_config}") + ms, min_ms, max_ms = 1, 1, 1 gbps = lambda ms: (2 * M * N * K + M * N) * a.element_size() * 1e-9 / (ms * 1e-3) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/sgl-kernel/benchmark/bench_int8_res/results.html b/sgl-kernel/benchmark/bench_int8_res/results.html new file mode 100644 index 00000000000..f8f21993bfa --- /dev/null +++ b/sgl-kernel/benchmark/bench_int8_res/results.html @@ -0,0 +1,3 @@ + + + diff --git a/sgl-kernel/benchmark/best_fp8_configs.json b/sgl-kernel/benchmark/best_fp8_configs.json new file mode 100644 index 00000000000..cff052cfd25 --- /dev/null +++ b/sgl-kernel/benchmark/best_fp8_configs.json @@ -0,0 +1,42 @@ +{ + "M1_N4096_K8192": { + "best_config": 6, + "time": 6.532669067382812e-05, + "batch_size": 1 + }, + "M16_N4096_K8192": { + "best_config": 6, + "time": 6.699562072753906e-05, + "batch_size": 16 + }, + "M64_N4096_K8192": { + "best_config": 6, + "time": 6.67572021484375e-05, + "batch_size": 64 + }, + "M128_N4096_K8192": { + "best_config": 6, + "time": 6.699562072753906e-05, + "batch_size": 128 + }, + "M256_N4096_K8192": { + "best_config": 6, + "time": 6.842613220214844e-05, + "batch_size": 256 + }, + "M512_N4096_K8192": { + "best_config": 6, + "time": 0.00012421607971191406, + "batch_size": 512 + }, + "M1024_N4096_K8192": { + "best_config": 6, + "time": 0.00023627281188964844, + "batch_size": 1024 + }, + "M2048_N4096_K8192": { + "best_config": 6, + "time": 0.00045871734619140625, + "batch_size": 2048 + } +} \ No newline at end of file diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index 3a60f6ba0a6..aaa0a53dc89 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -2,6 +2,9 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import os +import sys +import multiprocessing root = Path(__file__).parent.resolve() @@ -23,19 +26,32 @@ def update_wheel_platform_tag(): cutlass = root / "3rdparty" / "cutlass" +nlohmann = root / "3rdparty" / "nlohmann" + include_dirs = [ cutlass.resolve() / "include", cutlass.resolve() / "tools" / "util" / "include", root / "src" / "sgl-kernel" / "csrc", + nlohmann.resolve(), ] + +# nvcc_flags = [ +# "-O3", +# "-Xcompiler", +# "-fPIC", +# "-gencode=arch=compute_75,code=sm_75", +# "-gencode=arch=compute_80,code=sm_80", +# "-gencode=arch=compute_89,code=sm_89", +# "-gencode=arch=compute_90,code=sm_90", +# "-U__CUDA_NO_HALF_OPERATORS__", +# "-U__CUDA_NO_HALF2_OPERATORS__", +# ] nvcc_flags = [ "-O3", "-Xcompiler", "-fPIC", - "-gencode=arch=compute_75,code=sm_75", - "-gencode=arch=compute_80,code=sm_80", + # 只保留需要的架构 "-gencode=arch=compute_89,code=sm_89", - "-gencode=arch=compute_90,code=sm_90", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] @@ -49,7 +65,7 @@ def update_wheel_platform_tag(): "src/sgl-kernel/csrc/trt_reduce_internal.cu", "src/sgl-kernel/csrc/trt_reduce_kernel.cu", "src/sgl-kernel/csrc/moe_align_kernel.cu", - "src/sgl-kernel/csrc/int8_gemm_kernel.cu", + # "src/sgl-kernel/csrc/int8_gemm_kernel.cu", "src/sgl-kernel/csrc/fp8_gemm_kernel.cu", "src/sgl-kernel/csrc/sgl_kernel_ops.cu", ], @@ -63,6 +79,20 @@ def update_wheel_platform_tag(): ), ] +def set_parallel_jobs(): + if sys.platform == 'win32': + num_cores = int(os.environ.get('NUMBER_OF_PROCESSORS', 4)) + else: + num_cores = len(os.sched_getaffinity(0)) if hasattr(os, 'sched_getaffinity') else os.cpu_count() + + # 限制并行度为核心数的1/4或更少 + num_jobs = max(1, num_cores // 2) + os.environ['MAX_JOBS'] = str(num_jobs) + + # 设置CUDA编译的并行任务数 + os.environ['CUDA_NVCC_THREADS'] = str(num_jobs) + return num_jobs +set_parallel_jobs() setup( name="sgl-kernel", version=get_version(), diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index 2a4a2bd5177..06894c3358e 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -4,6 +4,7 @@ init_custom_reduce, int8_scaled_mm, fp8_scaled_mm, + fp8_scaled_mm_profile, moe_align_block_size, ) @@ -14,4 +15,5 @@ "custom_reduce", "int8_scaled_mm", "fp8_scaled_mm", + "fp8_scaled_mm_profile", ] diff --git a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu index ef88110e925..914d1cb4df8 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/fp8_gemm_kernel.cu @@ -33,7 +33,10 @@ using namespace cute; template + typename WarpShape, int Stages, bool WithBias, + typename FP8MathOperator = cutlass::arch::OpMultiplyAdd, + template typename EpilogueVisitor = cutlass::epilogue::threadblock::Sm80EVT, + typename ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>> struct DeviceGemmFp8RowwiseSm89 { static_assert(std::is_same_v, "ElementType must be FP8(e4m3)"); @@ -97,8 +100,8 @@ struct DeviceGemmFp8RowwiseSm89 using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor::GemmKernel; + WarpShape, InstructionShape, EpilogueOp, ThreadblockSwizzle, + Stages, FP8MathOperator, EVTEpilogueStages>::GemmKernel; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; }; @@ -509,16 +512,16 @@ void sm90_dispatch_shape(torch::Tensor& out, const torch::Tensor& a, const torch uint32_t const mp2 = std::max(static_cast(64), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { - // m in [1, 64] - return sm90_dispatch_bias, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); - } else if (mp2 <= 128) { - // m in (64, 128] - return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); - } else { - // m in (128, inf) - return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); - } + // if (mp2 <= 64) { + // // m in [1, 64] + // return sm90_dispatch_bias, Shape<_1, _8, _1>>(out, a, b, scales_a, scales_b, bias); + // } else if (mp2 <= 128) { + // // m in (64, 128] + // return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + // } else { + // // m in (128, inf) + // return sm90_dispatch_bias, Shape<_2, _1, _1>>(out, a, b, scales_a, scales_b, bias); + // } } torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, @@ -574,3 +577,133 @@ torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat return out; } + + +#define DISPATCH_FP8_GEMM_CONFIG(TB_M, TB_N, TB_K, WP_M, WP_N, WP_K, STAGES) \ + sm89_dispatch_bias, \ + cutlass::gemm::GemmShape, STAGES>(out, mat_a, mat_b, scales_a, scales_b, bias) +// 定义一个宏来生成一组配置的所有stages +#define DISPATCH_FP8_GEMM_GROUP(GROUP_ID, CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, BASE_CASE) \ + case BASE_CASE: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 2); break; \ + case BASE_CASE + 1: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 3); break; \ + case BASE_CASE + 2: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 4); break; \ + case BASE_CASE + 3: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 5); break; \ + case BASE_CASE + 4: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 6); break; \ + case BASE_CASE + 5: DISPATCH_FP8_GEMM_CONFIG(CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, 7); break; + +template +void sm89_dispatch_shape_profile(torch::Tensor& out, const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const c10::optional& bias, + int config_id) { + switch(config_id) { + case 1: + DISPATCH_FP8_GEMM_CONFIG(32, 64, 128, 16, 64, 64, 5); + case 2: + DISPATCH_FP8_GEMM_CONFIG(16, 64, 128, 16, 64, 64, 5); + case 3: + DISPATCH_FP8_GEMM_CONFIG(64, 64, 128, 32, 64, 64, 5); + case 4: + DISPATCH_FP8_GEMM_CONFIG(64, 128, 64, 32, 64, 64, 5); + case 5: + DISPATCH_FP8_GEMM_CONFIG(128, 128, 64, 64, 32, 64, 2); + case 6: + DISPATCH_FP8_GEMM_CONFIG(64, 128, 64, 32, 64, 64, 6); + // // Group 1: CtaShape32x128x64_WarpShape32x32x64 + // DISPATCH_FP8_GEMM_GROUP(1, 32, 128, 64, 32, 32, 64, 1) + + // // Group 2: CtaShape64x128x64_WarpShape32x64x64 + // DISPATCH_FP8_GEMM_GROUP(2, 64, 128, 64, 32, 64, 64, 7) + + // // Group 3: CtaShape64x64x128_WarpShape32x64x64 + // DISPATCH_FP8_GEMM_GROUP(3, 64, 64, 128, 32, 64, 64, 13) + + // // Group 4: CtaShape64x128x64_WarpShape64x32x64 + // DISPATCH_FP8_GEMM_GROUP(4, 64, 128, 64, 64, 32, 64, 19) + + // // Group 5: CtaShape128x64x64_WarpShape64x32x64 + // DISPATCH_FP8_GEMM_GROUP(5, 128, 64, 64, 64, 32, 64, 25) + + // // Group 6: CtaShape128x128x64_WarpShape64x32x64 + // DISPATCH_FP8_GEMM_GROUP(6, 128, 128, 64, 64, 32, 64, 31) + + // // Group 7: CtaShape128x128x64_WarpShape64x64x64 + // DISPATCH_FP8_GEMM_GROUP(7, 128, 128, 64, 64, 64, 64, 37) + + // // Group 8: CtaShape128x128x64_WarpShape128x32x64 + // DISPATCH_FP8_GEMM_GROUP(8, 128, 128, 64, 128, 32, 64, 43) + + // // Group 9: CtaShape128x256x64_WarpShape64x64x64 + // DISPATCH_FP8_GEMM_GROUP(9, 128, 256, 64, 64, 64, 64, 49) + + // // Group 10: CtaShape256x128x64_WarpShape64x64x64 + // DISPATCH_FP8_GEMM_GROUP(10, 256, 128, 64, 64, 64, 64, 55) + + // // Group 11: CtaShape128x64x128_WarpShape64x32x128 + // DISPATCH_FP8_GEMM_GROUP(11, 128, 64, 128, 64, 32, 128, 61) + + // // Group 12: CtaShape16x256x128_WarpShape16x64x128 + // DISPATCH_FP8_GEMM_GROUP(12, 16, 256, 128, 16, 64, 128, 67) + + // // Group 13: CtaShape16x64x128_WarpShape16x64x64 + // DISPATCH_FP8_GEMM_GROUP(13, 16, 64, 128, 16, 64, 64, 73) + + // // Group 14: CtaShape16x128x64_WarpShape16x64x64 + // DISPATCH_FP8_GEMM_GROUP(14, 16, 128, 64, 16, 64, 64, 79) + + // // Group 15: CtaShape32x64x128_WarpShape16x64x64 + // DISPATCH_FP8_GEMM_GROUP(15, 32, 64, 128, 16, 64, 64, 85) + } +} +torch::Tensor fp8_scaled_mm_profile(const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, const c10::optional& bias, + int config_id) { + + // 基本检查 + TORCH_CHECK(mat_a.is_cuda(), "mat_a must be a CUDA tensor"); + TORCH_CHECK(mat_b.is_cuda(), "mat_b must be a CUDA tensor"); + TORCH_CHECK(mat_a.dim() == 2, "mat_a must be a 2D tensor"); + TORCH_CHECK(mat_b.dim() == 2, "mat_b must be a 2D tensor"); + TORCH_CHECK(mat_a.stride(1) == 1, "mat_a must be a row major tensor"); + TORCH_CHECK(mat_b.stride(0) == 1, "mat_a must be a column major tensor"); + TORCH_CHECK(mat_a.size(1) == mat_b.size(0), "mat_a and mat_b shapes cannot be multiplied"); + + TORCH_CHECK((mat_a.size(1) * mat_a.element_size()) % 16 == 0, "mat_a must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK((mat_b.size(0) * mat_b.element_size()) % 16 == 0, "mat_b must be multiple of 16 bytes for memory alignment"); + TORCH_CHECK(mat_a.scalar_type() == torch::kFloat8_e4m3fn, "mat_a must be Float8_e4m3fn"); + TORCH_CHECK(mat_b.scalar_type() == torch::kFloat8_e4m3fn, "mat_b must be Float8_e4m3fn"); + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, "out_dtype must be Half or BFloat16"); + + // 检查scales + TORCH_CHECK(scales_a.numel() == mat_a.size(0), "size of scales_a is not matched"); + TORCH_CHECK(scales_b.numel() == mat_b.size(1), "size of scales_b is not matched"); + TORCH_CHECK(scales_a.is_contiguous(), "scales_a must be contiguous"); + TORCH_CHECK(scales_b.is_contiguous(), "scales_b must be contiguous"); + TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32, "scales_a must be Float32"); + TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32, "scales_b must be Float32"); + + // 检查bias + if (bias) { + TORCH_CHECK(bias->numel() == mat_b.size(1), "size of bias is not matched"); + TORCH_CHECK(bias->is_contiguous(), "bias must be contiguous"); + TORCH_CHECK(bias->dtype() == out_dtype, "bias dtype must match output dtype"); + } + + torch::Tensor out = torch::empty({mat_a.size(0), mat_b.size(1)}, mat_a.options().dtype(out_dtype)); + TORCH_CHECK((out.size(1) * out.element_size()) % 16 == 0, "out must be multiple of 16 bytes for memory alignment"); + + auto sm_version = getSMVersion(); + + if (sm_version == 89) { + if (out_dtype == torch::kBFloat16) { + sm89_dispatch_shape_profile(out, mat_a, mat_b, scales_a, scales_b, bias, config_id); + } else { + sm89_dispatch_shape_profile(out, mat_a, mat_b, scales_a, scales_b, bias, config_id); + } + } else { + TORCH_CHECK_NOT_IMPLEMENTED(false, "FP8 operations require SM89 GPU architecture"); + } + + return out; +} \ No newline at end of file diff --git a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu index b12d324cc62..4673a13271d 100644 --- a/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu +++ b/sgl-kernel/src/sgl-kernel/csrc/sgl_kernel_ops.cu @@ -13,14 +13,19 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer); // int8_scaled_mm -torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, - const torch::Tensor& scales_b, const torch::Dtype& out_dtype, - const c10::optional& bias); +// torch::Tensor int8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, +// const torch::Tensor& scales_b, const torch::Dtype& out_dtype, +// const c10::optional& bias); torch::Tensor fp8_scaled_mm(const torch::Tensor& mat_a, const torch::Tensor& mat_b, const torch::Tensor& scales_a, const torch::Tensor& scales_b, const torch::Dtype& out_dtype, const c10::optional& bias); +torch::Tensor fp8_scaled_mm_profile(const torch::Tensor& mat_a, const torch::Tensor& mat_b, + const torch::Tensor& scales_a, const torch::Tensor& scales_b, + const torch::Dtype& out_dtype, const c10::optional& bias, + int config_id); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // trt_reduce m.def("init_custom_ar", &init_custom_ar, "init custom allreduce meta (CUDA)"); @@ -29,7 +34,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // moe_align_block_size m.def("moe_align_block_size", &moe_align_block_size, "MOE Align Block Size (CUDA)"); // int8_scaled_mm - m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); + // m.def("int8_scaled_mm", &int8_scaled_mm, "INT8 scaled matmul (CUDA)"); // fp8_scaled_mm m.def("fp8_scaled_mm", &fp8_scaled_mm, "FP8 scaled matmul (CUDA)"); + // fp8_scaled_mm_profile + m.def("fp8_scaled_mm_profile", &fp8_scaled_mm_profile, "FP8 scaled matmul profile (CUDA)"); } diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index f339997b027..8b36c1738cd 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,10 +1,10 @@ from sgl_kernel.ops._kernels import all_reduce as _all_reduce from sgl_kernel.ops._kernels import dispose as _dispose from sgl_kernel.ops._kernels import init_custom_ar as _init_custom_ar -from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm +# from sgl_kernel.ops._kernels import int8_scaled_mm as _int8_scaled_mm from sgl_kernel.ops._kernels import fp8_scaled_mm as _fp8_scaled_mm from sgl_kernel.ops._kernels import moe_align_block_size as _moe_align_block_size - +from sgl_kernel.ops._kernels import fp8_scaled_mm_profile as _fp8_scaled_mm_profile def init_custom_reduce(rank_id, num_devices, buffers, barrier_in, barrier_out): return _init_custom_ar(rank_id, num_devices, buffers, barrier_in, barrier_out) @@ -41,7 +41,18 @@ def moe_align_block_size( def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _int8_scaled_mm( + return None + # return _int8_scaled_mm( + # mat_a, + # mat_b, + # scales_a, + # scales_b, + # out_dtype, + # bias, + # ) + +def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): + return _fp8_scaled_mm( mat_a, mat_b, scales_a, @@ -50,12 +61,13 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): bias, ) -def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): - return _fp8_scaled_mm( +def fp8_scaled_mm_profile(mat_a, mat_b, scales_a, scales_b, out_dtype, bias, config_id): + return _fp8_scaled_mm_profile( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, + config_id, ) From 349795099e106bb9155a62f373f1e7c2bd85bab3 Mon Sep 17 00:00:00 2001 From: yych0745 <1398089567@qq.com> Date: Mon, 13 Jan 2025 19:38:03 +0800 Subject: [PATCH 05/32] add config_profile for sm_89 --- sgl-kernel/3rdparty/nlohmann/json.hpp | 25420 ++++++++++++++++ sgl-kernel/3rdparty/nlohmann/json_fwd.hpp | 187 + sgl-kernel/benchmark/89_fp8_bf16.json | 10 + ...fp8_bf16_256\350\247\243\345\206\263.json" | 10 + ...4096,device=NVIDIA_L40,dtype=bfloat16.json | 11 + ...=4096,device=NVIDIA_L40,dtype=float16.json | 11 + sgl-kernel/benchmark/bench_fp8_gemm.py | 142 +- .../benchmark/bench_fp8_res/results.html | 1 + sgl-kernel/benchmark/best_fp8_configs.json | 42 - sgl-kernel/outp | 0 sgl-kernel/setup.py | 13 +- .../src/sgl-kernel/csrc/fp8_gemm_kernel.cu | 320 +- .../src/sgl-kernel/csrc/sgl_kernel_ops.cu | 7 +- sgl-kernel/src/sgl-kernel/csrc/utils.hpp | 99 +- sgl-kernel/src/sgl-kernel/ops/__init__.py | 8 +- 15 files changed, 26041 insertions(+), 240 deletions(-) create mode 100644 sgl-kernel/3rdparty/nlohmann/json.hpp create mode 100644 sgl-kernel/3rdparty/nlohmann/json_fwd.hpp create mode 100644 sgl-kernel/benchmark/89_fp8_bf16.json create mode 100644 "sgl-kernel/benchmark/89_fp8_bf16_256\350\247\243\345\206\263.json" create mode 100644 sgl-kernel/benchmark/N=8192,K=4096,device=NVIDIA_L40,dtype=bfloat16.json create mode 100644 sgl-kernel/benchmark/N=8192,K=4096,device=NVIDIA_L40,dtype=float16.json create mode 100644 sgl-kernel/benchmark/bench_fp8_res/results.html delete mode 100644 sgl-kernel/benchmark/best_fp8_configs.json create mode 100644 sgl-kernel/outp diff --git a/sgl-kernel/3rdparty/nlohmann/json.hpp b/sgl-kernel/3rdparty/nlohmann/json.hpp new file mode 100644 index 00000000000..9be8b892e3d --- /dev/null +++ b/sgl-kernel/3rdparty/nlohmann/json.hpp @@ -0,0 +1,25420 @@ +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + +/****************************************************************************\ + * Note on documentation: The source files contain links to the online * + * documentation of the public API at https://json.nlohmann.me. This URL * + * contains the most recent documentation and should also be applicable to * + * previous versions; documentation for deprecated functions is not * + * removed, but marked deprecated. See "Generate documentation" section in * + * file docs/README.md. * +\****************************************************************************/ + +#ifndef INCLUDE_NLOHMANN_JSON_HPP_ +#define INCLUDE_NLOHMANN_JSON_HPP_ + +#include // all_of, find, for_each +#include // nullptr_t, ptrdiff_t, size_t +#include // hash, less +#include // initializer_list +#ifndef JSON_NO_IO + #include // istream, ostream +#endif // JSON_NO_IO +#include // random_access_iterator_tag +#include // unique_ptr +#include // string, stoi, to_string +#include // declval, forward, move, pair, swap +#include // vector + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +// This file contains all macro definitions affecting or depending on the ABI + +#ifndef JSON_SKIP_LIBRARY_VERSION_CHECK + #if defined(NLOHMANN_JSON_VERSION_MAJOR) && defined(NLOHMANN_JSON_VERSION_MINOR) && defined(NLOHMANN_JSON_VERSION_PATCH) + #if NLOHMANN_JSON_VERSION_MAJOR != 3 || NLOHMANN_JSON_VERSION_MINOR != 11 || NLOHMANN_JSON_VERSION_PATCH != 3 + #warning "Already included a different version of the library!" + #endif + #endif +#endif + +#define NLOHMANN_JSON_VERSION_MAJOR 3 // NOLINT(modernize-macro-to-enum) +#define NLOHMANN_JSON_VERSION_MINOR 11 // NOLINT(modernize-macro-to-enum) +#define NLOHMANN_JSON_VERSION_PATCH 3 // NOLINT(modernize-macro-to-enum) + +#ifndef JSON_DIAGNOSTICS + #define JSON_DIAGNOSTICS 0 +#endif + +#ifndef JSON_DIAGNOSTIC_POSITIONS + #define JSON_DIAGNOSTIC_POSITIONS 0 +#endif + +#ifndef JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON + #define JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON 0 +#endif + +#if JSON_DIAGNOSTICS + #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS _diag +#else + #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS +#endif + +#if JSON_DIAGNOSTIC_POSITIONS + #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS _dp +#else + #define NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS +#endif + +#if JSON_USE_LEGACY_DISCARDED_VALUE_COMPARISON + #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON _ldvcmp +#else + #define NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_NO_VERSION + #define NLOHMANN_JSON_NAMESPACE_NO_VERSION 0 +#endif + +// Construct the namespace ABI tags component +#define NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b, c) json_abi ## a ## b ## c +#define NLOHMANN_JSON_ABI_TAGS_CONCAT(a, b, c) \ + NLOHMANN_JSON_ABI_TAGS_CONCAT_EX(a, b, c) + +#define NLOHMANN_JSON_ABI_TAGS \ + NLOHMANN_JSON_ABI_TAGS_CONCAT( \ + NLOHMANN_JSON_ABI_TAG_DIAGNOSTICS, \ + NLOHMANN_JSON_ABI_TAG_LEGACY_DISCARDED_VALUE_COMPARISON, \ + NLOHMANN_JSON_ABI_TAG_DIAGNOSTIC_POSITIONS) + +// Construct the namespace version component +#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) \ + _v ## major ## _ ## minor ## _ ## patch +#define NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(major, minor, patch) \ + NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT_EX(major, minor, patch) + +#if NLOHMANN_JSON_NAMESPACE_NO_VERSION +#define NLOHMANN_JSON_NAMESPACE_VERSION +#else +#define NLOHMANN_JSON_NAMESPACE_VERSION \ + NLOHMANN_JSON_NAMESPACE_VERSION_CONCAT(NLOHMANN_JSON_VERSION_MAJOR, \ + NLOHMANN_JSON_VERSION_MINOR, \ + NLOHMANN_JSON_VERSION_PATCH) +#endif + +// Combine namespace components +#define NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) a ## b +#define NLOHMANN_JSON_NAMESPACE_CONCAT(a, b) \ + NLOHMANN_JSON_NAMESPACE_CONCAT_EX(a, b) + +#ifndef NLOHMANN_JSON_NAMESPACE +#define NLOHMANN_JSON_NAMESPACE \ + nlohmann::NLOHMANN_JSON_NAMESPACE_CONCAT( \ + NLOHMANN_JSON_ABI_TAGS, \ + NLOHMANN_JSON_NAMESPACE_VERSION) +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_BEGIN +#define NLOHMANN_JSON_NAMESPACE_BEGIN \ + namespace nlohmann \ + { \ + inline namespace NLOHMANN_JSON_NAMESPACE_CONCAT( \ + NLOHMANN_JSON_ABI_TAGS, \ + NLOHMANN_JSON_NAMESPACE_VERSION) \ + { +#endif + +#ifndef NLOHMANN_JSON_NAMESPACE_END +#define NLOHMANN_JSON_NAMESPACE_END \ + } /* namespace (inline namespace) NOLINT(readability/namespace) */ \ + } // namespace nlohmann +#endif + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // transform +#include // array +#include // forward_list +#include // inserter, front_inserter, end +#include // map +#ifdef JSON_HAS_CPP_17 + #include // optional +#endif +#include // string +#include // tuple, make_tuple +#include // is_arithmetic, is_same, is_enum, underlying_type, is_convertible +#include // unordered_map +#include // pair, declval +#include // valarray + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // nullptr_t +#include // exception +#if JSON_DIAGNOSTICS + #include // accumulate +#endif +#include // runtime_error +#include // to_string +#include // vector + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // array +#include // size_t +#include // uint8_t +#include // string + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // declval, pair +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +template struct make_void +{ + using type = void; +}; +template using void_t = typename make_void::type; + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +// https://en.cppreference.com/w/cpp/experimental/is_detected +struct nonesuch +{ + nonesuch() = delete; + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + nonesuch(nonesuch const&&) = delete; + void operator=(nonesuch const&) = delete; + void operator=(nonesuch&&) = delete; +}; + +template class Op, + class... Args> +struct detector +{ + using value_t = std::false_type; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = std::true_type; + using type = Op; +}; + +template class Op, class... Args> +using is_detected = typename detector::value_t; + +template class Op, class... Args> +struct is_detected_lazy : is_detected { }; + +template class Op, class... Args> +using detected_t = typename detector::type; + +template class Op, class... Args> +using detected_or = detector; + +template class Op, class... Args> +using detected_or_t = typename detected_or::type; + +template class Op, class... Args> +using is_detected_exact = std::is_same>; + +template class Op, class... Args> +using is_detected_convertible = + std::is_convertible, To>; + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include + + +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-FileCopyrightText: 2016 - 2021 Evan Nemerson +// SPDX-License-Identifier: MIT + +/* Hedley - https://nemequ.github.io/hedley + * Created by Evan Nemerson + */ + +#if !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < 15) +#if defined(JSON_HEDLEY_VERSION) + #undef JSON_HEDLEY_VERSION +#endif +#define JSON_HEDLEY_VERSION 15 + +#if defined(JSON_HEDLEY_STRINGIFY_EX) + #undef JSON_HEDLEY_STRINGIFY_EX +#endif +#define JSON_HEDLEY_STRINGIFY_EX(x) #x + +#if defined(JSON_HEDLEY_STRINGIFY) + #undef JSON_HEDLEY_STRINGIFY +#endif +#define JSON_HEDLEY_STRINGIFY(x) JSON_HEDLEY_STRINGIFY_EX(x) + +#if defined(JSON_HEDLEY_CONCAT_EX) + #undef JSON_HEDLEY_CONCAT_EX +#endif +#define JSON_HEDLEY_CONCAT_EX(a,b) a##b + +#if defined(JSON_HEDLEY_CONCAT) + #undef JSON_HEDLEY_CONCAT +#endif +#define JSON_HEDLEY_CONCAT(a,b) JSON_HEDLEY_CONCAT_EX(a,b) + +#if defined(JSON_HEDLEY_CONCAT3_EX) + #undef JSON_HEDLEY_CONCAT3_EX +#endif +#define JSON_HEDLEY_CONCAT3_EX(a,b,c) a##b##c + +#if defined(JSON_HEDLEY_CONCAT3) + #undef JSON_HEDLEY_CONCAT3 +#endif +#define JSON_HEDLEY_CONCAT3(a,b,c) JSON_HEDLEY_CONCAT3_EX(a,b,c) + +#if defined(JSON_HEDLEY_VERSION_ENCODE) + #undef JSON_HEDLEY_VERSION_ENCODE +#endif +#define JSON_HEDLEY_VERSION_ENCODE(major,minor,revision) (((major) * 1000000) + ((minor) * 1000) + (revision)) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MAJOR) + #undef JSON_HEDLEY_VERSION_DECODE_MAJOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MAJOR(version) ((version) / 1000000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_MINOR) + #undef JSON_HEDLEY_VERSION_DECODE_MINOR +#endif +#define JSON_HEDLEY_VERSION_DECODE_MINOR(version) (((version) % 1000000) / 1000) + +#if defined(JSON_HEDLEY_VERSION_DECODE_REVISION) + #undef JSON_HEDLEY_VERSION_DECODE_REVISION +#endif +#define JSON_HEDLEY_VERSION_DECODE_REVISION(version) ((version) % 1000) + +#if defined(JSON_HEDLEY_GNUC_VERSION) + #undef JSON_HEDLEY_GNUC_VERSION +#endif +#if defined(__GNUC__) && defined(__GNUC_PATCHLEVEL__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#elif defined(__GNUC__) + #define JSON_HEDLEY_GNUC_VERSION JSON_HEDLEY_VERSION_ENCODE(__GNUC__, __GNUC_MINOR__, 0) +#endif + +#if defined(JSON_HEDLEY_GNUC_VERSION_CHECK) + #undef JSON_HEDLEY_GNUC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GNUC_VERSION) + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GNUC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION) + #undef JSON_HEDLEY_MSVC_VERSION +#endif +#if defined(_MSC_FULL_VER) && (_MSC_FULL_VER >= 140000000) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 10000000, (_MSC_FULL_VER % 10000000) / 100000, (_MSC_FULL_VER % 100000) / 100) +#elif defined(_MSC_FULL_VER) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_FULL_VER / 1000000, (_MSC_FULL_VER % 1000000) / 10000, (_MSC_FULL_VER % 10000) / 10) +#elif defined(_MSC_VER) && !defined(__ICL) + #define JSON_HEDLEY_MSVC_VERSION JSON_HEDLEY_VERSION_ENCODE(_MSC_VER / 100, _MSC_VER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_MSVC_VERSION_CHECK) + #undef JSON_HEDLEY_MSVC_VERSION_CHECK +#endif +#if !defined(JSON_HEDLEY_MSVC_VERSION) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (0) +#elif defined(_MSC_VER) && (_MSC_VER >= 1400) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 10000000) + (minor * 100000) + (patch))) +#elif defined(_MSC_VER) && (_MSC_VER >= 1200) + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_FULL_VER >= ((major * 1000000) + (minor * 10000) + (patch))) +#else + #define JSON_HEDLEY_MSVC_VERSION_CHECK(major,minor,patch) (_MSC_VER >= ((major * 100) + (minor))) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION) + #undef JSON_HEDLEY_INTEL_VERSION +#endif +#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && !defined(__ICL) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, __INTEL_COMPILER_UPDATE) +#elif defined(__INTEL_COMPILER) && !defined(__ICL) + #define JSON_HEDLEY_INTEL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER / 100, __INTEL_COMPILER % 100, 0) +#endif + +#if defined(JSON_HEDLEY_INTEL_VERSION_CHECK) + #undef JSON_HEDLEY_INTEL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_INTEL_VERSION) + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_INTEL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_INTEL_CL_VERSION) + #undef JSON_HEDLEY_INTEL_CL_VERSION +#endif +#if defined(__INTEL_COMPILER) && defined(__INTEL_COMPILER_UPDATE) && defined(__ICL) + #define JSON_HEDLEY_INTEL_CL_VERSION JSON_HEDLEY_VERSION_ENCODE(__INTEL_COMPILER, __INTEL_COMPILER_UPDATE, 0) +#endif + +#if defined(JSON_HEDLEY_INTEL_CL_VERSION_CHECK) + #undef JSON_HEDLEY_INTEL_CL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_INTEL_CL_VERSION) + #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_INTEL_CL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_INTEL_CL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION) + #undef JSON_HEDLEY_PGI_VERSION +#endif +#if defined(__PGI) && defined(__PGIC__) && defined(__PGIC_MINOR__) && defined(__PGIC_PATCHLEVEL__) + #define JSON_HEDLEY_PGI_VERSION JSON_HEDLEY_VERSION_ENCODE(__PGIC__, __PGIC_MINOR__, __PGIC_PATCHLEVEL__) +#endif + +#if defined(JSON_HEDLEY_PGI_VERSION_CHECK) + #undef JSON_HEDLEY_PGI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PGI_VERSION) + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PGI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PGI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #undef JSON_HEDLEY_SUNPRO_VERSION +#endif +#if defined(__SUNPRO_C) && (__SUNPRO_C > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_C >> 16) & 0xf) * 10) + ((__SUNPRO_C >> 12) & 0xf), (((__SUNPRO_C >> 8) & 0xf) * 10) + ((__SUNPRO_C >> 4) & 0xf), (__SUNPRO_C & 0xf) * 10) +#elif defined(__SUNPRO_C) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_C >> 8) & 0xf, (__SUNPRO_C >> 4) & 0xf, (__SUNPRO_C) & 0xf) +#elif defined(__SUNPRO_CC) && (__SUNPRO_CC > 0x1000) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((((__SUNPRO_CC >> 16) & 0xf) * 10) + ((__SUNPRO_CC >> 12) & 0xf), (((__SUNPRO_CC >> 8) & 0xf) * 10) + ((__SUNPRO_CC >> 4) & 0xf), (__SUNPRO_CC & 0xf) * 10) +#elif defined(__SUNPRO_CC) + #define JSON_HEDLEY_SUNPRO_VERSION JSON_HEDLEY_VERSION_ENCODE((__SUNPRO_CC >> 8) & 0xf, (__SUNPRO_CC >> 4) & 0xf, (__SUNPRO_CC) & 0xf) +#endif + +#if defined(JSON_HEDLEY_SUNPRO_VERSION_CHECK) + #undef JSON_HEDLEY_SUNPRO_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_SUNPRO_VERSION) + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_SUNPRO_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_SUNPRO_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION +#endif +#if defined(__EMSCRIPTEN__) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION JSON_HEDLEY_VERSION_ENCODE(__EMSCRIPTEN_major__, __EMSCRIPTEN_minor__, __EMSCRIPTEN_tiny__) +#endif + +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK) + #undef JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_EMSCRIPTEN_VERSION) + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_EMSCRIPTEN_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_EMSCRIPTEN_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION) + #undef JSON_HEDLEY_ARM_VERSION +#endif +#if defined(__CC_ARM) && defined(__ARMCOMPILER_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCOMPILER_VERSION / 1000000, (__ARMCOMPILER_VERSION % 1000000) / 10000, (__ARMCOMPILER_VERSION % 10000) / 100) +#elif defined(__CC_ARM) && defined(__ARMCC_VERSION) + #define JSON_HEDLEY_ARM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ARMCC_VERSION / 1000000, (__ARMCC_VERSION % 1000000) / 10000, (__ARMCC_VERSION % 10000) / 100) +#endif + +#if defined(JSON_HEDLEY_ARM_VERSION_CHECK) + #undef JSON_HEDLEY_ARM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_ARM_VERSION) + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_ARM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_ARM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION) + #undef JSON_HEDLEY_IBM_VERSION +#endif +#if defined(__ibmxl__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__ibmxl_version__, __ibmxl_release__, __ibmxl_modification__) +#elif defined(__xlC__) && defined(__xlC_ver__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, (__xlC_ver__ >> 8) & 0xff) +#elif defined(__xlC__) + #define JSON_HEDLEY_IBM_VERSION JSON_HEDLEY_VERSION_ENCODE(__xlC__ >> 8, __xlC__ & 0xff, 0) +#endif + +#if defined(JSON_HEDLEY_IBM_VERSION_CHECK) + #undef JSON_HEDLEY_IBM_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IBM_VERSION) + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IBM_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IBM_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_VERSION) + #undef JSON_HEDLEY_TI_VERSION +#endif +#if \ + defined(__TI_COMPILER_VERSION__) && \ + ( \ + defined(__TMS470__) || defined(__TI_ARM__) || \ + defined(__MSP430__) || \ + defined(__TMS320C2000__) \ + ) +#if (__TI_COMPILER_VERSION__ >= 16000000) + #define JSON_HEDLEY_TI_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif +#endif + +#if defined(JSON_HEDLEY_TI_VERSION_CHECK) + #undef JSON_HEDLEY_TI_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_VERSION) + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #undef JSON_HEDLEY_TI_CL2000_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C2000__) + #define JSON_HEDLEY_TI_CL2000_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL2000_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL2000_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL2000_VERSION) + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL2000_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL2000_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #undef JSON_HEDLEY_TI_CL430_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__MSP430__) + #define JSON_HEDLEY_TI_CL430_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL430_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL430_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL430_VERSION) + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL430_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL430_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #undef JSON_HEDLEY_TI_ARMCL_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && (defined(__TMS470__) || defined(__TI_ARM__)) + #define JSON_HEDLEY_TI_ARMCL_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION_CHECK) + #undef JSON_HEDLEY_TI_ARMCL_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_ARMCL_VERSION) + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_ARMCL_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #undef JSON_HEDLEY_TI_CL6X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__TMS320C6X__) + #define JSON_HEDLEY_TI_CL6X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL6X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL6X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL6X_VERSION) + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL6X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL6X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #undef JSON_HEDLEY_TI_CL7X_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__C7000__) + #define JSON_HEDLEY_TI_CL7X_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CL7X_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CL7X_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CL7X_VERSION) + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CL7X_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CL7X_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #undef JSON_HEDLEY_TI_CLPRU_VERSION +#endif +#if defined(__TI_COMPILER_VERSION__) && defined(__PRU__) + #define JSON_HEDLEY_TI_CLPRU_VERSION JSON_HEDLEY_VERSION_ENCODE(__TI_COMPILER_VERSION__ / 1000000, (__TI_COMPILER_VERSION__ % 1000000) / 1000, (__TI_COMPILER_VERSION__ % 1000)) +#endif + +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION_CHECK) + #undef JSON_HEDLEY_TI_CLPRU_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TI_CLPRU_VERSION) + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TI_CLPRU_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION) + #undef JSON_HEDLEY_CRAY_VERSION +#endif +#if defined(_CRAYC) + #if defined(_RELEASE_PATCHLEVEL) + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, _RELEASE_PATCHLEVEL) + #else + #define JSON_HEDLEY_CRAY_VERSION JSON_HEDLEY_VERSION_ENCODE(_RELEASE_MAJOR, _RELEASE_MINOR, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_CRAY_VERSION_CHECK) + #undef JSON_HEDLEY_CRAY_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_CRAY_VERSION) + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_CRAY_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_CRAY_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION) + #undef JSON_HEDLEY_IAR_VERSION +#endif +#if defined(__IAR_SYSTEMS_ICC__) + #if __VER__ > 1000 + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE((__VER__ / 1000000), ((__VER__ / 1000) % 1000), (__VER__ % 1000)) + #else + #define JSON_HEDLEY_IAR_VERSION JSON_HEDLEY_VERSION_ENCODE(__VER__ / 100, __VER__ % 100, 0) + #endif +#endif + +#if defined(JSON_HEDLEY_IAR_VERSION_CHECK) + #undef JSON_HEDLEY_IAR_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_IAR_VERSION) + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_IAR_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_IAR_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION) + #undef JSON_HEDLEY_TINYC_VERSION +#endif +#if defined(__TINYC__) + #define JSON_HEDLEY_TINYC_VERSION JSON_HEDLEY_VERSION_ENCODE(__TINYC__ / 1000, (__TINYC__ / 100) % 10, __TINYC__ % 100) +#endif + +#if defined(JSON_HEDLEY_TINYC_VERSION_CHECK) + #undef JSON_HEDLEY_TINYC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_TINYC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_TINYC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION) + #undef JSON_HEDLEY_DMC_VERSION +#endif +#if defined(__DMC__) + #define JSON_HEDLEY_DMC_VERSION JSON_HEDLEY_VERSION_ENCODE(__DMC__ >> 8, (__DMC__ >> 4) & 0xf, __DMC__ & 0xf) +#endif + +#if defined(JSON_HEDLEY_DMC_VERSION_CHECK) + #undef JSON_HEDLEY_DMC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_DMC_VERSION) + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_DMC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_DMC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #undef JSON_HEDLEY_COMPCERT_VERSION +#endif +#if defined(__COMPCERT_VERSION__) + #define JSON_HEDLEY_COMPCERT_VERSION JSON_HEDLEY_VERSION_ENCODE(__COMPCERT_VERSION__ / 10000, (__COMPCERT_VERSION__ / 100) % 100, __COMPCERT_VERSION__ % 100) +#endif + +#if defined(JSON_HEDLEY_COMPCERT_VERSION_CHECK) + #undef JSON_HEDLEY_COMPCERT_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_COMPCERT_VERSION) + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_COMPCERT_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_COMPCERT_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION) + #undef JSON_HEDLEY_PELLES_VERSION +#endif +#if defined(__POCC__) + #define JSON_HEDLEY_PELLES_VERSION JSON_HEDLEY_VERSION_ENCODE(__POCC__ / 100, __POCC__ % 100, 0) +#endif + +#if defined(JSON_HEDLEY_PELLES_VERSION_CHECK) + #undef JSON_HEDLEY_PELLES_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_PELLES_VERSION) + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_PELLES_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_PELLES_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_MCST_LCC_VERSION) + #undef JSON_HEDLEY_MCST_LCC_VERSION +#endif +#if defined(__LCC__) && defined(__LCC_MINOR__) + #define JSON_HEDLEY_MCST_LCC_VERSION JSON_HEDLEY_VERSION_ENCODE(__LCC__ / 100, __LCC__ % 100, __LCC_MINOR__) +#endif + +#if defined(JSON_HEDLEY_MCST_LCC_VERSION_CHECK) + #undef JSON_HEDLEY_MCST_LCC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_MCST_LCC_VERSION) + #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_MCST_LCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_MCST_LCC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION) + #undef JSON_HEDLEY_GCC_VERSION +#endif +#if \ + defined(JSON_HEDLEY_GNUC_VERSION) && \ + !defined(__clang__) && \ + !defined(JSON_HEDLEY_INTEL_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_ARM_VERSION) && \ + !defined(JSON_HEDLEY_CRAY_VERSION) && \ + !defined(JSON_HEDLEY_TI_VERSION) && \ + !defined(JSON_HEDLEY_TI_ARMCL_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL430_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL2000_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL6X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CL7X_VERSION) && \ + !defined(JSON_HEDLEY_TI_CLPRU_VERSION) && \ + !defined(__COMPCERT__) && \ + !defined(JSON_HEDLEY_MCST_LCC_VERSION) + #define JSON_HEDLEY_GCC_VERSION JSON_HEDLEY_GNUC_VERSION +#endif + +#if defined(JSON_HEDLEY_GCC_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_VERSION_CHECK +#endif +#if defined(JSON_HEDLEY_GCC_VERSION) + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (JSON_HEDLEY_GCC_VERSION >= JSON_HEDLEY_VERSION_ENCODE(major, minor, patch)) +#else + #define JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_ATTRIBUTE +#endif +#if \ + defined(__has_attribute) && \ + ( \ + (!defined(JSON_HEDLEY_IAR_VERSION) || JSON_HEDLEY_IAR_VERSION_CHECK(8,5,9)) \ + ) +# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) __has_attribute(attribute) +#else +# define JSON_HEDLEY_HAS_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_ATTRIBUTE +#endif +#if defined(__has_attribute) + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE +#endif +#if \ + defined(__has_cpp_attribute) && \ + defined(__cplusplus) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS) + #undef JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS +#endif +#if !defined(__cplusplus) || !defined(__has_cpp_attribute) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#elif \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION) && \ + (!defined(JSON_HEDLEY_SUNPRO_VERSION) || JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0)) && \ + (!defined(JSON_HEDLEY_MSVC_VERSION) || JSON_HEDLEY_MSVC_VERSION_CHECK(19,20,0)) + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(ns::attribute) +#else + #define JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(ns,attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE +#endif +#if defined(__has_cpp_attribute) && defined(__cplusplus) + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) __has_cpp_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_CPP_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_BUILTIN) + #undef JSON_HEDLEY_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else + #define JSON_HEDLEY_HAS_BUILTIN(builtin) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_BUILTIN) + #undef JSON_HEDLEY_GNUC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GNUC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_BUILTIN) + #undef JSON_HEDLEY_GCC_HAS_BUILTIN +#endif +#if defined(__has_builtin) + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) __has_builtin(builtin) +#else + #define JSON_HEDLEY_GCC_HAS_BUILTIN(builtin,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_FEATURE) + #undef JSON_HEDLEY_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_HAS_FEATURE(feature) __has_feature(feature) +#else + #define JSON_HEDLEY_HAS_FEATURE(feature) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_FEATURE) + #undef JSON_HEDLEY_GNUC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GNUC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_FEATURE) + #undef JSON_HEDLEY_GCC_HAS_FEATURE +#endif +#if defined(__has_feature) + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) __has_feature(feature) +#else + #define JSON_HEDLEY_GCC_HAS_FEATURE(feature,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_EXTENSION) + #undef JSON_HEDLEY_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_HAS_EXTENSION(extension) __has_extension(extension) +#else + #define JSON_HEDLEY_HAS_EXTENSION(extension) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_EXTENSION) + #undef JSON_HEDLEY_GNUC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GNUC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_EXTENSION) + #undef JSON_HEDLEY_GCC_HAS_EXTENSION +#endif +#if defined(__has_extension) + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) __has_extension(extension) +#else + #define JSON_HEDLEY_GCC_HAS_EXTENSION(extension,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GNUC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE +#endif +#if defined(__has_declspec_attribute) + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) __has_declspec_attribute(attribute) +#else + #define JSON_HEDLEY_GCC_HAS_DECLSPEC_ATTRIBUTE(attribute,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_HAS_WARNING) + #undef JSON_HEDLEY_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_HAS_WARNING(warning) __has_warning(warning) +#else + #define JSON_HEDLEY_HAS_WARNING(warning) (0) +#endif + +#if defined(JSON_HEDLEY_GNUC_HAS_WARNING) + #undef JSON_HEDLEY_GNUC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GNUC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GNUC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_GCC_HAS_WARNING) + #undef JSON_HEDLEY_GCC_HAS_WARNING +#endif +#if defined(__has_warning) + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) __has_warning(warning) +#else + #define JSON_HEDLEY_GCC_HAS_WARNING(warning,major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + defined(__clang__) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,17) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(8,0,0) || \ + (JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) && defined(__C99_PRAGMA_OPERATOR)) + #define JSON_HEDLEY_PRAGMA(value) _Pragma(#value) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_PRAGMA(value) __pragma(value) +#else + #define JSON_HEDLEY_PRAGMA(value) +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_PUSH) + #undef JSON_HEDLEY_DIAGNOSTIC_PUSH +#endif +#if defined(JSON_HEDLEY_DIAGNOSTIC_POP) + #undef JSON_HEDLEY_DIAGNOSTIC_POP +#endif +#if defined(__clang__) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("clang diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("clang diagnostic pop") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("GCC diagnostic push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("GCC diagnostic pop") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH __pragma(warning(push)) + #define JSON_HEDLEY_DIAGNOSTIC_POP __pragma(warning(pop)) +#elif JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("pop") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("diag_push") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("diag_pop") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_PUSH _Pragma("warning(push)") + #define JSON_HEDLEY_DIAGNOSTIC_POP _Pragma("warning(pop)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_PUSH + #define JSON_HEDLEY_DIAGNOSTIC_POP +#endif + +/* JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_ +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat") +# if JSON_HEDLEY_HAS_WARNING("-Wc++17-extensions") +# if JSON_HEDLEY_HAS_WARNING("-Wc++1z-extensions") +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ + _Pragma("clang diagnostic ignored \"-Wc++1z-extensions\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + _Pragma("clang diagnostic ignored \"-Wc++17-extensions\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# endif +# else +# define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(xpr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wc++98-compat\"") \ + xpr \ + JSON_HEDLEY_DIAGNOSTIC_POP +# endif +# endif +#endif +#if !defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(x) x +#endif + +#if defined(JSON_HEDLEY_CONST_CAST) + #undef JSON_HEDLEY_CONST_CAST +#endif +#if defined(__cplusplus) +# define JSON_HEDLEY_CONST_CAST(T, expr) (const_cast(expr)) +#elif \ + JSON_HEDLEY_HAS_WARNING("-Wcast-qual") || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_CONST_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_CONST_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_REINTERPRET_CAST) + #undef JSON_HEDLEY_REINTERPRET_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) (reinterpret_cast(expr)) +#else + #define JSON_HEDLEY_REINTERPRET_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_STATIC_CAST) + #undef JSON_HEDLEY_STATIC_CAST +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_STATIC_CAST(T, expr) (static_cast(expr)) +#else + #define JSON_HEDLEY_STATIC_CAST(T, expr) ((T) (expr)) +#endif + +#if defined(JSON_HEDLEY_CPP_CAST) + #undef JSON_HEDLEY_CPP_CAST +#endif +#if defined(__cplusplus) +# if JSON_HEDLEY_HAS_WARNING("-Wold-style-cast") +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wold-style-cast\"") \ + ((T) (expr)) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# elif JSON_HEDLEY_IAR_VERSION_CHECK(8,3,0) +# define JSON_HEDLEY_CPP_CAST(T, expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("diag_suppress=Pe137") \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_CPP_CAST(T, expr) ((T) (expr)) +# endif +#else +# define JSON_HEDLEY_CPP_CAST(T, expr) (expr) +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wdeprecated-declarations") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("clang diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warning(disable:1478 1786)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:1478 1786)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1216,1444,1445") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED __pragma(warning(disable:4996)) +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1215,1444") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress 1291,1718") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && !defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,E_DEPRECATED_ATT,E_DEPRECATED_ATT_MESS)") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("error_messages(off,symdeprecated,symdeprecated2)") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("diag_suppress=Pe1444,Pe1215") +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,90,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED _Pragma("warn(disable:2241)") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_DEPRECATED +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("clang diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("warning(disable:161)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:161)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 1675") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("GCC diagnostic ignored \"-Wunknown-pragmas\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS __pragma(warning(disable:4068)) +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(16,9,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 163") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress=Pe161") +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS _Pragma("diag_suppress 161") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-attributes") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("clang diagnostic ignored \"-Wunknown-attributes\"") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(4,6,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("GCC diagnostic ignored \"-Wdeprecated-declarations\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("warning(disable:1292)") +#elif JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:1292)) +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES __pragma(warning(disable:5030)) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(20,7,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097,1098") +#elif JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("error_messages(off,attrskipunsup)") +#elif \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1173") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress=Pe1097") +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES _Pragma("diag_suppress 1097") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_CPP_ATTRIBUTES +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wcast-qual") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("clang diagnostic ignored \"-Wcast-qual\"") +#elif JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("warning(disable:2203 2331)") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL _Pragma("GCC diagnostic ignored \"-Wcast-qual\"") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_CAST_QUAL +#endif + +#if defined(JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION) + #undef JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunused-function") + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("clang diagnostic ignored \"-Wunused-function\"") +#elif JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("GCC diagnostic ignored \"-Wunused-function\"") +#elif JSON_HEDLEY_MSVC_VERSION_CHECK(1,0,0) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION __pragma(warning(disable:4505)) +#elif JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION _Pragma("diag_suppress 3142") +#else + #define JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNUSED_FUNCTION +#endif + +#if defined(JSON_HEDLEY_DEPRECATED) + #undef JSON_HEDLEY_DEPRECATED +#endif +#if defined(JSON_HEDLEY_DEPRECATED_FOR) + #undef JSON_HEDLEY_DEPRECATED_FOR +#endif +#if \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated("Since " # since)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated("Since " #since "; use " #replacement)) +#elif \ + (JSON_HEDLEY_HAS_EXTENSION(attribute_deprecated_with_message) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,13,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(18,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,3,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,3,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__("Since " #since))) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__("Since " #since "; use " #replacement))) +#elif defined(__cplusplus) && (__cplusplus >= 201402L) + #define JSON_HEDLEY_DEPRECATED(since) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since)]]) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[deprecated("Since " #since "; use " #replacement)]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(deprecated) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_DEPRECATED(since) __attribute__((__deprecated__)) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __attribute__((__deprecated__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_PELLES_VERSION_CHECK(6,50,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_DEPRECATED(since) __declspec(deprecated) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) __declspec(deprecated) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_DEPRECATED(since) _Pragma("deprecated") + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) _Pragma("deprecated") +#else + #define JSON_HEDLEY_DEPRECATED(since) + #define JSON_HEDLEY_DEPRECATED_FOR(since, replacement) +#endif + +#if defined(JSON_HEDLEY_UNAVAILABLE) + #undef JSON_HEDLEY_UNAVAILABLE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(warning) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_UNAVAILABLE(available_since) __attribute__((__warning__("Not available until " #available_since))) +#else + #define JSON_HEDLEY_UNAVAILABLE(available_since) +#endif + +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT +#endif +#if defined(JSON_HEDLEY_WARN_UNUSED_RESULT_MSG) + #undef JSON_HEDLEY_WARN_UNUSED_RESULT_MSG +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(warn_unused_result) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_WARN_UNUSED_RESULT __attribute__((__warn_unused_result__)) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) __attribute__((__warn_unused_result__)) +#elif (JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) >= 201907L) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard(msg)]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(nodiscard) + #define JSON_HEDLEY_WARN_UNUSED_RESULT JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[nodiscard]]) +#elif defined(_Check_return_) /* SAL */ + #define JSON_HEDLEY_WARN_UNUSED_RESULT _Check_return_ + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) _Check_return_ +#else + #define JSON_HEDLEY_WARN_UNUSED_RESULT + #define JSON_HEDLEY_WARN_UNUSED_RESULT_MSG(msg) +#endif + +#if defined(JSON_HEDLEY_SENTINEL) + #undef JSON_HEDLEY_SENTINEL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(sentinel) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_SENTINEL(position) __attribute__((__sentinel__(position))) +#else + #define JSON_HEDLEY_SENTINEL(position) +#endif + +#if defined(JSON_HEDLEY_NO_RETURN) + #undef JSON_HEDLEY_NO_RETURN +#endif +#if JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NO_RETURN __noreturn +#elif \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif defined(__STDC_VERSION__) && __STDC_VERSION__ >= 201112L + #define JSON_HEDLEY_NO_RETURN _Noreturn +#elif defined(__cplusplus) && (__cplusplus >= 201103L) + #define JSON_HEDLEY_NO_RETURN JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[noreturn]]) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(noreturn) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,2,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_NO_RETURN __attribute__((__noreturn__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_NO_RETURN _Pragma("does_not_return") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NO_RETURN _Pragma("FUNC_NEVER_RETURNS;") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NO_RETURN __attribute((noreturn)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NO_RETURN __declspec(noreturn) +#else + #define JSON_HEDLEY_NO_RETURN +#endif + +#if defined(JSON_HEDLEY_NO_ESCAPE) + #undef JSON_HEDLEY_NO_ESCAPE +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(noescape) + #define JSON_HEDLEY_NO_ESCAPE __attribute__((__noescape__)) +#else + #define JSON_HEDLEY_NO_ESCAPE +#endif + +#if defined(JSON_HEDLEY_UNREACHABLE) + #undef JSON_HEDLEY_UNREACHABLE +#endif +#if defined(JSON_HEDLEY_UNREACHABLE_RETURN) + #undef JSON_HEDLEY_UNREACHABLE_RETURN +#endif +#if defined(JSON_HEDLEY_ASSUME) + #undef JSON_HEDLEY_ASSUME +#endif +#if \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_ASSUME(expr) __assume(expr) +#elif JSON_HEDLEY_HAS_BUILTIN(__builtin_assume) + #define JSON_HEDLEY_ASSUME(expr) __builtin_assume(expr) +#elif \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #if defined(__cplusplus) + #define JSON_HEDLEY_ASSUME(expr) std::_nassert(expr) + #else + #define JSON_HEDLEY_ASSUME(expr) _nassert(expr) + #endif +#endif +#if \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_unreachable) && (!defined(JSON_HEDLEY_ARM_VERSION))) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,5,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,10,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,5) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(10,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_UNREACHABLE() __builtin_unreachable() +#elif defined(JSON_HEDLEY_ASSUME) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif +#if !defined(JSON_HEDLEY_ASSUME) + #if defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, ((expr) ? 1 : (JSON_HEDLEY_UNREACHABLE(), 1))) + #else + #define JSON_HEDLEY_ASSUME(expr) JSON_HEDLEY_STATIC_CAST(void, expr) + #endif +#endif +#if defined(JSON_HEDLEY_UNREACHABLE) + #if \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (JSON_HEDLEY_STATIC_CAST(void, JSON_HEDLEY_ASSUME(0)), (value)) + #else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) JSON_HEDLEY_UNREACHABLE() + #endif +#else + #define JSON_HEDLEY_UNREACHABLE_RETURN(value) return (value) +#endif +#if !defined(JSON_HEDLEY_UNREACHABLE) + #define JSON_HEDLEY_UNREACHABLE() JSON_HEDLEY_ASSUME(0) +#endif + +JSON_HEDLEY_DIAGNOSTIC_PUSH +#if JSON_HEDLEY_HAS_WARNING("-Wpedantic") + #pragma clang diagnostic ignored "-Wpedantic" +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wc++98-compat-pedantic") && defined(__cplusplus) + #pragma clang diagnostic ignored "-Wc++98-compat-pedantic" +#endif +#if JSON_HEDLEY_GCC_HAS_WARNING("-Wvariadic-macros",4,0,0) + #if defined(__clang__) + #pragma clang diagnostic ignored "-Wvariadic-macros" + #elif defined(JSON_HEDLEY_GCC_VERSION) + #pragma GCC diagnostic ignored "-Wvariadic-macros" + #endif +#endif +#if defined(JSON_HEDLEY_NON_NULL) + #undef JSON_HEDLEY_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NON_NULL(...) __attribute__((__nonnull__(__VA_ARGS__))) +#else + #define JSON_HEDLEY_NON_NULL(...) +#endif +JSON_HEDLEY_DIAGNOSTIC_POP + +#if defined(JSON_HEDLEY_PRINTF_FORMAT) + #undef JSON_HEDLEY_PRINTF_FORMAT +#endif +#if defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && !defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(ms_printf, string_idx, first_to_check))) +#elif defined(__MINGW32__) && JSON_HEDLEY_GCC_HAS_ATTRIBUTE(format,4,4,0) && defined(__USE_MINGW_ANSI_STDIO) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(gnu_printf, string_idx, first_to_check))) +#elif \ + JSON_HEDLEY_HAS_ATTRIBUTE(format) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,6,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __attribute__((__format__(__printf__, string_idx, first_to_check))) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(6,0,0) + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) __declspec(vaformat(printf,string_idx,first_to_check)) +#else + #define JSON_HEDLEY_PRINTF_FORMAT(string_idx,first_to_check) +#endif + +#if defined(JSON_HEDLEY_CONSTEXPR) + #undef JSON_HEDLEY_CONSTEXPR +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_CONSTEXPR JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(constexpr) + #endif +#endif +#if !defined(JSON_HEDLEY_CONSTEXPR) + #define JSON_HEDLEY_CONSTEXPR +#endif + +#if defined(JSON_HEDLEY_PREDICT) + #undef JSON_HEDLEY_PREDICT +#endif +#if defined(JSON_HEDLEY_LIKELY) + #undef JSON_HEDLEY_LIKELY +#endif +#if defined(JSON_HEDLEY_UNLIKELY) + #undef JSON_HEDLEY_UNLIKELY +#endif +#if defined(JSON_HEDLEY_UNPREDICTABLE) + #undef JSON_HEDLEY_UNPREDICTABLE +#endif +#if JSON_HEDLEY_HAS_BUILTIN(__builtin_unpredictable) + #define JSON_HEDLEY_UNPREDICTABLE(expr) __builtin_unpredictable((expr)) +#endif +#if \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect_with_probability) && !defined(JSON_HEDLEY_PGI_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(9,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PREDICT(expr, value, probability) __builtin_expect_with_probability( (expr), (value), (probability)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) __builtin_expect_with_probability(!!(expr), 1 , (probability)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) __builtin_expect_with_probability(!!(expr), 0 , (probability)) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect (!!(expr), 1 ) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect (!!(expr), 0 ) +#elif \ + (JSON_HEDLEY_HAS_BUILTIN(__builtin_expect) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,15,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,7,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,27) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PREDICT(expr, expected, probability) \ + (((probability) >= 0.9) ? __builtin_expect((expr), (expected)) : (JSON_HEDLEY_STATIC_CAST(void, expected), (expr))) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 1) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 0) : !!(expr))); \ + })) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) \ + (__extension__ ({ \ + double hedley_probability_ = (probability); \ + ((hedley_probability_ >= 0.9) ? __builtin_expect(!!(expr), 0) : ((hedley_probability_ <= 0.1) ? __builtin_expect(!!(expr), 1) : !!(expr))); \ + })) +# define JSON_HEDLEY_LIKELY(expr) __builtin_expect(!!(expr), 1) +# define JSON_HEDLEY_UNLIKELY(expr) __builtin_expect(!!(expr), 0) +#else +# define JSON_HEDLEY_PREDICT(expr, expected, probability) (JSON_HEDLEY_STATIC_CAST(void, expected), (expr)) +# define JSON_HEDLEY_PREDICT_TRUE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_PREDICT_FALSE(expr, probability) (!!(expr)) +# define JSON_HEDLEY_LIKELY(expr) (!!(expr)) +# define JSON_HEDLEY_UNLIKELY(expr) (!!(expr)) +#endif +#if !defined(JSON_HEDLEY_UNPREDICTABLE) + #define JSON_HEDLEY_UNPREDICTABLE(expr) JSON_HEDLEY_PREDICT(expr, 1, 0.5) +#endif + +#if defined(JSON_HEDLEY_MALLOC) + #undef JSON_HEDLEY_MALLOC +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(malloc) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_MALLOC __attribute__((__malloc__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_MALLOC _Pragma("returns_new_memory") +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_MALLOC __declspec(restrict) +#else + #define JSON_HEDLEY_MALLOC +#endif + +#if defined(JSON_HEDLEY_PURE) + #undef JSON_HEDLEY_PURE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(pure) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,96,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PURE __attribute__((__pure__)) +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) +# define JSON_HEDLEY_PURE _Pragma("does_not_write_global_data") +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(2,0,1) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) \ + ) +# define JSON_HEDLEY_PURE _Pragma("FUNC_IS_PURE;") +#else +# define JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_CONST) + #undef JSON_HEDLEY_CONST +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(const) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(2,5,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_CONST __attribute__((__const__)) +#elif \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) + #define JSON_HEDLEY_CONST _Pragma("no_side_effect") +#else + #define JSON_HEDLEY_CONST JSON_HEDLEY_PURE +#endif + +#if defined(JSON_HEDLEY_RESTRICT) + #undef JSON_HEDLEY_RESTRICT +#endif +#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT restrict +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(14,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(17,10,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,4) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,14,0) && defined(__cplusplus)) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) || \ + defined(__clang__) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_RESTRICT __restrict +#elif JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,3,0) && !defined(__cplusplus) + #define JSON_HEDLEY_RESTRICT _Restrict +#else + #define JSON_HEDLEY_RESTRICT +#endif + +#if defined(JSON_HEDLEY_INLINE) + #undef JSON_HEDLEY_INLINE +#endif +#if \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L)) || \ + (defined(__cplusplus) && (__cplusplus >= 199711L)) + #define JSON_HEDLEY_INLINE inline +#elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(6,2,0) + #define JSON_HEDLEY_INLINE __inline__ +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,1,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(3,1,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,2,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(8,0,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_INLINE __inline +#else + #define JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_ALWAYS_INLINE) + #undef JSON_HEDLEY_ALWAYS_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(always_inline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) +# define JSON_HEDLEY_ALWAYS_INLINE __attribute__((__always_inline__)) JSON_HEDLEY_INLINE +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(12,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_ALWAYS_INLINE __forceinline +#elif defined(__cplusplus) && \ + ( \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) \ + ) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("FUNC_ALWAYS_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_ALWAYS_INLINE _Pragma("inline=forced") +#else +# define JSON_HEDLEY_ALWAYS_INLINE JSON_HEDLEY_INLINE +#endif + +#if defined(JSON_HEDLEY_NEVER_INLINE) + #undef JSON_HEDLEY_NEVER_INLINE +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(noinline) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(10,1,0) || \ + JSON_HEDLEY_TI_VERSION_CHECK(15,12,0) || \ + (JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(4,8,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_ARMCL_VERSION_CHECK(5,2,0) || \ + (JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL2000_VERSION_CHECK(6,4,0) || \ + (JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,0,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(4,3,0) || \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) || \ + JSON_HEDLEY_TI_CL7X_VERSION_CHECK(1,2,0) || \ + JSON_HEDLEY_TI_CLPRU_VERSION_CHECK(2,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) || \ + JSON_HEDLEY_IAR_VERSION_CHECK(8,10,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute__((__noinline__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,10,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#elif JSON_HEDLEY_PGI_VERSION_CHECK(10,2,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("noinline") +#elif JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,0,0) && defined(__cplusplus) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("FUNC_CANNOT_INLINE;") +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) + #define JSON_HEDLEY_NEVER_INLINE _Pragma("inline=never") +#elif JSON_HEDLEY_COMPCERT_VERSION_CHECK(3,2,0) + #define JSON_HEDLEY_NEVER_INLINE __attribute((noinline)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(9,0,0) + #define JSON_HEDLEY_NEVER_INLINE __declspec(noinline) +#else + #define JSON_HEDLEY_NEVER_INLINE +#endif + +#if defined(JSON_HEDLEY_PRIVATE) + #undef JSON_HEDLEY_PRIVATE +#endif +#if defined(JSON_HEDLEY_PUBLIC) + #undef JSON_HEDLEY_PUBLIC +#endif +#if defined(JSON_HEDLEY_IMPORT) + #undef JSON_HEDLEY_IMPORT +#endif +#if defined(_WIN32) || defined(__CYGWIN__) +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC __declspec(dllexport) +# define JSON_HEDLEY_IMPORT __declspec(dllimport) +#else +# if \ + JSON_HEDLEY_HAS_ATTRIBUTE(visibility) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,11,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + ( \ + defined(__TI_EABI__) && \ + ( \ + (JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,2,0) && defined(__TI_GNU_ATTRIBUTE_SUPPORT__)) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(7,5,0) \ + ) \ + ) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) +# define JSON_HEDLEY_PRIVATE __attribute__((__visibility__("hidden"))) +# define JSON_HEDLEY_PUBLIC __attribute__((__visibility__("default"))) +# else +# define JSON_HEDLEY_PRIVATE +# define JSON_HEDLEY_PUBLIC +# endif +# define JSON_HEDLEY_IMPORT extern +#endif + +#if defined(JSON_HEDLEY_NO_THROW) + #undef JSON_HEDLEY_NO_THROW +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(nothrow) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,3,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_NO_THROW __attribute__((__nothrow__)) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) + #define JSON_HEDLEY_NO_THROW __declspec(nothrow) +#else + #define JSON_HEDLEY_NO_THROW +#endif + +#if defined(JSON_HEDLEY_FALL_THROUGH) + #undef JSON_HEDLEY_FALL_THROUGH +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(fallthrough) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(7,0,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_FALL_THROUGH __attribute__((__fallthrough__)) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE_NS(clang,fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[clang::fallthrough]]) +#elif JSON_HEDLEY_HAS_CPP_ATTRIBUTE(fallthrough) + #define JSON_HEDLEY_FALL_THROUGH JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_([[fallthrough]]) +#elif defined(__fallthrough) /* SAL */ + #define JSON_HEDLEY_FALL_THROUGH __fallthrough +#else + #define JSON_HEDLEY_FALL_THROUGH +#endif + +#if defined(JSON_HEDLEY_RETURNS_NON_NULL) + #undef JSON_HEDLEY_RETURNS_NON_NULL +#endif +#if \ + JSON_HEDLEY_HAS_ATTRIBUTE(returns_nonnull) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_RETURNS_NON_NULL __attribute__((__returns_nonnull__)) +#elif defined(_Ret_notnull_) /* SAL */ + #define JSON_HEDLEY_RETURNS_NON_NULL _Ret_notnull_ +#else + #define JSON_HEDLEY_RETURNS_NON_NULL +#endif + +#if defined(JSON_HEDLEY_ARRAY_PARAM) + #undef JSON_HEDLEY_ARRAY_PARAM +#endif +#if \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) && \ + !defined(__STDC_NO_VLA__) && \ + !defined(__cplusplus) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_TINYC_VERSION) + #define JSON_HEDLEY_ARRAY_PARAM(name) (name) +#else + #define JSON_HEDLEY_ARRAY_PARAM(name) +#endif + +#if defined(JSON_HEDLEY_IS_CONSTANT) + #undef JSON_HEDLEY_IS_CONSTANT +#endif +#if defined(JSON_HEDLEY_REQUIRE_CONSTEXPR) + #undef JSON_HEDLEY_REQUIRE_CONSTEXPR +#endif +/* JSON_HEDLEY_IS_CONSTEXPR_ is for + HEDLEY INTERNAL USE ONLY. API subject to change without notice. */ +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #undef JSON_HEDLEY_IS_CONSTEXPR_ +#endif +#if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_constant_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,19) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(4,1,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_TI_CL6X_VERSION_CHECK(6,1,0) || \ + (JSON_HEDLEY_SUNPRO_VERSION_CHECK(5,10,0) && !defined(__cplusplus)) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_MCST_LCC_VERSION_CHECK(1,25,10) + #define JSON_HEDLEY_IS_CONSTANT(expr) __builtin_constant_p(expr) +#endif +#if !defined(__cplusplus) +# if \ + JSON_HEDLEY_HAS_BUILTIN(__builtin_types_compatible_p) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(3,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(13,1,0) || \ + JSON_HEDLEY_CRAY_VERSION_CHECK(8,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,4,0) || \ + JSON_HEDLEY_TINYC_VERSION_CHECK(0,9,24) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0)), int*) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) __builtin_types_compatible_p(__typeof__((1 ? (void*) ((intptr_t) ((expr) * 0)) : (int*) 0)), int*) +#endif +# elif \ + ( \ + defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) && \ + !defined(JSON_HEDLEY_SUNPRO_VERSION) && \ + !defined(JSON_HEDLEY_PGI_VERSION) && \ + !defined(JSON_HEDLEY_IAR_VERSION)) || \ + (JSON_HEDLEY_HAS_EXTENSION(c_generic_selections) && !defined(JSON_HEDLEY_IAR_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,9,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(17,0,0) || \ + JSON_HEDLEY_IBM_VERSION_CHECK(12,1,0) || \ + JSON_HEDLEY_ARM_VERSION_CHECK(5,3,0) +#if defined(__INTPTR_TYPE__) + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((__INTPTR_TYPE__) ((expr) * 0)) : (int*) 0), int*: 1, void*: 0) +#else + #include + #define JSON_HEDLEY_IS_CONSTEXPR_(expr) _Generic((1 ? (void*) ((intptr_t) * 0) : (int*) 0), int*: 1, void*: 0) +#endif +# elif \ + defined(JSON_HEDLEY_GCC_VERSION) || \ + defined(JSON_HEDLEY_INTEL_VERSION) || \ + defined(JSON_HEDLEY_TINYC_VERSION) || \ + defined(JSON_HEDLEY_TI_ARMCL_VERSION) || \ + JSON_HEDLEY_TI_CL430_VERSION_CHECK(18,12,0) || \ + defined(JSON_HEDLEY_TI_CL2000_VERSION) || \ + defined(JSON_HEDLEY_TI_CL6X_VERSION) || \ + defined(JSON_HEDLEY_TI_CL7X_VERSION) || \ + defined(JSON_HEDLEY_TI_CLPRU_VERSION) || \ + defined(__clang__) +# define JSON_HEDLEY_IS_CONSTEXPR_(expr) ( \ + sizeof(void) != \ + sizeof(*( \ + 1 ? \ + ((void*) ((expr) * 0L) ) : \ +((struct { char v[sizeof(void) * 2]; } *) 1) \ + ) \ + ) \ + ) +# endif +#endif +#if defined(JSON_HEDLEY_IS_CONSTEXPR_) + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) JSON_HEDLEY_IS_CONSTEXPR_(expr) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (JSON_HEDLEY_IS_CONSTEXPR_(expr) ? (expr) : (-1)) +#else + #if !defined(JSON_HEDLEY_IS_CONSTANT) + #define JSON_HEDLEY_IS_CONSTANT(expr) (0) + #endif + #define JSON_HEDLEY_REQUIRE_CONSTEXPR(expr) (expr) +#endif + +#if defined(JSON_HEDLEY_BEGIN_C_DECLS) + #undef JSON_HEDLEY_BEGIN_C_DECLS +#endif +#if defined(JSON_HEDLEY_END_C_DECLS) + #undef JSON_HEDLEY_END_C_DECLS +#endif +#if defined(JSON_HEDLEY_C_DECL) + #undef JSON_HEDLEY_C_DECL +#endif +#if defined(__cplusplus) + #define JSON_HEDLEY_BEGIN_C_DECLS extern "C" { + #define JSON_HEDLEY_END_C_DECLS } + #define JSON_HEDLEY_C_DECL extern "C" +#else + #define JSON_HEDLEY_BEGIN_C_DECLS + #define JSON_HEDLEY_END_C_DECLS + #define JSON_HEDLEY_C_DECL +#endif + +#if defined(JSON_HEDLEY_STATIC_ASSERT) + #undef JSON_HEDLEY_STATIC_ASSERT +#endif +#if \ + !defined(__cplusplus) && ( \ + (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L)) || \ + (JSON_HEDLEY_HAS_FEATURE(c_static_assert) && !defined(JSON_HEDLEY_INTEL_CL_VERSION)) || \ + JSON_HEDLEY_GCC_VERSION_CHECK(6,0,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) || \ + defined(_Static_assert) \ + ) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) _Static_assert(expr, message) +#elif \ + (defined(__cplusplus) && (__cplusplus >= 201103L)) || \ + JSON_HEDLEY_MSVC_VERSION_CHECK(16,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(static_assert(expr, message)) +#else +# define JSON_HEDLEY_STATIC_ASSERT(expr, message) +#endif + +#if defined(JSON_HEDLEY_NULL) + #undef JSON_HEDLEY_NULL +#endif +#if defined(__cplusplus) + #if __cplusplus >= 201103L + #define JSON_HEDLEY_NULL JSON_HEDLEY_DIAGNOSTIC_DISABLE_CPP98_COMPAT_WRAP_(nullptr) + #elif defined(NULL) + #define JSON_HEDLEY_NULL NULL + #else + #define JSON_HEDLEY_NULL JSON_HEDLEY_STATIC_CAST(void*, 0) + #endif +#elif defined(NULL) + #define JSON_HEDLEY_NULL NULL +#else + #define JSON_HEDLEY_NULL ((void*) 0) +#endif + +#if defined(JSON_HEDLEY_MESSAGE) + #undef JSON_HEDLEY_MESSAGE +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_MESSAGE(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(message msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message msg) +#elif JSON_HEDLEY_CRAY_VERSION_CHECK(5,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(_CRI message msg) +#elif JSON_HEDLEY_IAR_VERSION_CHECK(8,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#elif JSON_HEDLEY_PELLES_VERSION_CHECK(2,0,0) +# define JSON_HEDLEY_MESSAGE(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_WARNING) + #undef JSON_HEDLEY_WARNING +#endif +#if JSON_HEDLEY_HAS_WARNING("-Wunknown-pragmas") +# define JSON_HEDLEY_WARNING(msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + JSON_HEDLEY_DIAGNOSTIC_DISABLE_UNKNOWN_PRAGMAS \ + JSON_HEDLEY_PRAGMA(clang warning msg) \ + JSON_HEDLEY_DIAGNOSTIC_POP +#elif \ + JSON_HEDLEY_GCC_VERSION_CHECK(4,8,0) || \ + JSON_HEDLEY_PGI_VERSION_CHECK(18,4,0) || \ + JSON_HEDLEY_INTEL_VERSION_CHECK(13,0,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(GCC warning msg) +#elif \ + JSON_HEDLEY_MSVC_VERSION_CHECK(15,0,0) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_PRAGMA(message(msg)) +#else +# define JSON_HEDLEY_WARNING(msg) JSON_HEDLEY_MESSAGE(msg) +#endif + +#if defined(JSON_HEDLEY_REQUIRE) + #undef JSON_HEDLEY_REQUIRE +#endif +#if defined(JSON_HEDLEY_REQUIRE_MSG) + #undef JSON_HEDLEY_REQUIRE_MSG +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(diagnose_if) +# if JSON_HEDLEY_HAS_WARNING("-Wgcc-compat") +# define JSON_HEDLEY_REQUIRE(expr) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), #expr, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("clang diagnostic ignored \"-Wgcc-compat\"") \ + __attribute__((diagnose_if(!(expr), msg, "error"))) \ + JSON_HEDLEY_DIAGNOSTIC_POP +# else +# define JSON_HEDLEY_REQUIRE(expr) __attribute__((diagnose_if(!(expr), #expr, "error"))) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) __attribute__((diagnose_if(!(expr), msg, "error"))) +# endif +#else +# define JSON_HEDLEY_REQUIRE(expr) +# define JSON_HEDLEY_REQUIRE_MSG(expr,msg) +#endif + +#if defined(JSON_HEDLEY_FLAGS) + #undef JSON_HEDLEY_FLAGS +#endif +#if JSON_HEDLEY_HAS_ATTRIBUTE(flag_enum) && (!defined(__cplusplus) || JSON_HEDLEY_HAS_WARNING("-Wbitfield-enum-conversion")) + #define JSON_HEDLEY_FLAGS __attribute__((__flag_enum__)) +#else + #define JSON_HEDLEY_FLAGS +#endif + +#if defined(JSON_HEDLEY_FLAGS_CAST) + #undef JSON_HEDLEY_FLAGS_CAST +#endif +#if JSON_HEDLEY_INTEL_VERSION_CHECK(19,0,0) +# define JSON_HEDLEY_FLAGS_CAST(T, expr) (__extension__ ({ \ + JSON_HEDLEY_DIAGNOSTIC_PUSH \ + _Pragma("warning(disable:188)") \ + ((T) (expr)); \ + JSON_HEDLEY_DIAGNOSTIC_POP \ + })) +#else +# define JSON_HEDLEY_FLAGS_CAST(T, expr) JSON_HEDLEY_STATIC_CAST(T, expr) +#endif + +#if defined(JSON_HEDLEY_EMPTY_BASES) + #undef JSON_HEDLEY_EMPTY_BASES +#endif +#if \ + (JSON_HEDLEY_MSVC_VERSION_CHECK(19,0,23918) && !JSON_HEDLEY_MSVC_VERSION_CHECK(20,0,0)) || \ + JSON_HEDLEY_INTEL_CL_VERSION_CHECK(2021,1,0) + #define JSON_HEDLEY_EMPTY_BASES __declspec(empty_bases) +#else + #define JSON_HEDLEY_EMPTY_BASES +#endif + +/* Remaining macros are deprecated. */ + +#if defined(JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK) + #undef JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK +#endif +#if defined(__clang__) + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) (0) +#else + #define JSON_HEDLEY_GCC_NOT_CLANG_VERSION_CHECK(major,minor,patch) JSON_HEDLEY_GCC_VERSION_CHECK(major,minor,patch) +#endif + +#if defined(JSON_HEDLEY_CLANG_HAS_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_CPP_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_CPP_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_BUILTIN) + #undef JSON_HEDLEY_CLANG_HAS_BUILTIN +#endif +#define JSON_HEDLEY_CLANG_HAS_BUILTIN(builtin) JSON_HEDLEY_HAS_BUILTIN(builtin) + +#if defined(JSON_HEDLEY_CLANG_HAS_FEATURE) + #undef JSON_HEDLEY_CLANG_HAS_FEATURE +#endif +#define JSON_HEDLEY_CLANG_HAS_FEATURE(feature) JSON_HEDLEY_HAS_FEATURE(feature) + +#if defined(JSON_HEDLEY_CLANG_HAS_EXTENSION) + #undef JSON_HEDLEY_CLANG_HAS_EXTENSION +#endif +#define JSON_HEDLEY_CLANG_HAS_EXTENSION(extension) JSON_HEDLEY_HAS_EXTENSION(extension) + +#if defined(JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE) + #undef JSON_HEDLEY_CLANG_HAS_DECLSPEC_DECLSPEC_ATTRIBUTE +#endif +#define JSON_HEDLEY_CLANG_HAS_DECLSPEC_ATTRIBUTE(attribute) JSON_HEDLEY_HAS_DECLSPEC_ATTRIBUTE(attribute) + +#if defined(JSON_HEDLEY_CLANG_HAS_WARNING) + #undef JSON_HEDLEY_CLANG_HAS_WARNING +#endif +#define JSON_HEDLEY_CLANG_HAS_WARNING(warning) JSON_HEDLEY_HAS_WARNING(warning) + +#endif /* !defined(JSON_HEDLEY_VERSION) || (JSON_HEDLEY_VERSION < X) */ + + +// This file contains all internal macro definitions (except those affecting ABI) +// You MUST include macro_unscope.hpp at the end of json.hpp to undef all of them + +// #include + + +// exclude unsupported compilers +#if !defined(JSON_SKIP_UNSUPPORTED_COMPILER_CHECK) + #if defined(__clang__) + #if (__clang_major__ * 10000 + __clang_minor__ * 100 + __clang_patchlevel__) < 30400 + #error "unsupported Clang version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #elif defined(__GNUC__) && !(defined(__ICC) || defined(__INTEL_COMPILER)) + #if (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__) < 40800 + #error "unsupported GCC version - see https://github.com/nlohmann/json#supported-compilers" + #endif + #endif +#endif + +// C++ language standard detection +// if the user manually specified the used c++ version this is skipped +#if !defined(JSON_HAS_CPP_20) && !defined(JSON_HAS_CPP_17) && !defined(JSON_HAS_CPP_14) && !defined(JSON_HAS_CPP_11) + #if (defined(__cplusplus) && __cplusplus >= 202002L) || (defined(_MSVC_LANG) && _MSVC_LANG >= 202002L) + #define JSON_HAS_CPP_20 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 + #elif (defined(__cplusplus) && __cplusplus >= 201703L) || (defined(_HAS_CXX17) && _HAS_CXX17 == 1) // fix for issue #464 + #define JSON_HAS_CPP_17 + #define JSON_HAS_CPP_14 + #elif (defined(__cplusplus) && __cplusplus >= 201402L) || (defined(_HAS_CXX14) && _HAS_CXX14 == 1) + #define JSON_HAS_CPP_14 + #endif + // the cpp 11 flag is always specified because it is the minimal required version + #define JSON_HAS_CPP_11 +#endif + +#ifdef __has_include + #if __has_include() + #include + #endif +#endif + +#if !defined(JSON_HAS_FILESYSTEM) && !defined(JSON_HAS_EXPERIMENTAL_FILESYSTEM) + #ifdef JSON_HAS_CPP_17 + #if defined(__cpp_lib_filesystem) + #define JSON_HAS_FILESYSTEM 1 + #elif defined(__cpp_lib_experimental_filesystem) + #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 + #elif !defined(__has_include) + #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 + #elif __has_include() + #define JSON_HAS_FILESYSTEM 1 + #elif __has_include() + #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 1 + #endif + + // std::filesystem does not work on MinGW GCC 8: https://sourceforge.net/p/mingw-w64/bugs/737/ + #if defined(__MINGW32__) && defined(__GNUC__) && __GNUC__ == 8 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + + // no filesystem support before GCC 8: https://en.cppreference.com/w/cpp/compiler_support + #if defined(__GNUC__) && !defined(__clang__) && __GNUC__ < 8 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + + // no filesystem support before Clang 7: https://en.cppreference.com/w/cpp/compiler_support + #if defined(__clang_major__) && __clang_major__ < 7 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + + // no filesystem support before MSVC 19.14: https://en.cppreference.com/w/cpp/compiler_support + #if defined(_MSC_VER) && _MSC_VER < 1914 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + + // no filesystem support before iOS 13 + #if defined(__IPHONE_OS_VERSION_MIN_REQUIRED) && __IPHONE_OS_VERSION_MIN_REQUIRED < 130000 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + + // no filesystem support before macOS Catalina + #if defined(__MAC_OS_X_VERSION_MIN_REQUIRED) && __MAC_OS_X_VERSION_MIN_REQUIRED < 101500 + #undef JSON_HAS_FILESYSTEM + #undef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #endif + #endif +#endif + +#ifndef JSON_HAS_EXPERIMENTAL_FILESYSTEM + #define JSON_HAS_EXPERIMENTAL_FILESYSTEM 0 +#endif + +#ifndef JSON_HAS_FILESYSTEM + #define JSON_HAS_FILESYSTEM 0 +#endif + +#ifndef JSON_HAS_THREE_WAY_COMPARISON + #if defined(__cpp_impl_three_way_comparison) && __cpp_impl_three_way_comparison >= 201907L \ + && defined(__cpp_lib_three_way_comparison) && __cpp_lib_three_way_comparison >= 201907L + #define JSON_HAS_THREE_WAY_COMPARISON 1 + #else + #define JSON_HAS_THREE_WAY_COMPARISON 0 + #endif +#endif + +#ifndef JSON_HAS_RANGES + // ranges header shipping in GCC 11.1.0 (released 2021-04-27) has syntax error + #if defined(__GLIBCXX__) && __GLIBCXX__ == 20210427 + #define JSON_HAS_RANGES 0 + #elif defined(__cpp_lib_ranges) + #define JSON_HAS_RANGES 1 + #else + #define JSON_HAS_RANGES 0 + #endif +#endif + +#ifndef JSON_HAS_STATIC_RTTI + #if !defined(_HAS_STATIC_RTTI) || _HAS_STATIC_RTTI != 0 + #define JSON_HAS_STATIC_RTTI 1 + #else + #define JSON_HAS_STATIC_RTTI 0 + #endif +#endif + +#ifdef JSON_HAS_CPP_17 + #define JSON_INLINE_VARIABLE inline +#else + #define JSON_INLINE_VARIABLE +#endif + +#if JSON_HEDLEY_HAS_ATTRIBUTE(no_unique_address) + #define JSON_NO_UNIQUE_ADDRESS [[no_unique_address]] +#else + #define JSON_NO_UNIQUE_ADDRESS +#endif + +// disable documentation warnings on clang +#if defined(__clang__) + #pragma clang diagnostic push + #pragma clang diagnostic ignored "-Wdocumentation" + #pragma clang diagnostic ignored "-Wdocumentation-unknown-command" +#endif + +// allow disabling exceptions +#if (defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND)) && !defined(JSON_NOEXCEPTION) + #define JSON_THROW(exception) throw exception + #define JSON_TRY try + #define JSON_CATCH(exception) catch(exception) + #define JSON_INTERNAL_CATCH(exception) catch(exception) +#else + #include + #define JSON_THROW(exception) std::abort() + #define JSON_TRY if(true) + #define JSON_CATCH(exception) if(false) + #define JSON_INTERNAL_CATCH(exception) if(false) +#endif + +// override exception macros +#if defined(JSON_THROW_USER) + #undef JSON_THROW + #define JSON_THROW JSON_THROW_USER +#endif +#if defined(JSON_TRY_USER) + #undef JSON_TRY + #define JSON_TRY JSON_TRY_USER +#endif +#if defined(JSON_CATCH_USER) + #undef JSON_CATCH + #define JSON_CATCH JSON_CATCH_USER + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_CATCH_USER +#endif +#if defined(JSON_INTERNAL_CATCH_USER) + #undef JSON_INTERNAL_CATCH + #define JSON_INTERNAL_CATCH JSON_INTERNAL_CATCH_USER +#endif + +// allow overriding assert +#if !defined(JSON_ASSERT) + #include // assert + #define JSON_ASSERT(x) assert(x) +#endif + +// allow to access some private functions (needed by the test suite) +#if defined(JSON_TESTS_PRIVATE) + #define JSON_PRIVATE_UNLESS_TESTED public +#else + #define JSON_PRIVATE_UNLESS_TESTED private +#endif + +/*! +@brief macro to briefly define a mapping between an enum and JSON +@def NLOHMANN_JSON_SERIALIZE_ENUM +@since version 3.4.0 +*/ +#define NLOHMANN_JSON_SERIALIZE_ENUM(ENUM_TYPE, ...) \ + template \ + inline void to_json(BasicJsonType& j, const ENUM_TYPE& e) \ + { \ + /* NOLINTNEXTLINE(modernize-type-traits) we use C++11 */ \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + /* NOLINTNEXTLINE(modernize-avoid-c-arrays) we don't want to depend on */ \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [e](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.first == e; \ + }); \ + j = ((it != std::end(m)) ? it : std::begin(m))->second; \ + } \ + template \ + inline void from_json(const BasicJsonType& j, ENUM_TYPE& e) \ + { \ + /* NOLINTNEXTLINE(modernize-type-traits) we use C++11 */ \ + static_assert(std::is_enum::value, #ENUM_TYPE " must be an enum!"); \ + /* NOLINTNEXTLINE(modernize-avoid-c-arrays) we don't want to depend on */ \ + static const std::pair m[] = __VA_ARGS__; \ + auto it = std::find_if(std::begin(m), std::end(m), \ + [&j](const std::pair& ej_pair) -> bool \ + { \ + return ej_pair.second == j; \ + }); \ + e = ((it != std::end(m)) ? it : std::begin(m))->first; \ + } + +// Ugly macros to avoid uglier copy-paste when specializing basic_json. They +// may be removed in the future once the class is split. + +#define NLOHMANN_BASIC_JSON_TPL_DECLARATION \ + template class ObjectType, \ + template class ArrayType, \ + class StringType, class BooleanType, class NumberIntegerType, \ + class NumberUnsignedType, class NumberFloatType, \ + template class AllocatorType, \ + template class JSONSerializer, \ + class BinaryType, \ + class CustomBaseClass> + +#define NLOHMANN_BASIC_JSON_TPL \ + basic_json + +// Macros to simplify conversion from/to types + +#define NLOHMANN_JSON_EXPAND( x ) x +#define NLOHMANN_JSON_GET_MACRO(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, _61, _62, _63, _64, NAME,...) NAME +#define NLOHMANN_JSON_PASTE(...) NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_GET_MACRO(__VA_ARGS__, \ + NLOHMANN_JSON_PASTE64, \ + NLOHMANN_JSON_PASTE63, \ + NLOHMANN_JSON_PASTE62, \ + NLOHMANN_JSON_PASTE61, \ + NLOHMANN_JSON_PASTE60, \ + NLOHMANN_JSON_PASTE59, \ + NLOHMANN_JSON_PASTE58, \ + NLOHMANN_JSON_PASTE57, \ + NLOHMANN_JSON_PASTE56, \ + NLOHMANN_JSON_PASTE55, \ + NLOHMANN_JSON_PASTE54, \ + NLOHMANN_JSON_PASTE53, \ + NLOHMANN_JSON_PASTE52, \ + NLOHMANN_JSON_PASTE51, \ + NLOHMANN_JSON_PASTE50, \ + NLOHMANN_JSON_PASTE49, \ + NLOHMANN_JSON_PASTE48, \ + NLOHMANN_JSON_PASTE47, \ + NLOHMANN_JSON_PASTE46, \ + NLOHMANN_JSON_PASTE45, \ + NLOHMANN_JSON_PASTE44, \ + NLOHMANN_JSON_PASTE43, \ + NLOHMANN_JSON_PASTE42, \ + NLOHMANN_JSON_PASTE41, \ + NLOHMANN_JSON_PASTE40, \ + NLOHMANN_JSON_PASTE39, \ + NLOHMANN_JSON_PASTE38, \ + NLOHMANN_JSON_PASTE37, \ + NLOHMANN_JSON_PASTE36, \ + NLOHMANN_JSON_PASTE35, \ + NLOHMANN_JSON_PASTE34, \ + NLOHMANN_JSON_PASTE33, \ + NLOHMANN_JSON_PASTE32, \ + NLOHMANN_JSON_PASTE31, \ + NLOHMANN_JSON_PASTE30, \ + NLOHMANN_JSON_PASTE29, \ + NLOHMANN_JSON_PASTE28, \ + NLOHMANN_JSON_PASTE27, \ + NLOHMANN_JSON_PASTE26, \ + NLOHMANN_JSON_PASTE25, \ + NLOHMANN_JSON_PASTE24, \ + NLOHMANN_JSON_PASTE23, \ + NLOHMANN_JSON_PASTE22, \ + NLOHMANN_JSON_PASTE21, \ + NLOHMANN_JSON_PASTE20, \ + NLOHMANN_JSON_PASTE19, \ + NLOHMANN_JSON_PASTE18, \ + NLOHMANN_JSON_PASTE17, \ + NLOHMANN_JSON_PASTE16, \ + NLOHMANN_JSON_PASTE15, \ + NLOHMANN_JSON_PASTE14, \ + NLOHMANN_JSON_PASTE13, \ + NLOHMANN_JSON_PASTE12, \ + NLOHMANN_JSON_PASTE11, \ + NLOHMANN_JSON_PASTE10, \ + NLOHMANN_JSON_PASTE9, \ + NLOHMANN_JSON_PASTE8, \ + NLOHMANN_JSON_PASTE7, \ + NLOHMANN_JSON_PASTE6, \ + NLOHMANN_JSON_PASTE5, \ + NLOHMANN_JSON_PASTE4, \ + NLOHMANN_JSON_PASTE3, \ + NLOHMANN_JSON_PASTE2, \ + NLOHMANN_JSON_PASTE1)(__VA_ARGS__)) +#define NLOHMANN_JSON_PASTE2(func, v1) func(v1) +#define NLOHMANN_JSON_PASTE3(func, v1, v2) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE2(func, v2) +#define NLOHMANN_JSON_PASTE4(func, v1, v2, v3) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE3(func, v2, v3) +#define NLOHMANN_JSON_PASTE5(func, v1, v2, v3, v4) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE4(func, v2, v3, v4) +#define NLOHMANN_JSON_PASTE6(func, v1, v2, v3, v4, v5) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE5(func, v2, v3, v4, v5) +#define NLOHMANN_JSON_PASTE7(func, v1, v2, v3, v4, v5, v6) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE6(func, v2, v3, v4, v5, v6) +#define NLOHMANN_JSON_PASTE8(func, v1, v2, v3, v4, v5, v6, v7) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE7(func, v2, v3, v4, v5, v6, v7) +#define NLOHMANN_JSON_PASTE9(func, v1, v2, v3, v4, v5, v6, v7, v8) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE8(func, v2, v3, v4, v5, v6, v7, v8) +#define NLOHMANN_JSON_PASTE10(func, v1, v2, v3, v4, v5, v6, v7, v8, v9) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE9(func, v2, v3, v4, v5, v6, v7, v8, v9) +#define NLOHMANN_JSON_PASTE11(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE10(func, v2, v3, v4, v5, v6, v7, v8, v9, v10) +#define NLOHMANN_JSON_PASTE12(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE11(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11) +#define NLOHMANN_JSON_PASTE13(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE12(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12) +#define NLOHMANN_JSON_PASTE14(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE13(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13) +#define NLOHMANN_JSON_PASTE15(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE14(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14) +#define NLOHMANN_JSON_PASTE16(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE15(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15) +#define NLOHMANN_JSON_PASTE17(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE16(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16) +#define NLOHMANN_JSON_PASTE18(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE17(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17) +#define NLOHMANN_JSON_PASTE19(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE18(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18) +#define NLOHMANN_JSON_PASTE20(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE19(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19) +#define NLOHMANN_JSON_PASTE21(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE20(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20) +#define NLOHMANN_JSON_PASTE22(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE21(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21) +#define NLOHMANN_JSON_PASTE23(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE22(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22) +#define NLOHMANN_JSON_PASTE24(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE23(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23) +#define NLOHMANN_JSON_PASTE25(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE24(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24) +#define NLOHMANN_JSON_PASTE26(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE25(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25) +#define NLOHMANN_JSON_PASTE27(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE26(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26) +#define NLOHMANN_JSON_PASTE28(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE27(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27) +#define NLOHMANN_JSON_PASTE29(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE28(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28) +#define NLOHMANN_JSON_PASTE30(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE29(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29) +#define NLOHMANN_JSON_PASTE31(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE30(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30) +#define NLOHMANN_JSON_PASTE32(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE31(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31) +#define NLOHMANN_JSON_PASTE33(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE32(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32) +#define NLOHMANN_JSON_PASTE34(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE33(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33) +#define NLOHMANN_JSON_PASTE35(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE34(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34) +#define NLOHMANN_JSON_PASTE36(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE35(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35) +#define NLOHMANN_JSON_PASTE37(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE36(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36) +#define NLOHMANN_JSON_PASTE38(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE37(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37) +#define NLOHMANN_JSON_PASTE39(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE38(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38) +#define NLOHMANN_JSON_PASTE40(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE39(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39) +#define NLOHMANN_JSON_PASTE41(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE40(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40) +#define NLOHMANN_JSON_PASTE42(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE41(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41) +#define NLOHMANN_JSON_PASTE43(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE42(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42) +#define NLOHMANN_JSON_PASTE44(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE43(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43) +#define NLOHMANN_JSON_PASTE45(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE44(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44) +#define NLOHMANN_JSON_PASTE46(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE45(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45) +#define NLOHMANN_JSON_PASTE47(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE46(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46) +#define NLOHMANN_JSON_PASTE48(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE47(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47) +#define NLOHMANN_JSON_PASTE49(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE48(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48) +#define NLOHMANN_JSON_PASTE50(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE49(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49) +#define NLOHMANN_JSON_PASTE51(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE50(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50) +#define NLOHMANN_JSON_PASTE52(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE51(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51) +#define NLOHMANN_JSON_PASTE53(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE52(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52) +#define NLOHMANN_JSON_PASTE54(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE53(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53) +#define NLOHMANN_JSON_PASTE55(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE54(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54) +#define NLOHMANN_JSON_PASTE56(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE55(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55) +#define NLOHMANN_JSON_PASTE57(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE56(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56) +#define NLOHMANN_JSON_PASTE58(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE57(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57) +#define NLOHMANN_JSON_PASTE59(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE58(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58) +#define NLOHMANN_JSON_PASTE60(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE59(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59) +#define NLOHMANN_JSON_PASTE61(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE60(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60) +#define NLOHMANN_JSON_PASTE62(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE61(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61) +#define NLOHMANN_JSON_PASTE63(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE62(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62) +#define NLOHMANN_JSON_PASTE64(func, v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) NLOHMANN_JSON_PASTE2(func, v1) NLOHMANN_JSON_PASTE63(func, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49, v50, v51, v52, v53, v54, v55, v56, v57, v58, v59, v60, v61, v62, v63) + +#define NLOHMANN_JSON_TO(v1) nlohmann_json_j[#v1] = nlohmann_json_t.v1; +#define NLOHMANN_JSON_FROM(v1) nlohmann_json_j.at(#v1).get_to(nlohmann_json_t.v1); +#define NLOHMANN_JSON_FROM_WITH_DEFAULT(v1) nlohmann_json_t.v1 = !nlohmann_json_j.is_null() ? nlohmann_json_j.value(#v1, nlohmann_json_default_obj.v1) : nlohmann_json_default_obj.v1; + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT +@since version 3.11.0 +*/ +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_WITH_DEFAULT(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE +@since version 3.11.x +*/ +#define NLOHMANN_DEFINE_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE +@since version 3.9.0 +*/ +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT +@since version 3.11.0 +*/ +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE +@since version 3.11.x +*/ +#define NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } + +/*! +@brief macro +@def NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE +@since version 3.11.x +*/ +#define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE(Type, BaseType, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +#define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + friend void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } + +#define NLOHMANN_DEFINE_DERIVED_TYPE_INTRUSIVE_ONLY_SERIALIZE(Type, BaseType, ...) \ + friend void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + +/*! +@brief macro +@def NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE +@since version 3.11.x +*/ +#define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE(Type, BaseType, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM, __VA_ARGS__)) } + +#define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE_WITH_DEFAULT(Type, BaseType, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + inline void from_json(const nlohmann::json& nlohmann_json_j, Type& nlohmann_json_t) { nlohmann::from_json(nlohmann_json_j, static_cast(nlohmann_json_t)); const Type nlohmann_json_default_obj{}; NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_FROM_WITH_DEFAULT, __VA_ARGS__)) } + +#define NLOHMANN_DEFINE_DERIVED_TYPE_NON_INTRUSIVE_ONLY_SERIALIZE(Type, BaseType, ...) \ + inline void to_json(nlohmann::json& nlohmann_json_j, const Type& nlohmann_json_t) { nlohmann::to_json(nlohmann_json_j, static_cast(nlohmann_json_t)); NLOHMANN_JSON_EXPAND(NLOHMANN_JSON_PASTE(NLOHMANN_JSON_TO, __VA_ARGS__)) } \ + +// inspired from https://stackoverflow.com/a/26745591 +// allows to call any std function as if (e.g. with begin): +// using std::begin; begin(x); +// +// it allows using the detected idiom to retrieve the return type +// of such an expression +#define NLOHMANN_CAN_CALL_STD_FUNC_IMPL(std_name) \ + namespace detail { \ + using std::std_name; \ + \ + template \ + using result_of_##std_name = decltype(std_name(std::declval()...)); \ + } \ + \ + namespace detail2 { \ + struct std_name##_tag \ + { \ + }; \ + \ + template \ + std_name##_tag std_name(T&&...); \ + \ + template \ + using result_of_##std_name = decltype(std_name(std::declval()...)); \ + \ + template \ + struct would_call_std_##std_name \ + { \ + static constexpr auto const value = ::nlohmann::detail:: \ + is_detected_exact::value; \ + }; \ + } /* namespace detail2 */ \ + \ + template \ + struct would_call_std_##std_name : detail2::would_call_std_##std_name \ + { \ + } + +#ifndef JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_USE_IMPLICIT_CONVERSIONS 1 +#endif + +#if JSON_USE_IMPLICIT_CONVERSIONS + #define JSON_EXPLICIT +#else + #define JSON_EXPLICIT explicit +#endif + +#ifndef JSON_DISABLE_ENUM_SERIALIZATION + #define JSON_DISABLE_ENUM_SERIALIZATION 0 +#endif + +#ifndef JSON_USE_GLOBAL_UDLS + #define JSON_USE_GLOBAL_UDLS 1 +#endif + +#if JSON_HAS_THREE_WAY_COMPARISON + #include // partial_ordering +#endif + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +/////////////////////////// +// JSON type enumeration // +/////////////////////////// + +/*! +@brief the JSON type enumeration + +This enumeration collects the different JSON types. It is internally used to +distinguish the stored values, and the functions @ref basic_json::is_null(), +@ref basic_json::is_object(), @ref basic_json::is_array(), +@ref basic_json::is_string(), @ref basic_json::is_boolean(), +@ref basic_json::is_number() (with @ref basic_json::is_number_integer(), +@ref basic_json::is_number_unsigned(), and @ref basic_json::is_number_float()), +@ref basic_json::is_discarded(), @ref basic_json::is_primitive(), and +@ref basic_json::is_structured() rely on it. + +@note There are three enumeration entries (number_integer, number_unsigned, and +number_float), because the library distinguishes these three types for numbers: +@ref basic_json::number_unsigned_t is used for unsigned integers, +@ref basic_json::number_integer_t is used for signed integers, and +@ref basic_json::number_float_t is used for floating-point numbers or to +approximate integers which do not fit in the limits of their respective type. + +@sa see @ref basic_json::basic_json(const value_t value_type) -- create a JSON +value with the default value for a given type + +@since version 1.0.0 +*/ +enum class value_t : std::uint8_t +{ + null, ///< null value + object, ///< object (unordered set of name/value pairs) + array, ///< array (ordered collection of values) + string, ///< string value + boolean, ///< boolean value + number_integer, ///< number value (signed integer) + number_unsigned, ///< number value (unsigned integer) + number_float, ///< number value (floating-point) + binary, ///< binary array (ordered collection of bytes) + discarded ///< discarded by the parser callback function +}; + +/*! +@brief comparison operator for JSON types + +Returns an ordering that is similar to Python: +- order: null < boolean < number < object < array < string < binary +- furthermore, each type is not smaller than itself +- discarded values are not comparable +- binary is represented as a b"" string in python and directly comparable to a + string; however, making a binary array directly comparable with a string would + be surprising behavior in a JSON file. + +@since version 1.0.0 +*/ +#if JSON_HAS_THREE_WAY_COMPARISON + inline std::partial_ordering operator<=>(const value_t lhs, const value_t rhs) noexcept // *NOPAD* +#else + inline bool operator<(const value_t lhs, const value_t rhs) noexcept +#endif +{ + static constexpr std::array order = {{ + 0 /* null */, 3 /* object */, 4 /* array */, 5 /* string */, + 1 /* boolean */, 2 /* integer */, 2 /* unsigned */, 2 /* float */, + 6 /* binary */ + } + }; + + const auto l_index = static_cast(lhs); + const auto r_index = static_cast(rhs); +#if JSON_HAS_THREE_WAY_COMPARISON + if (l_index < order.size() && r_index < order.size()) + { + return order[l_index] <=> order[r_index]; // *NOPAD* + } + return std::partial_ordering::unordered; +#else + return l_index < order.size() && r_index < order.size() && order[l_index] < order[r_index]; +#endif +} + +// GCC selects the built-in operator< over an operator rewritten from +// a user-defined spaceship operator +// Clang, MSVC, and ICC select the rewritten candidate +// (see GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=105200) +#if JSON_HAS_THREE_WAY_COMPARISON && defined(__GNUC__) +inline bool operator<(const value_t lhs, const value_t rhs) noexcept +{ + return std::is_lt(lhs <=> rhs); // *NOPAD* +} +#endif + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +/*! +@brief replace all occurrences of a substring by another string + +@param[in,out] s the string to manipulate; changed so that all + occurrences of @a f are replaced with @a t +@param[in] f the substring to replace with @a t +@param[in] t the string to replace @a f + +@pre The search string @a f must not be empty. **This precondition is +enforced with an assertion.** + +@since version 2.0.0 +*/ +template +inline void replace_substring(StringType& s, const StringType& f, + const StringType& t) +{ + JSON_ASSERT(!f.empty()); + for (auto pos = s.find(f); // find first occurrence of f + pos != StringType::npos; // make sure f was found + s.replace(pos, f.size(), t), // replace with t, and + pos = s.find(f, pos + t.size())) // find next occurrence of f + {} +} + +/*! + * @brief string escaping as described in RFC 6901 (Sect. 4) + * @param[in] s string to escape + * @return escaped string + * + * Note the order of escaping "~" to "~0" and "/" to "~1" is important. + */ +template +inline StringType escape(StringType s) +{ + replace_substring(s, StringType{"~"}, StringType{"~0"}); + replace_substring(s, StringType{"/"}, StringType{"~1"}); + return s; +} + +/*! + * @brief string unescaping as described in RFC 6901 (Sect. 4) + * @param[in] s string to unescape + * @return unescaped string + * + * Note the order of escaping "~1" to "/" and "~0" to "~" is important. + */ +template +static void unescape(StringType& s) +{ + replace_substring(s, StringType{"~1"}, StringType{"/"}); + replace_substring(s, StringType{"~0"}, StringType{"~"}); +} + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // size_t + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +/// struct to capture the start position of the current token +struct position_t +{ + /// the total number of characters read + std::size_t chars_read_total = 0; + /// the number of characters read in the current line + std::size_t chars_read_current_line = 0; + /// the number of lines read + std::size_t lines_read = 0; + + /// conversion to size_t to preserve SAX interface + constexpr operator size_t() const + { + return chars_read_total; + } +}; + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-FileCopyrightText: 2018 The Abseil Authors +// SPDX-License-Identifier: MIT + + + +#include // array +#include // size_t +#include // conditional, enable_if, false_type, integral_constant, is_constructible, is_integral, is_same, remove_cv, remove_reference, true_type +#include // index_sequence, make_index_sequence, index_sequence_for + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +template +using uncvref_t = typename std::remove_cv::type>::type; + +#ifdef JSON_HAS_CPP_14 + +// the following utilities are natively available in C++14 +using std::enable_if_t; +using std::index_sequence; +using std::make_index_sequence; +using std::index_sequence_for; + +#else + +// alias templates to reduce boilerplate +template +using enable_if_t = typename std::enable_if::type; + +// The following code is taken from https://github.com/abseil/abseil-cpp/blob/10cb35e459f5ecca5b2ff107635da0bfa41011b4/absl/utility/utility.h +// which is part of Google Abseil (https://github.com/abseil/abseil-cpp), licensed under the Apache License 2.0. + +//// START OF CODE FROM GOOGLE ABSEIL + +// integer_sequence +// +// Class template representing a compile-time integer sequence. An instantiation +// of `integer_sequence` has a sequence of integers encoded in its +// type through its template arguments (which is a common need when +// working with C++11 variadic templates). `absl::integer_sequence` is designed +// to be a drop-in replacement for C++14's `std::integer_sequence`. +// +// Example: +// +// template< class T, T... Ints > +// void user_function(integer_sequence); +// +// int main() +// { +// // user_function's `T` will be deduced to `int` and `Ints...` +// // will be deduced to `0, 1, 2, 3, 4`. +// user_function(make_integer_sequence()); +// } +template +struct integer_sequence +{ + using value_type = T; + static constexpr std::size_t size() noexcept + { + return sizeof...(Ints); + } +}; + +// index_sequence +// +// A helper template for an `integer_sequence` of `size_t`, +// `absl::index_sequence` is designed to be a drop-in replacement for C++14's +// `std::index_sequence`. +template +using index_sequence = integer_sequence; + +namespace utility_internal +{ + +template +struct Extend; + +// Note that SeqSize == sizeof...(Ints). It's passed explicitly for efficiency. +template +struct Extend, SeqSize, 0> +{ + using type = integer_sequence < T, Ints..., (Ints + SeqSize)... >; +}; + +template +struct Extend, SeqSize, 1> +{ + using type = integer_sequence < T, Ints..., (Ints + SeqSize)..., 2 * SeqSize >; +}; + +// Recursion helper for 'make_integer_sequence'. +// 'Gen::type' is an alias for 'integer_sequence'. +template +struct Gen +{ + using type = + typename Extend < typename Gen < T, N / 2 >::type, N / 2, N % 2 >::type; +}; + +template +struct Gen +{ + using type = integer_sequence; +}; + +} // namespace utility_internal + +// Compile-time sequences of integers + +// make_integer_sequence +// +// This template alias is equivalent to +// `integer_sequence`, and is designed to be a drop-in +// replacement for C++14's `std::make_integer_sequence`. +template +using make_integer_sequence = typename utility_internal::Gen::type; + +// make_index_sequence +// +// This template alias is equivalent to `index_sequence<0, 1, ..., N-1>`, +// and is designed to be a drop-in replacement for C++14's +// `std::make_index_sequence`. +template +using make_index_sequence = make_integer_sequence; + +// index_sequence_for +// +// Converts a typename pack into an index sequence of the same length, and +// is designed to be a drop-in replacement for C++14's +// `std::index_sequence_for()` +template +using index_sequence_for = make_index_sequence; + +//// END OF CODE FROM GOOGLE ABSEIL + +#endif + +// dispatch utility (taken from ranges-v3) +template struct priority_tag : priority_tag < N - 1 > {}; +template<> struct priority_tag<0> {}; + +// taken from ranges-v3 +template +struct static_const +{ + static JSON_INLINE_VARIABLE constexpr T value{}; +}; + +#ifndef JSON_HAS_CPP_17 + template + constexpr T static_const::value; +#endif + +template +constexpr std::array make_array(Args&& ... args) +{ + return std::array {{static_cast(std::forward(args))...}}; +} + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // numeric_limits +#include // char_traits +#include // tuple +#include // false_type, is_constructible, is_integral, is_same, true_type +#include // declval + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +#include // random_access_iterator_tag + +// #include + +// #include + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN +namespace detail +{ + +template +struct iterator_types {}; + +template +struct iterator_types < + It, + void_t> +{ + using difference_type = typename It::difference_type; + using value_type = typename It::value_type; + using pointer = typename It::pointer; + using reference = typename It::reference; + using iterator_category = typename It::iterator_category; +}; + +// This is required as some compilers implement std::iterator_traits in a way that +// doesn't work with SFINAE. See https://github.com/nlohmann/json/issues/1341. +template +struct iterator_traits +{ +}; + +template +struct iterator_traits < T, enable_if_t < !std::is_pointer::value >> + : iterator_types +{ +}; + +template +struct iterator_traits::value>> +{ + using iterator_category = std::random_access_iterator_tag; + using value_type = T; + using difference_type = ptrdiff_t; + using pointer = T*; + using reference = T&; +}; + +} // namespace detail +NLOHMANN_JSON_NAMESPACE_END + +// #include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN + +NLOHMANN_CAN_CALL_STD_FUNC_IMPL(begin); + +NLOHMANN_JSON_NAMESPACE_END + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + + + +// #include + + +NLOHMANN_JSON_NAMESPACE_BEGIN + +NLOHMANN_CAN_CALL_STD_FUNC_IMPL(end); + +NLOHMANN_JSON_NAMESPACE_END + +// #include + +// #include + +// #include +// __ _____ _____ _____ +// __| | __| | | | JSON for Modern C++ +// | | |__ | | | | | | version 3.11.3 +// |_____|_____|_____|_|___| https://github.com/nlohmann/json +// +// SPDX-FileCopyrightText: 2013 - 2024 Niels Lohmann +// SPDX-License-Identifier: MIT + +#ifndef INCLUDE_NLOHMANN_JSON_FWD_HPP_ + #define INCLUDE_NLOHMANN_JSON_FWD_HPP_ + + #include // int64_t, uint64_t + #include // map + #include // allocator + #include // string + #include // vector + + // #include + + + /*! + @brief namespace for Niels Lohmann + @see https://github.com/nlohmann + @since version 1.0.0 + */ + NLOHMANN_JSON_NAMESPACE_BEGIN + + /*! + @brief default JSONSerializer template argument + + This serializer ignores the template arguments and uses ADL + ([argument-dependent lookup](https://en.cppreference.com/w/cpp/language/adl)) + for serialization. + */ + template + struct adl_serializer; + + /// a class to store JSON values + /// @sa https://json.nlohmann.me/api/basic_json/ + template class ObjectType = + std::map, + template class ArrayType = std::vector, + class StringType = std::string, class BooleanType = bool, + class NumberIntegerType = std::int64_t, + class NumberUnsignedType = std::uint64_t, + class NumberFloatType = double, + template class AllocatorType = std::allocator, + template class JSONSerializer = + adl_serializer, + class BinaryType = std::vector, // cppcheck-suppress syntaxError + class CustomBaseClass = void> + class basic_json; + + /// @brief JSON Pointer defines a string syntax for identifying a specific value within a JSON document + /// @sa https://json.nlohmann.me/api/json_pointer/ + template + class json_pointer; + + /*! + @brief default specialization + @sa https://json.nlohmann.me/api/json/ + */ + using json = basic_json<>; + + /// @brief a minimal map-like container that preserves insertion order + /// @sa https://json.nlohmann.me/api/ordered_map/ + template + struct ordered_map; + + /// @brief specialization that maintains the insertion order of object keys + /// @sa https://json.nlohmann.me/api/ordered_json/ + using ordered_json = basic_json; + + NLOHMANN_JSON_NAMESPACE_END + +#endif // INCLUDE_NLOHMANN_JSON_FWD_HPP_ + + +NLOHMANN_JSON_NAMESPACE_BEGIN +/*! +@brief detail namespace with internal helper functions + +This namespace collects functions that should not be exposed, +implementations of some @ref basic_json methods, and meta-programming helpers. + +@since version 2.1.0 +*/ +namespace detail +{ + +///////////// +// helpers // +///////////// + +// Note to maintainers: +// +// Every trait in this file expects a non CV-qualified type. +// The only exceptions are in the 'aliases for detected' section +// (i.e. those of the form: decltype(T::member_function(std::declval()))) +// +// In this case, T has to be properly CV-qualified to constraint the function arguments +// (e.g. to_json(BasicJsonType&, const T&)) + +template struct is_basic_json : std::false_type {}; + +NLOHMANN_BASIC_JSON_TPL_DECLARATION +struct is_basic_json : std::true_type {}; + +// used by exceptions create() member functions +// true_type for pointer to possibly cv-qualified basic_json or std::nullptr_t +// false_type otherwise +template +struct is_basic_json_context : + std::integral_constant < bool, + is_basic_json::type>::type>::value + || std::is_same::value > +{}; + +////////////////////// +// json_ref helpers // +////////////////////// + +template +class json_ref; + +template +struct is_json_ref : std::false_type {}; + +template +struct is_json_ref> : std::true_type {}; + +////////////////////////// +// aliases for detected // +////////////////////////// + +template +using mapped_type_t = typename T::mapped_type; + +template +using key_type_t = typename T::key_type; + +template +using value_type_t = typename T::value_type; + +template +using difference_type_t = typename T::difference_type; + +template +using pointer_t = typename T::pointer; + +template +using reference_t = typename T::reference; + +template +using iterator_category_t = typename T::iterator_category; + +template +using to_json_function = decltype(T::to_json(std::declval()...)); + +template +using from_json_function = decltype(T::from_json(std::declval()...)); + +template +using get_template_function = decltype(std::declval().template get()); + +// trait checking if JSONSerializer::from_json(json const&, udt&) exists +template +struct has_from_json : std::false_type {}; + +// trait checking if j.get is valid +// use this trait instead of std::is_constructible or std::is_convertible, +// both rely on, or make use of implicit conversions, and thus fail when T +// has several constructors/operator= (see https://github.com/nlohmann/json/issues/958) +template +struct is_getable +{ + static constexpr bool value = is_detected::value; +}; + +template +struct has_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if JSONSerializer::from_json(json const&) exists +// this overload is used for non-default-constructible user-defined-types +template +struct has_non_default_from_json : std::false_type {}; + +template +struct has_non_default_from_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +// This trait checks if BasicJsonType::json_serializer::to_json exists +// Do not evaluate the trait when T is a basic_json type, to avoid template instantiation infinite recursion. +template +struct has_to_json : std::false_type {}; + +template +struct has_to_json < BasicJsonType, T, enable_if_t < !is_basic_json::value >> +{ + using serializer = typename BasicJsonType::template json_serializer; + + static constexpr bool value = + is_detected_exact::value; +}; + +template +using detect_key_compare = typename T::key_compare; + +template +struct has_key_compare : std::integral_constant::value> {}; + +// obtains the actual object key comparator +template +struct actual_object_comparator +{ + using object_t = typename BasicJsonType::object_t; + using object_comparator_t = typename BasicJsonType::default_object_comparator_t; + using type = typename std::conditional < has_key_compare::value, + typename object_t::key_compare, object_comparator_t>::type; +}; + +template +using actual_object_comparator_t = typename actual_object_comparator::type; + +///////////////// +// char_traits // +///////////////// + +// Primary template of char_traits calls std char_traits +template +struct char_traits : std::char_traits +{}; + +// Explicitly define char traits for unsigned char since it is not standard +template<> +struct char_traits : std::char_traits +{ + using char_type = unsigned char; + using int_type = uint64_t; + + // Redefine to_int_type function + static int_type to_int_type(char_type c) noexcept + { + return static_cast(c); + } + + static char_type to_char_type(int_type i) noexcept + { + return static_cast(i); + } + + static constexpr int_type eof() noexcept + { + return static_cast(std::char_traits::eof()); + } +}; + +// Explicitly define char traits for signed char since it is not standard +template<> +struct char_traits : std::char_traits +{ + using char_type = signed char; + using int_type = uint64_t; + + // Redefine to_int_type function + static int_type to_int_type(char_type c) noexcept + { + return static_cast(c); + } + + static char_type to_char_type(int_type i) noexcept + { + return static_cast(i); + } + + static constexpr int_type eof() noexcept + { + return static_cast(std::char_traits::eof()); + } +}; + +/////////////////// +// is_ functions // +/////////////////// + +// https://en.cppreference.com/w/cpp/types/conjunction +template struct conjunction : std::true_type { }; +template struct conjunction : B { }; +template +struct conjunction +: std::conditional(B::value), conjunction, B>::type {}; + +// https://en.cppreference.com/w/cpp/types/negation +template struct negation : std::integral_constant < bool, !B::value > { }; + +// Reimplementation of is_constructible and is_default_constructible, due to them being broken for +// std::pair and std::tuple until LWG 2367 fix (see https://cplusplus.github.io/LWG/lwg-defects.html#2367). +// This causes compile errors in e.g. clang 3.5 or gcc 4.9. +template +struct is_default_constructible : std::is_default_constructible {}; + +template +struct is_default_constructible> + : conjunction, is_default_constructible> {}; + +template +struct is_default_constructible> + : conjunction, is_default_constructible> {}; + +template +struct is_default_constructible> + : conjunction...> {}; + +template +struct is_default_constructible> + : conjunction...> {}; + +template +struct is_constructible : std::is_constructible {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_constructible> : is_default_constructible> {}; + +template +struct is_iterator_traits : std::false_type {}; + +template +struct is_iterator_traits> +{ + private: + using traits = iterator_traits; + + public: + static constexpr auto value = + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value && + is_detected::value; +}; + +template +struct is_range +{ + private: + using t_ref = typename std::add_lvalue_reference::type; + + using iterator = detected_t; + using sentinel = detected_t; + + // to be 100% correct, it should use https://en.cppreference.com/w/cpp/iterator/input_or_output_iterator + // and https://en.cppreference.com/w/cpp/iterator/sentinel_for + // but reimplementing these would be too much work, as a lot of other concepts are used underneath + static constexpr auto is_iterator_begin = + is_iterator_traits>::value; + + public: + static constexpr bool value = !std::is_same::value && !std::is_same::value && is_iterator_begin; +}; + +template +using iterator_t = enable_if_t::value, result_of_begin())>>; + +template +using range_value_t = value_type_t>>; + +// The following implementation of is_complete_type is taken from +// https://blogs.msdn.microsoft.com/vcblog/2015/12/02/partial-support-for-expression-sfinae-in-vs-2015-update-1/ +// and is written by Xiang Fan who agreed to using it in this library. + +template +struct is_complete_type : std::false_type {}; + +template +struct is_complete_type : std::true_type {}; + +template +struct is_compatible_object_type_impl : std::false_type {}; + +template +struct is_compatible_object_type_impl < + BasicJsonType, CompatibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + using object_t = typename BasicJsonType::object_t; + + // macOS's is_constructible does not play well with nonesuch... + static constexpr bool value = + is_constructible::value && + is_constructible::value; +}; + +template +struct is_compatible_object_type + : is_compatible_object_type_impl {}; + +template +struct is_constructible_object_type_impl : std::false_type {}; + +template +struct is_constructible_object_type_impl < + BasicJsonType, ConstructibleObjectType, + enable_if_t < is_detected::value&& + is_detected::value >> +{ + using object_t = typename BasicJsonType::object_t; + + static constexpr bool value = + (is_default_constructible::value && + (std::is_move_assignable::value || + std::is_copy_assignable::value) && + (is_constructible::value && + std::is_same < + typename object_t::mapped_type, + typename ConstructibleObjectType::mapped_type >::value)) || + (has_from_json::value || + has_non_default_from_json < + BasicJsonType, + typename ConstructibleObjectType::mapped_type >::value); +}; + +template +struct is_constructible_object_type + : is_constructible_object_type_impl {}; + +template +struct is_compatible_string_type +{ + static constexpr auto value = + is_constructible::value; +}; + +template +struct is_constructible_string_type +{ + // launder type through decltype() to fix compilation failure on ICPC +#ifdef __INTEL_COMPILER + using laundered_type = decltype(std::declval()); +#else + using laundered_type = ConstructibleStringType; +#endif + + static constexpr auto value = + conjunction < + is_constructible, + is_detected_exact>::value; +}; + +template +struct is_compatible_array_type_impl : std::false_type {}; + +template +struct is_compatible_array_type_impl < + BasicJsonType, CompatibleArrayType, + enable_if_t < + is_detected::value&& + is_iterator_traits>>::value&& +// special case for types like std::filesystem::path whose iterator's value_type are themselves +// c.f. https://github.com/nlohmann/json/pull/3073 + !std::is_same>::value >> +{ + static constexpr bool value = + is_constructible>::value; +}; + +template +struct is_compatible_array_type + : is_compatible_array_type_impl {}; + +template +struct is_constructible_array_type_impl : std::false_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t::value >> + : std::true_type {}; + +template +struct is_constructible_array_type_impl < + BasicJsonType, ConstructibleArrayType, + enable_if_t < !std::is_same::value&& + !is_compatible_string_type::value&& + is_default_constructible::value&& +(std::is_move_assignable::value || + std::is_copy_assignable::value)&& +is_detected::value&& +is_iterator_traits>>::value&& +is_detected::value&& +// special case for types like std::filesystem::path whose iterator's value_type are themselves +// c.f. https://github.com/nlohmann/json/pull/3073 +!std::is_same>::value&& +is_complete_type < +detected_t>::value >> +{ + using value_type = range_value_t; + + static constexpr bool value = + std::is_same::value || + has_from_json::value || + has_non_default_from_json < + BasicJsonType, + value_type >::value; +}; + +template +struct is_constructible_array_type + : is_constructible_array_type_impl {}; + +template +struct is_compatible_integer_type_impl : std::false_type {}; + +template +struct is_compatible_integer_type_impl < + RealIntegerType, CompatibleNumberIntegerType, + enable_if_t < std::is_integral::value&& + std::is_integral::value&& + !std::is_same::value >> +{ + // is there an assert somewhere on overflows? + using RealLimits = std::numeric_limits; + using CompatibleLimits = std::numeric_limits; + + static constexpr auto value = + is_constructible::value && + CompatibleLimits::is_integer && + RealLimits::is_signed == CompatibleLimits::is_signed; +}; + +template +struct is_compatible_integer_type + : is_compatible_integer_type_impl {}; + +template +struct is_compatible_type_impl: std::false_type {}; + +template +struct is_compatible_type_impl < + BasicJsonType, CompatibleType, + enable_if_t::value >> +{ + static constexpr bool value = + has_to_json::value; +}; + +template +struct is_compatible_type + : is_compatible_type_impl {}; + +template +struct is_constructible_tuple : std::false_type {}; + +template +struct is_constructible_tuple> : conjunction...> {}; + +template +struct is_json_iterator_of : std::false_type {}; + +template +struct is_json_iterator_of : std::true_type {}; + +template +struct is_json_iterator_of : std::true_type +{}; + +// checks if a given type T is a template specialization of Primary +template