Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subtests to test.Assert #191

Merged
merged 10 commits into from
Dec 22, 2021
56 changes: 30 additions & 26 deletions circuitstats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,43 @@ package gnark
import (
"encoding/gob"
"os"
"sync"
"testing"

"github.com/consensys/gnark-crypto/ecc"
"github.com/consensys/gnark/backend"
"github.com/consensys/gnark/frontend"
"github.com/consensys/gnark/internal/backend/circuits"
"github.com/stretchr/testify/require"
"github.com/consensys/gnark/test"
)

const (
fileStats = "init.stats"
generateNewStats = false
)

func TestCircuitStatistics(t *testing.T) {

assert := require.New(t)

curves := ecc.Implemented()
for name, tData := range circuits.Circuits {
var statsM sync.Mutex

for _, curve := range curves {
check := func(backendID backend.ID) {
t.Log(name, curve.String(), backendID.String())

ccs, err := frontend.Compile(curve, backendID, tData.Circuit)
assert.NoError(err)

// ensure we didn't introduce regressions that make circuits less efficient
nbConstraints := ccs.GetNbConstraints()
internal, secret, public := ccs.GetNbVariables()
checkStats(t, name, nbConstraints, internal, secret, public, curve, backendID)
func TestCircuitStatistics(t *testing.T) {
assert := test.NewAssert(t)
for k := range circuits.Circuits {
for _, curve := range ecc.Implemented() {
for _, b := range backend.Implemented() {
curve := curve
b := b
name := k
// copy the circuit now in case assert calls t.Parallel()
tData := circuits.Circuits[k]
assert.Run(func(assert *test.Assert) {
ccs, err := frontend.Compile(curve, b, tData.Circuit)
assert.NoError(err)

// ensure we didn't introduce regressions that make circuits less efficient
nbConstraints := ccs.GetNbConstraints()
internal, secret, public := ccs.GetNbVariables()
checkStats(assert, name, nbConstraints, internal, secret, public, curve, b)
}, name, curve.String(), b.String())
}
check(backend.GROTH16)
check(backend.PLONK)
}

}
Expand All @@ -59,28 +61,30 @@ type circuitStats struct {

var mStats map[string][backend.PLONK + 1][ecc.BW6_633 + 1]circuitStats

func checkStats(t *testing.T, circuitName string, nbConstraints, internal, secret, public int, curve ecc.ID, backendID backend.ID) {
func checkStats(assert *test.Assert, circuitName string, nbConstraints, internal, secret, public int, curve ecc.ID, backendID backend.ID) {
statsM.Lock()
defer statsM.Unlock()
if generateNewStats {
rs := mStats[circuitName]
rs[backendID][curve] = circuitStats{nbConstraints, internal, secret, public}
mStats[circuitName] = rs
return
}
if referenceStats, ok := mStats[circuitName]; !ok {
t.Log("warning: no stats for circuit", circuitName)
assert.Log("warning: no stats for circuit", circuitName)
} else {
ref := referenceStats[backendID][curve]
if ref.NbConstraints != nbConstraints {
t.Errorf("expected %d nbConstraints (reference), got %d. %s, %s, %s", ref.NbConstraints, nbConstraints, circuitName, backendID.String(), curve.String())
assert.Failf("unexpected constraint count", "expected %d nbConstraints (reference), got %d. %s, %s, %s", ref.NbConstraints, nbConstraints, circuitName, backendID.String(), curve.String())
}
if ref.Internal != internal {
t.Errorf("expected %d internal (reference), got %d. %s, %s, %s", ref.Internal, internal, circuitName, backendID.String(), curve.String())
assert.Failf("unexpected internal variable count", "expected %d internal (reference), got %d. %s, %s, %s", ref.Internal, internal, circuitName, backendID.String(), curve.String())
}
if ref.Secret != secret {
t.Errorf("expected %d secret (reference), got %d. %s, %s, %s", ref.Secret, secret, circuitName, backendID.String(), curve.String())
assert.Failf("unexpected secret variable count", "expected %d secret (reference), got %d. %s, %s, %s", ref.Secret, secret, circuitName, backendID.String(), curve.String())
}
if ref.Public != public {
t.Errorf("expected %d public (reference), got %d. %s, %s, %s", ref.Public, public, circuitName, backendID.String(), curve.String())
assert.Failf("unexpected public variable count", "expected %d public (reference), got %d. %s, %s, %s", ref.Public, public, circuitName, backendID.String(), curve.String())
}
}
}
Expand Down
41 changes: 17 additions & 24 deletions integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package gnark

import (
"fmt"
"sort"
"testing"

Expand All @@ -35,30 +36,22 @@ func TestIntegrationAPI(t *testing.T) {
}
sort.Strings(keys)

for _, k := range keys {

tData := circuits.Circuits[k]
t.Log(k)
for _, w := range tData.ValidWitnesses {
// assert.ProverSucceeded(tData.Circuit, w, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)))
assert.ProverSucceeded(
tData.Circuit,
w,
test.WithProverOpts(backend.WithHints(tData.HintFunctions...)),
test.WithCurves(tData.Curves[0], tData.Curves[1:]...))
}

for _, w := range tData.InvalidWitnesses {
assert.ProverFailed(
tData.Circuit,
w,
test.WithProverOpts(backend.WithHints(tData.HintFunctions...)),
test.WithCurves(tData.Curves[0], tData.Curves[1:]...))
}

// we put that here now, but will be into a proper fuzz target with go1.18
const fuzzCount = 30
assert.Fuzz(tData.Circuit, fuzzCount, test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), test.WithBackends(backend.GROTH16))
for i := range keys {
name := keys[i]
tData := circuits.Circuits[name]
assert.Run(func(assert *test.Assert) {
for i := range tData.ValidWitnesses {
assert.Run(func(assert *test.Assert) {
assert.ProverSucceeded(tData.Circuit, tData.ValidWitnesses[i], test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...))
}, fmt.Sprintf("valid-%d", i))
}

for i := range tData.InvalidWitnesses {
assert.Run(func(assert *test.Assert) {
assert.ProverFailed(tData.Circuit, tData.InvalidWitnesses[i], test.WithProverOpts(backend.WithHints(tData.HintFunctions...)), test.WithCurves(tData.Curves[0], tData.Curves[1:]...))
}, fmt.Sprintf("invalid-%d", i))
}
}, name)
}

}
132 changes: 67 additions & 65 deletions internal/backend/bls12-377/cs/r1cs_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading