Skip to content

Commit

Permalink
refactor(bb): use RefArray where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
ludamad0 committed Feb 20, 2024
1 parent 21ec23d commit e03dbe9
Show file tree
Hide file tree
Showing 17 changed files with 1,127 additions and 1,042 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include "barretenberg/commitment_schemes/commitment_key.hpp"
#include "barretenberg/common/ref_span.hpp"
#include "barretenberg/common/ref_vector.hpp"
#include "barretenberg/common/zip_view.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
Expand Down Expand Up @@ -316,15 +317,15 @@ template <typename Curve> class ZeroMorphProver_ {
* @param commitment_key
* @param transcript
*/
static void prove(const std::vector<Polynomial>& f_polynomials,
const std::vector<Polynomial>& g_polynomials,
const std::vector<FF>& f_evaluations,
const std::vector<FF>& g_shift_evaluations,
const std::vector<FF>& multilinear_challenge,
static void prove(RefSpan<Polynomial> f_polynomials,
RefSpan<Polynomial> g_polynomials,
RefSpan<FF> f_evaluations,
RefSpan<FF> g_shift_evaluations,
std::span<FF> multilinear_challenge,
const std::shared_ptr<CommitmentKey<Curve>>& commitment_key,
const std::shared_ptr<NativeTranscript>& transcript,
const std::vector<Polynomial>& concatenated_polynomials = {},
const std::vector<FF>& concatenated_evaluations = {},
RefSpan<Polynomial> concatenated_polynomials = {},
RefSpan<FF> concatenated_evaluations = {},
const std::vector<RefVector<Polynomial>>& concatenation_groups = {})
{
// Generate batching challenge \rho and powers 1,...,\rho^{m-1}
Expand Down Expand Up @@ -516,13 +517,13 @@ template <typename Curve> class ZeroMorphVerifier_ {
* @param concatenation_groups_commitments
* @return Commitment
*/
static Commitment compute_C_Z_x(const std::vector<Commitment>& f_commitments,
const std::vector<Commitment>& g_commitments,
std::vector<Commitment>& C_q_k,
static Commitment compute_C_Z_x(RefSpan<Commitment> f_commitments,
RefSpan<Commitment> g_commitments,
std::span<Commitment> C_q_k,
FF rho,
FF batched_evaluation,
FF x_challenge,
std::vector<FF> u_challenge,
std::span<FF> u_challenge,
const std::vector<RefVector<Commitment>>& concatenation_groups_commitments = {})
{
size_t log_N = C_q_k.size();
Expand Down Expand Up @@ -634,14 +635,14 @@ template <typename Curve> class ZeroMorphVerifier_ {
* @return std::array<Commitment, 2> Inputs to the final pairing check
*/
static std::array<Commitment, 2> verify(
auto&& unshifted_commitments,
auto&& to_be_shifted_commitments,
auto&& unshifted_evaluations,
auto&& shifted_evaluations,
auto& multivariate_challenge,
RefSpan<Commitment> unshifted_commitments,
RefSpan<Commitment> to_be_shifted_commitments,
RefSpan<FF> unshifted_evaluations,
RefSpan<FF> shifted_evaluations,
std::span<FF> multivariate_challenge,
auto& transcript,
const std::vector<RefVector<Commitment>>& concatenation_group_commitments = {},
const std::vector<FF>& concatenated_evaluations = {})
RefSpan<FF> concatenated_evaluations = {})
{
size_t log_N = multivariate_challenge.size();
FF rho = transcript->template get_challenge<FF>("rho");
Expand Down
16 changes: 12 additions & 4 deletions barretenberg/cpp/src/barretenberg/common/ref_array.hpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#pragma once

#include "barretenberg/common/assert.hpp"
#include <array>
#include <cstddef>
#include <initializer_list>
#include <iterator>
#include <span>
#include <stdexcept>

namespace bb {
Expand All @@ -18,6 +21,7 @@ namespace bb {
*/
template <typename T, std::size_t N> class RefArray {
public:
RefArray() = default;
RefArray(const std::array<T*, N>& ptr_array)
{
std::size_t i = 0;
Expand Down Expand Up @@ -92,6 +96,9 @@ template <typename T, std::size_t N> class RefArray {
*/
iterator end() const { return iterator(this, N); }

T** get_storage() { return storage; }
T* const* get_storage() const { return storage; }

private:
// We are making a high-level array, for simplicity having a C array as backing makes sense.
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays)
Expand All @@ -115,24 +122,25 @@ template <typename T, typename... Ts> RefArray(T&, Ts&...) -> RefArray<T, 1 + si
* @param ref_arrays The RefArray objects to be concatenated.
* @return RefArray object containing all elements from the input arrays.
*/
template <typename T, std::size_t... Ns> RefArray<T, (Ns + ...)> concatenate(const RefArray<T, Ns>&... ref_arrays)
template <typename T, std::size_t... Ns>
RefArray<T, (Ns + ...)> constexpr concatenate(const RefArray<T, Ns>&... ref_arrays)
{
// Fold expression to calculate the total size of the new array using fold expression
constexpr std::size_t TotalSize = (Ns + ...);
std::array<T*, TotalSize> concatenated;
RefArray<T, TotalSize> concatenated;

std::size_t offset = 0;
// Copies elements from a given RefArray to the concatenated array
auto copy_into = [&](const auto& ref_array, std::size_t& offset) {
for (std::size_t i = 0; i < ref_array.size(); ++i) {
concatenated[offset + i] = &ref_array[i];
concatenated.get_storage()[offset + i] = &ref_array[i];
}
offset += ref_array.size();
};

// Fold expression to copy elements from each input RefArray to the concatenated array
(..., copy_into(ref_arrays, offset));

return RefArray<T, TotalSize>{ concatenated };
return concatenated;
}
} // namespace bb
95 changes: 95 additions & 0 deletions barretenberg/cpp/src/barretenberg/common/ref_span.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#pragma once

#include <cstddef>
#include <iterator>

#include "ref_array.hpp"

namespace bb {

template <typename T> class RefSpan {
public:
// Default constructor
RefSpan()
: storage(nullptr)
, array_size(0)
{}

template <std::size_t Size>
RefSpan(const RefArray<T, Size>& ref_array)
: storage(ref_array.get_storage())
, array_size(Size)
{}

// Constructor from an array of pointers and size
RefSpan(T** ptr_array, std::size_t size)
: storage(ptr_array)
, array_size(size)
{}

// Copy constructor
RefSpan(const RefSpan& other) = default;

// Move constructor
RefSpan(RefSpan&& other) noexcept = default;

// Destructor
~RefSpan() = default;

// Copy assignment operator
RefSpan& operator=(const RefSpan& other) = default;

// Move assignment operator
RefSpan& operator=(RefSpan&& other) noexcept = default;

// Access element at index
T& operator[](std::size_t idx) const
{
// Assuming the caller ensures idx is within bounds.
return *storage[idx];
}

// Get size of the RefSpan
constexpr std::size_t size() const { return array_size; }

// Iterator implementation
class iterator {
public:
iterator(T* const* array, std::size_t pos)
: array(array)
, pos(pos)
{}

T& operator*() const { return *(array[pos]); }

iterator& operator++()
{
++pos;
return *this;
}

iterator operator++(int)
{
iterator temp = *this;
++(*this);
return temp;
}

bool operator==(const iterator& other) const { return pos == other.pos; }
bool operator!=(const iterator& other) const { return pos != other.pos; }

private:
T* const* array;
std::size_t pos;
};

// Begin and end for iterator support
iterator begin() const { return iterator(storage, 0); }
iterator end() const { return iterator(storage, array_size); }

private:
T* const* storage;
std::size_t array_size;
};

} // namespace bb
76 changes: 37 additions & 39 deletions barretenberg/cpp/src/barretenberg/flavor/ecc_vm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,8 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa

using FF = typename G1::subgroup_field;
using Polynomial = bb::Polynomial<FF>;
using PolynomialHandle = std::span<FF>;
using GroupElement = typename G1::element;
using Commitment = typename G1::affine_element;
using CommitmentHandle = typename G1::affine_element;
using CommitmentKey = bb::CommitmentKey<Curve>;
using VerifierCommitmentKey = bb::VerifierCommitmentKey<Curve>;
using RelationSeparator = FF;
Expand Down Expand Up @@ -96,9 +94,9 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
lagrange_last); // column 2

DataType get_selectors() { return get_all(); };
RefVector<DataType> get_sigma_polynomials() { return {}; };
RefVector<DataType> get_id_polynomials() { return {}; };
RefVector<DataType> get_table_polynomials() { return {}; };
auto get_sigma_polynomials() { return RefArray<DataType, 0>{}; };
auto get_id_polynomials() { return RefArray<DataType, 0>{}; };
auto get_table_polynomials() { return RefArray<DataType, 0>{}; };
};

/**
Expand Down Expand Up @@ -202,9 +200,9 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
class WitnessEntities : public WireEntities<DataType>, public DerivedWitnessEntities<DataType> {
public:
DEFINE_COMPOUND_GET_ALL(WireEntities<DataType>, DerivedWitnessEntities<DataType>)
RefVector<DataType> get_wires() { return WireEntities<DataType>::get_all(); };
auto get_wires() { return WireEntities<DataType>::get_all(); };
// The sorted concatenations of table and witness data needed for plookup.
RefVector<DataType> get_sorted_polynomials() { return {}; };
auto get_sorted_polynomials() { return RefArray<DataType, 0>{}; };
};

/**
Expand Down Expand Up @@ -242,35 +240,35 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
};

template <typename DataType, typename PrecomputedAndWitnessEntitiesSuperset>
static RefVector<DataType> get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities)
static auto get_to_be_shifted(PrecomputedAndWitnessEntitiesSuperset& entities)
{
// NOTE: must match order of ShiftedEntities above!
return { entities.transcript_mul,
entities.transcript_msm_count,
entities.transcript_accumulator_x,
entities.transcript_accumulator_y,
entities.precompute_scalar_sum,
entities.precompute_s1hi,
entities.precompute_dx,
entities.precompute_dy,
entities.precompute_tx,
entities.precompute_ty,
entities.msm_transition,
entities.msm_add,
entities.msm_double,
entities.msm_skew,
entities.msm_accumulator_x,
entities.msm_accumulator_y,
entities.msm_count,
entities.msm_round,
entities.msm_add1,
entities.msm_pc,
entities.precompute_pc,
entities.transcript_pc,
entities.precompute_round,
entities.transcript_accumulator_empty,
entities.precompute_select,
entities.z_perm };
return RefArray{ entities.transcript_mul,
entities.transcript_msm_count,
entities.transcript_accumulator_x,
entities.transcript_accumulator_y,
entities.precompute_scalar_sum,
entities.precompute_s1hi,
entities.precompute_dx,
entities.precompute_dy,
entities.precompute_tx,
entities.precompute_ty,
entities.msm_transition,
entities.msm_add,
entities.msm_double,
entities.msm_skew,
entities.msm_accumulator_x,
entities.msm_accumulator_y,
entities.msm_count,
entities.msm_round,
entities.msm_add1,
entities.msm_pc,
entities.precompute_pc,
entities.transcript_pc,
entities.precompute_round,
entities.transcript_accumulator_empty,
entities.precompute_select,
entities.z_perm };
}
/**
* @brief A base class labelling all entities (for instance, all of the polynomials used by the prover during
Expand All @@ -297,13 +295,13 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa

DEFINE_COMPOUND_GET_ALL(PrecomputedEntities<DataType>, WitnessEntities<DataType>, ShiftedEntities<DataType>)
// Gemini-specific getters.
RefVector<DataType> get_unshifted()
auto get_unshifted()
{
return concatenate(PrecomputedEntities<DataType>::get_all(), WitnessEntities<DataType>::get_all());
};

RefVector<DataType> get_to_be_shifted() { return ECCVMBase::get_to_be_shifted<DataType>(*this); }
RefVector<DataType> get_shifted() { return ShiftedEntities<DataType>::get_all(); };
auto get_to_be_shifted() { return ECCVMBase::get_to_be_shifted<DataType>(*this); }
auto get_shifted() { return ShiftedEntities<DataType>::get_all(); };
};

public:
Expand All @@ -318,9 +316,9 @@ template <typename CycleGroup_T, typename Curve_T, typename PCS_T> class ECCVMBa
using Base = ProvingKey_<PrecomputedEntities<Polynomial>, WitnessEntities<Polynomial>>;
using Base::Base;

RefVector<Polynomial> get_to_be_shifted() { return ECCVMBase::get_to_be_shifted<Polynomial>(*this); }
auto get_to_be_shifted() { return ECCVMBase::get_to_be_shifted<Polynomial>(*this); }
// The plookup wires that store plookup read data.
std::array<PolynomialHandle, 3> get_table_column_wires() { return {}; };
RefArray<Polynomial, 0> get_table_column_wires() { return {}; };
};

/**
Expand Down
6 changes: 3 additions & 3 deletions barretenberg/cpp/src/barretenberg/flavor/flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ class ProvingKey_ : public PrecomputedPolynomials, public WitnessPolynomials {
return concatenate(PrecomputedPolynomials::get_labels(), WitnessPolynomials::get_labels());
}
// This order matters! must match get_unshifted in entity classes
RefVector<Polynomial> get_all() { return concatenate(get_precomputed_polynomials(), get_witness_polynomials()); }
RefVector<Polynomial> get_witness_polynomials() { return WitnessPolynomials::get_all(); }
RefVector<Polynomial> get_precomputed_polynomials() { return PrecomputedPolynomials::get_all(); }
auto get_all() { return concatenate(get_precomputed_polynomials(), get_witness_polynomials()); }
auto get_witness_polynomials() { return WitnessPolynomials::get_all(); }
auto get_precomputed_polynomials() { return PrecomputedPolynomials::get_all(); }
ProvingKey_() = default;
ProvingKey_(const size_t circuit_size, const size_t num_public_inputs)
{
Expand Down
6 changes: 3 additions & 3 deletions barretenberg/cpp/src/barretenberg/flavor/flavor_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
// while DEFINE_COMPOUND_GET_ALL lets you combine the iterators of substructures or base
// classes.

#include "barretenberg/common/ref_vector.hpp"
#include "barretenberg/common/ref_array.hpp"
#include "barretenberg/common/std_array.hpp"
#include "barretenberg/common/std_string.hpp"
#include "barretenberg/common/std_vector.hpp"
Expand Down Expand Up @@ -40,11 +40,11 @@ template <typename T, typename... BaseClass> auto _concatenate_base_class_get_la
#define DEFINE_REF_VIEW(...) \
[[nodiscard]] auto get_all() \
{ \
return RefVector{ __VA_ARGS__ }; \
return RefArray{ __VA_ARGS__ }; \
} \
[[nodiscard]] auto get_all() const \
{ \
return RefVector{ __VA_ARGS__ }; \
return RefArray{ __VA_ARGS__ }; \
}

/**
Expand Down
Loading

0 comments on commit e03dbe9

Please sign in to comment.