Skip to content

Commit

Permalink
implement lazy root computation (#18)
Browse files Browse the repository at this point in the history
* implement lazy root computation

* review feedback use width instead of 2

* review feedback: reverse output of and change name of checkForNil to noMissingData

* review feedback add test for public root API
  • Loading branch information
evan-forbes authored Mar 26, 2021
1 parent 178b2b6 commit b641792
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 20 deletions.
43 changes: 32 additions & 11 deletions datasquare.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,8 @@ func (ds *dataSquare) computeRoots() {
rowRoots := make([][]byte, ds.width)
columnRoots := make([][]byte, ds.width)
for i := uint(0); i < ds.width; i++ {
rowTree := ds.createTreeFn()
columnTree := ds.createTreeFn()
rowData := ds.Row(i)
columnData := ds.Column(i)
for j := uint(0); j < ds.width; j++ {
rowTree.Push(rowData[j])
columnTree.Push(columnData[j])
}

rowRoots[i] = rowTree.Root()
columnRoots[i] = columnTree.Root()
rowRoots[i] = ds.RowRoot(i)
columnRoots[i] = ds.ColRoot(i)
}

ds.rowRoots = rowRoots
Expand All @@ -167,6 +158,21 @@ func (ds *dataSquare) RowRoots() [][]byte {
return ds.rowRoots
}

// RowRoot calculates and returns the root of the selected row. Note: unlike the
// RowRoots method, RowRoot uses the built-in cache when available.
func (ds *dataSquare) RowRoot(x uint) []byte {
if ds.rowRoots != nil {
return ds.rowRoots[x]
}

tree := ds.createTreeFn()
for _, d := range ds.Row(x) {
tree.Push(d)
}

return tree.Root()
}

// ColumnRoots returns the Merkle roots of all the columns in the square.
func (ds *dataSquare) ColumnRoots() [][]byte {
if ds.columnRoots == nil {
Expand All @@ -176,6 +182,21 @@ func (ds *dataSquare) ColumnRoots() [][]byte {
return ds.columnRoots
}

// ColRoot calculates and returns the root of the selected row. Note: unlike the
// ColRoots method, ColRoot does not use the built in cache
func (ds *dataSquare) ColRoot(y uint) []byte {
if ds.columnRoots != nil {
return ds.columnRoots[y]
}

tree := ds.createTreeFn()
for _, d := range ds.Column(y) {
tree.Push(d)
}

return tree.Root()
}

func (ds *dataSquare) computeRowProof(x uint, y uint) ([]byte, [][]byte, uint, uint, error) {
tree := ds.createTreeFn()
data := ds.Row(x)
Expand Down
45 changes: 45 additions & 0 deletions datasquare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,51 @@ func TestRoots(t *testing.T) {
}
}

func TestLazyRootGeneration(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

var rowRoots [][]byte
var colRoots [][]byte

for i := uint(0); i < square.width; i++ {
rowRoots = append(rowRoots, square.RowRoot(i))
colRoots = append(rowRoots, square.ColRoot(i))
}

square.computeRoots()

if !reflect.DeepEqual(square.rowRoots, rowRoots) && !reflect.DeepEqual(square.columnRoots, colRoots) {
t.Error("RowRoot or ColumnRoot did not produce identical roots to computeRoots")
}
}

func TestRootAPI(t *testing.T) {
square, err := newDataSquare([][]byte{{1}, {2}, {3}, {4}}, NewDefaultTree)
if err != nil {
panic(err)
}

for i := uint(0); i < square.width; i++ {
if !reflect.DeepEqual(square.RowRoots()[i], square.RowRoot(i)) {
t.Errorf(
"Row root API results in different roots, expected %v go %v",
square.RowRoots()[i],
square.RowRoot(i),
)
}
if !reflect.DeepEqual(square.ColumnRoots()[i], square.ColRoot(i)) {
t.Errorf(
"Column root API results in different roots, expected %v go %v",
square.ColumnRoots()[i],
square.ColRoot(i),
)
}
}
}

func TestProofs(t *testing.T) {
result, err := newDataSquare([][]byte{{1, 2}, {3, 4}, {5, 6}, {7, 8}}, NewDefaultTree)
if err != nil {
Expand Down
41 changes: 32 additions & 9 deletions extendeddatacrossword.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package rsmt2d

import (
"bytes"
"errors"
"fmt"

"gonum.org/v1/gonum/mat"
Expand Down Expand Up @@ -170,11 +169,11 @@ func (eds *ExtendedDataSquare) solveCrossword(rowRoots [][]byte, columnRoots [][

// Check that rebuilt vector matches given merkle root
if mode == row {
if !bytes.Equal(eds.RowRoots()[i], rowRoots[i]) {
if !bytes.Equal(eds.RowRoot(i), rowRoots[i]) {
return &ByzantineRowError{i, edsBackup}
}
} else if mode == column {
if !bytes.Equal(eds.ColumnRoots()[i], columnRoots[i]) {
if !bytes.Equal(eds.ColRoot(i), columnRoots[i]) {
return &ByzantineColumnError{i, edsBackup}
}
}
Expand All @@ -184,12 +183,12 @@ func (eds *ExtendedDataSquare) solveCrossword(rowRoots [][]byte, columnRoots [][
if vectorMask.AtVec(int(j)) == 0 {
if mode == row {
adjMask := mask.ColView(int(j))
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.ColumnRoots()[j], columnRoots[j]) {
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.ColRoot(j), columnRoots[j]) {
return &ByzantineColumnError{j, edsBackup}
}
} else if mode == column {
adjMask := mask.RowView(int(j))
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.RowRoots()[j], rowRoots[j]) {
if vecNumTrue(adjMask) == adjMask.Len()-1 && !bytes.Equal(eds.RowRoot(j), rowRoots[j]) {
return &ByzantineRowError{j, edsBackup}
}
}
Expand Down Expand Up @@ -229,11 +228,26 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
for i := uint(0); i < eds.width; i++ {
rowMask := mask.RowView(int(i))
columnMask := mask.ColView(int(i))
if (vecIsTrue(rowMask) && !bytes.Equal(rowRoots[i], eds.RowRoots()[i])) || (vecIsTrue(columnMask) && !bytes.Equal(columnRoots[i], eds.ColumnRoots()[i])) {
return errors.New("bad roots input")
rowMaskIsVec := vecIsTrue(rowMask)
columnMaskIsVec := vecIsTrue(columnMask)

// if there's no missing data in the this row
if noMissingData(eds.Row(i)) {
// ensure that the roots are equal and that rowMask is a vector
if rowMaskIsVec && !bytes.Equal(rowRoots[i], eds.RowRoot(i)) {
return fmt.Errorf("bad root input: row %d expected %v got %v", i, rowRoots[i], eds.RowRoot(i))
}
}

if vecIsTrue(rowMask) {
// if there's no missing data in the this col
if noMissingData(eds.Column(i)) {
// ensure that the roots are equal and that rowMask is a vector
if columnMaskIsVec && !bytes.Equal(columnRoots[i], eds.ColRoot(i)) {
return fmt.Errorf("bad root input: col %d expected %v got %v", i, columnRoots[i], eds.ColRoot(i))
}
}

if rowMaskIsVec {
shares, err = Encode(eds.rowSlice(i, 0, eds.originalDataWidth), eds.codec)
if err != nil {
return err
Expand All @@ -243,7 +257,7 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
}
}

if vecIsTrue(columnMask) {
if columnMaskIsVec {
shares, err = Encode(eds.columnSlice(0, i, eds.originalDataWidth), eds.codec)
if err != nil {
return err
Expand All @@ -257,6 +271,15 @@ func (eds *ExtendedDataSquare) prerepairSanityCheck(rowRoots [][]byte, columnRoo
return nil
}

func noMissingData(input [][]byte) bool {
for _, d := range input {
if d == nil {
return false
}
}
return true
}

func vecIsTrue(vec mat.Vector) bool {
for i := 0; i < vec.Len(); i++ {
if vec.AtVec(i) == 0 {
Expand Down

0 comments on commit b641792

Please sign in to comment.