Skip to content

Commit

Permalink
Refactor and a few optimizations (#671)
Browse files Browse the repository at this point in the history
* Benchmarked against rustvm

* final refactor and optimizations

* nit

* refactor
  • Loading branch information
Sh0g0-1758 authored Oct 15, 2024
1 parent 0d1b966 commit 76b422d
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 154 deletions.
7 changes: 1 addition & 6 deletions cmd/cli/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/core"
"github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/hinter"
hintrunner "github.com/NethermindEth/cairo-vm-go/pkg/hintrunner/zero"
cairoversion "github.com/NethermindEth/cairo-vm-go/pkg/parsers/cairo_version"
"github.com/NethermindEth/cairo-vm-go/pkg/parsers/starknet"
zero "github.com/NethermindEth/cairo-vm-go/pkg/parsers/zero"
"github.com/NethermindEth/cairo-vm-go/pkg/runner"
Expand Down Expand Up @@ -110,18 +109,14 @@ func main() {
if pathToFile == "" {
return fmt.Errorf("path to cairo file not set")
}
cairoVersion, err := cairoversion.GetCairoVersion(pathToFile)
if err != nil {
return fmt.Errorf("cannot get cairo version: %w", err)
}
fmt.Printf("Loading program at %s\n", pathToFile)
zeroProgram, err := zero.ZeroProgramFromFile(pathToFile)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
}

var hints map[uint64][]hinter.Hinter
if cairoVersion > 0 {
if zeroProgram.CompilerVersion[0] == '1' {
cairoProgram, err := starknet.StarknetProgramFromFile(pathToFile)
if err != nil {
return fmt.Errorf("cannot load program: %w", err)
Expand Down
119 changes: 102 additions & 17 deletions integration_tests/cairozero_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"os"
"os/exec"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -53,7 +52,7 @@ func (f *Filter) filtered(testFile string) bool {
return false
}

func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][2]int, benchmark bool, errorExpected bool) {
func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[string][3]int, benchmark bool, errorExpected bool) {
t.Logf("testing: %s\n", path)

compiledOutput, err := compileZeroCode(path)
Expand All @@ -73,6 +72,17 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
}
}

elapsedRs, rsTraceFile, rsMemoryFile, err := runRustVm(name, compiledOutput)
if errorExpected {
// we let the code go on so that we can check if the go vm also raises an error
assert.Error(t, err, path)
} else {
if err != nil {
t.Error(err)
return
}
}

elapsedGo, traceFile, memoryFile, _, err := runVm(compiledOutput)
if errorExpected {
assert.Error(t, err, path)
Expand All @@ -85,7 +95,7 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
}

if benchmark {
benchmarkMap[name] = [2]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds())}
benchmarkMap[name] = [3]int{int(elapsedPy.Milliseconds()), int(elapsedGo.Milliseconds()), int(elapsedRs.Milliseconds())}
}

pyTrace, pyMemory, err := decodeProof(pyTraceFile, pyMemoryFile)
Expand All @@ -100,6 +110,20 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
return
}

rsTrace, rsMemory, err := decodeProof(rsTraceFile, rsMemoryFile)
if err != nil {
t.Error(err)
return
}

if !assert.Equal(t, pyTrace, rsTrace) {
t.Logf("pytrace:\n%s\n", traceRepr(pyTrace))
t.Logf("rstrace:\n%s\n", traceRepr(rsTrace))
}
if !assert.Equal(t, pyMemory, rsMemory) {
t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory))
t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory))
}
if !assert.Equal(t, pyTrace, trace) {
t.Logf("pytrace:\n%s\n", traceRepr(pyTrace))
t.Logf("trace:\n%s\n", traceRepr(trace))
Expand All @@ -108,6 +132,14 @@ func runAndTestFile(t *testing.T, path string, name string, benchmarkMap map[str
t.Logf("pymemory;\n%s\n", memoryRepr(pyMemory))
t.Logf("memory;\n%s\n", memoryRepr(memory))
}
if !assert.Equal(t, rsTrace, trace) {
t.Logf("rstrace:\n%s\n", traceRepr(rsTrace))
t.Logf("trace:\n%s\n", traceRepr(trace))
}
if !assert.Equal(t, rsMemory, memory) {
t.Logf("rsmemory;\n%s\n", memoryRepr(rsMemory))
t.Logf("memory;\n%s\n", memoryRepr(memory))
}
}

var zerobench = flag.Bool("zerobench", false, "run integration tests and generate benchmarks file")
Expand All @@ -123,7 +155,7 @@ func TestCairoZeroFiles(t *testing.T) {
filter := Filter{}
filter.init()

benchmarkMap := make(map[string][2]int)
benchmarkMap := make(map[string][3]int)

sem := make(chan struct{}, 5) // semaphore to limit concurrency
var wg sync.WaitGroup // WaitGroup to wait for all goroutines to finish
Expand Down Expand Up @@ -176,35 +208,31 @@ func TestCairoZeroFiles(t *testing.T) {
}
}

// Save the Benchmarks for the integration tests in `BenchMarks.txt`
func WriteBenchMarksToFile(benchmarkMap map[string][2]int) {
totalWidth := 123
func WriteBenchMarksToFile(benchmarkMap map[string][3]int) {
totalWidth := 113 // Reduced width to adjust for long file names

border := strings.Repeat("=", totalWidth)
separator := strings.Repeat("-", totalWidth)

var sb strings.Builder
w := tabwriter.NewWriter(&sb, 40, 0, 0, ' ', tabwriter.Debug)
w := tabwriter.NewWriter(&sb, 0, 0, 1, ' ', tabwriter.AlignRight)

sb.WriteString(border + "\n")
fmt.Fprintln(w, "| File \t PythonVM (ms) \t GoVM (ms) \t")
fmt.Fprintf(w, "| %-40s | %-20s | %-20s | %-20s |\n", "File", "PythonVM (ms)", "GoVM (ms)", "RustVM (ms)")
w.Flush()
sb.WriteString(border + "\n")

iterator := 0
totalFiles := len(benchmarkMap)

for key, values := range benchmarkMap {
row := "| " + key + "\t "

for iter, value := range values {
row = row + strconv.Itoa(value) + "\t"
if iter == 0 {
row = row + " "
}
// Adjust the key length if it's too long
displayKey := key
if len(displayKey) > 40 {
displayKey = displayKey[:37] + "..."
}

fmt.Fprintln(w, row)
fmt.Fprintf(w, "| %-40s | %-20d | %-20d | %-20d |\n", displayKey, values[0], values[1], values[2])
w.Flush()

if iterator < totalFiles-1 {
Expand Down Expand Up @@ -236,6 +264,8 @@ const (
compiledSuffix = "_compiled.json"
pyTraceSuffix = "_py_trace"
pyMemorySuffix = "_py_memory"
rsTraceSuffix = "_rs_trace"
rsMemorySuffix = "_rs_memory"
traceSuffix = "_trace"
memorySuffix = "_memory"
)
Expand Down Expand Up @@ -323,6 +353,61 @@ func runPythonVm(testFilename, path string) (time.Duration, string, string, erro
return elapsed, traceOutput, memoryOutput, nil
}

// given a path to a compiled cairo zero file, execute it using the
// rust vm and return the trace and memory files location
func runRustVm(testFilename, path string) (time.Duration, string, string, error) {
traceOutput := swapExtenstion(path, rsTraceSuffix)
memoryOutput := swapExtenstion(path, rsMemorySuffix)

args := []string{
path,
"--proof_mode",
"--trace_file",
traceOutput,
"--memory_file",
memoryOutput,
}

// If any other layouts are needed, add the suffix checks here.
// The convention would be: ".$layout.cairo"
// A file without this suffix will use the default ("plain") layout.
if strings.HasSuffix(testFilename, ".small.cairo") {
args = append(args, "--layout", "small")
} else if strings.HasSuffix(testFilename, ".dex.cairo") {
args = append(args, "--layout", "dex")
} else if strings.HasSuffix(testFilename, ".recursive.cairo") {
args = append(args, "--layout", "recursive")
} else if strings.HasSuffix(testFilename, ".starknet_with_keccak.cairo") {
args = append(args, "--layout", "starknet_with_keccak")
} else if strings.HasSuffix(testFilename, ".starknet.cairo") {
args = append(args, "--layout", "starknet")
} else if strings.HasSuffix(testFilename, ".recursive_large_output.cairo") {
args = append(args, "--layout", "recursive_large_output")
} else if strings.HasSuffix(testFilename, ".recursive_with_poseidon.cairo") {
args = append(args, "--layout", "recursive_with_poseidon")
} else if strings.HasSuffix(testFilename, ".all_solidity.cairo") {
args = append(args, "--layout", "all_solidity")
} else if strings.HasSuffix(testFilename, ".all_cairo.cairo") {
args = append(args, "--layout", "all_cairo")
}

cmd := exec.Command("./../rust_vm_bin/cairo-vm-cli", args...)

start := time.Now()

res, err := cmd.CombinedOutput()

elapsed := time.Since(start)

if err != nil {
return 0, "", "", fmt.Errorf(
"./../rust_vm_bin/cairo-vm-cli %s: %w\n%s", path, err, string(res),
)
}

return elapsed, traceOutput, memoryOutput, nil
}

// given a path to a compiled cairo zero file, execute
// it using our vm
func runVm(path string) (time.Duration, string, string, string, error) {
Expand Down
6 changes: 5 additions & 1 deletion pkg/hintrunner/zero/zerohint.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ func (hint *GenericZeroHinter) Execute(vm *VM.VirtualMachine, ctx *hinter.HintRu
}

func GetZeroHints(cairoZeroJson *zero.ZeroProgram) (map[uint64][]hinter.Hinter, error) {
hints := make(map[uint64][]hinter.Hinter)
numHints := 0
for _, rawHints := range cairoZeroJson.Hints {
numHints += len(rawHints)
}
hints := make(map[uint64][]hinter.Hinter, numHints)
for counter, rawHints := range cairoZeroJson.Hints {
pc, err := strconv.ParseUint(counter, 10, 64)
if err != nil {
Expand Down
30 changes: 0 additions & 30 deletions pkg/parsers/cairo_version/cairo_version.go

This file was deleted.

3 changes: 2 additions & 1 deletion pkg/parsers/zero/zero.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type ZeroProgram struct {
Data []string `json:"data"`
Builtins []builtins.BuiltinType `json:"builtins"`
Hints map[string][]Hint `json:"hints"`
CompilerVersion string `json:"version"`
CompilerVersion string `json:"compiler_version"`
MainScope string `json:"main_scope"`
Identifiers map[string]*Identifier `json:"identifiers"`
ReferenceManager ReferenceManager `json:"reference_manager"`
Expand All @@ -95,6 +95,7 @@ type Identifier struct {
Value any `json:"value"`
}

// TODO: Do we really need this ?
func (z ZeroProgram) MarshalToFile(filepath string) error {
// Marshal Output struct into JSON bytes
data, err := json.MarshalIndent(z, "", " ")
Expand Down
61 changes: 13 additions & 48 deletions pkg/runner/program.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,7 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error)
bytecode[i] = felt
}

entrypoints, err := extractEntrypoints(cairoZeroJson)
if err != nil {
return nil, err
}

labels, err := extractLabels(cairoZeroJson)
if err != nil {
return nil, err
}
entrypoints, labels := extractEntrypointsAndLabels(cairoZeroJson)

return &ZeroProgram{
Bytecode: bytecode,
Expand All @@ -53,49 +45,22 @@ func LoadCairoZeroProgram(cairoZeroJson *zero.ZeroProgram) (*ZeroProgram, error)
}, nil
}

func extractEntrypoints(json *zero.ZeroProgram) (map[string]uint64, error) {
result := make(map[string]uint64)
err := scanIdentifiers(
json,
func(key string, ident *zero.Identifier) error {
if ident.IdentifierType == "function" {
name := key[len(json.MainScope)+1:]
result[name] = uint64(ident.Pc)
}
return nil
},
)

if err != nil {
return nil, fmt.Errorf("extracting entrypoints: %w", err)
func extractEntrypointsAndLabels(json *zero.ZeroProgram) (map[string]uint64, map[string]uint64) {
entrypoints := map[string]uint64{}
for key, ident := range json.Identifiers {
if ident.IdentifierType == "function" {
name := key[len(json.MainScope)+1:]
entrypoints[name] = uint64(ident.Pc)
}
}
return result, nil
}

func extractLabels(json *zero.ZeroProgram) (map[string]uint64, error) {
labels := make(map[string]uint64, 2)
err := scanIdentifiers(
json,
func(key string, ident *zero.Identifier) error {
if ident.IdentifierType == "label" {
name := key[len(json.MainScope)+1:]
labels[name] = uint64(ident.Pc)
}
return nil
},
)
if err != nil {
return nil, fmt.Errorf("extracting labels: %w", err)
}

return labels, nil
}

func scanIdentifiers(json *zero.ZeroProgram, f func(key string, ident *zero.Identifier) error) error {
for key, ident := range json.Identifiers {
if err := f(key, ident); err != nil {
return err
if ident.IdentifierType == "label" {
name := key[len(json.MainScope)+1:]
labels[name] = uint64(ident.Pc)
}
}
return nil

return entrypoints, labels
}
Loading

0 comments on commit 76b422d

Please sign in to comment.