diff --git a/dagutils/diff.go b/dagutils/diff.go index 5015238..9fef3f9 100644 --- a/dagutils/diff.go +++ b/dagutils/diff.go @@ -102,64 +102,62 @@ func ApplyChange(ctx context.Context, ds ipld.DAGService, nd *dag.ProtoNode, cs // 2. both of two nodes are ProtoNode. // Otherwise, it compares the cid and emits a Mod change object. func Diff(ctx context.Context, ds ipld.DAGService, a, b ipld.Node) ([]*Change, error) { - // Base case where both nodes are leaves, just compare - // their CIDs. - if len(a.Links()) == 0 && len(b.Links()) == 0 { - return getChange(a, b) + if a.Cid() == b.Cid() { + return []*Change{}, nil } - var out []*Change cleanA, okA := a.Copy().(*dag.ProtoNode) cleanB, okB := b.Copy().(*dag.ProtoNode) - if !okA || !okB { - return getChange(a, b) + + linksA := a.Links() + linksB := b.Links() + + if !okA || !okB || (len(linksA) == 0 && len(linksB) == 0) { + return []*Change{{Type: Mod, Before: a.Cid(), After: b.Cid()}}, nil } - // strip out unchanged stuff - for _, lnk := range a.Links() { - l, _, err := b.ResolveLink([]string{lnk.Name}) - if err == nil { - if l.Cid.Equals(lnk.Cid) { - // no change... ignore it - } else { - anode, err := lnk.GetNode(ctx, ds) - if err != nil { - return nil, err - } - - bnode, err := l.GetNode(ctx, ds) - if err != nil { - return nil, err - } - - sub, err := Diff(ctx, ds, anode, bnode) - if err != nil { - return nil, err - } - - for _, subc := range sub { - subc.Path = path.Join(lnk.Name, subc.Path) - out = append(out, subc) - } - } - _ = cleanA.RemoveNodeLink(l.Name) - _ = cleanB.RemoveNodeLink(l.Name) + var out []*Change + for _, linkA := range linksA { + linkB, _, err := b.ResolveLink([]string{linkA.Name}) + if err != nil { + continue + } + + cleanA.RemoveNodeLink(linkA.Name) + cleanB.RemoveNodeLink(linkA.Name) + + if linkA.Cid == linkB.Cid { + continue + } + + nodeA, err := linkA.GetNode(ctx, ds) + if err != nil { + return nil, err } + + nodeB, err := linkB.GetNode(ctx, ds) + if err != nil { + return nil, err + } + + sub, err := Diff(ctx, ds, nodeA, nodeB) + if err != nil { + return nil, err + } + + for _, c := range sub { + c.Path = path.Join(linkA.Name, c.Path) + } + + out = append(out, sub...) } - for _, lnk := range cleanA.Links() { - out = append(out, &Change{ - Type: Remove, - Path: lnk.Name, - Before: lnk.Cid, - }) + for _, l := range cleanA.Links() { + out = append(out, &Change{Type: Remove, Path: l.Name, Before: l.Cid}) } - for _, lnk := range cleanB.Links() { - out = append(out, &Change{ - Type: Add, - Path: lnk.Name, - After: lnk.Cid, - }) + + for _, l := range cleanB.Links() { + out = append(out, &Change{Type: Add, Path: l.Name, After: l.Cid}) } return out, nil @@ -177,38 +175,26 @@ type Conflict struct { // A slice of Conflicts is returned and contains pointers to the // Changes involved (which share the same path). func MergeDiffs(a, b []*Change) ([]*Change, []Conflict) { - var out []*Change - var conflicts []Conflict paths := make(map[string]*Change) for _, c := range a { paths[c.Path] = c } - for _, c := range b { - if ca, ok := paths[c.Path]; ok { - conflicts = append(conflicts, Conflict{ - A: ca, - B: c, - }) + var changes []*Change + var conflicts []Conflict + + for _, changeB := range b { + if changeA, ok := paths[changeB.Path]; ok { + conflicts = append(conflicts, Conflict{changeA, changeB}) } else { - out = append(out, c) + changes = append(changes, changeB) } + delete(paths, changeB.Path) } + for _, c := range paths { - out = append(out, c) + changes = append(changes, c) } - return out, conflicts -} -func getChange(a, b ipld.Node) ([]*Change, error) { - if a.Cid().Equals(b.Cid()) { - return []*Change{}, nil - } - return []*Change{ - { - Type: Mod, - Before: a.Cid(), - After: b.Cid(), - }, - }, nil + return changes, conflicts } diff --git a/dagutils/diff_test.go b/dagutils/diff_test.go new file mode 100644 index 0000000..9cafe13 --- /dev/null +++ b/dagutils/diff_test.go @@ -0,0 +1,127 @@ +package dagutils + +import ( + "context" + "testing" + + cid "github.com/ipfs/go-cid" + ipld "github.com/ipfs/go-ipld-format" + dag "github.com/ipfs/go-merkledag" + mdtest "github.com/ipfs/go-merkledag/test" +) + +func TestMergeDiffs(t *testing.T) { + node1 := dag.NodeWithData([]byte("one")) + node2 := dag.NodeWithData([]byte("two")) + node3 := dag.NodeWithData([]byte("three")) + node4 := dag.NodeWithData([]byte("four")) + + changesA := []*Change{ + {Add, "one", cid.Cid{}, node1.Cid()}, + {Remove, "two", node2.Cid(), cid.Cid{}}, + {Mod, "three", node3.Cid(), node4.Cid()}, + } + + changesB := []*Change{ + {Mod, "two", node2.Cid(), node3.Cid()}, + {Add, "four", cid.Cid{}, node4.Cid()}, + } + + changes, conflicts := MergeDiffs(changesA, changesB) + if len(changes) != 3 { + t.Fatal("unexpected merge changes") + } + + expect := []*Change{ + changesB[1], + changesA[0], + changesA[2], + } + + for i, change := range changes { + if change.Type != expect[i].Type { + t.Error("unexpected diff change type") + } + + if change.Path != expect[i].Path { + t.Error("unexpected diff change path") + } + + if change.Before != expect[i].Before { + t.Error("unexpected diff change before") + } + + if change.After != expect[i].After { + t.Error("unexpected diff change before") + } + } + + if len(conflicts) != 1 { + t.Fatal("unexpected merge conflicts") + } + + if conflicts[0].A != changesA[1] { + t.Error("unexpected merge conflict a") + } + + if conflicts[0].B != changesB[0] { + t.Error("unexpected merge conflict b") + } +} + +func TestDiff(t *testing.T) { + ctx := context.Background() + ds := mdtest.Mock() + + rootA := &dag.ProtoNode{} + rootB := &dag.ProtoNode{} + + child1 := dag.NodeWithData([]byte("one")) + child2 := dag.NodeWithData([]byte("two")) + child3 := dag.NodeWithData([]byte("three")) + child4 := dag.NodeWithData([]byte("four")) + + rootA.AddNodeLink("one", child1) + rootA.AddNodeLink("two", child2) + + rootB.AddNodeLink("one", child3) + rootB.AddNodeLink("four", child4) + + nodes := []ipld.Node{child1, child2, child3, child4, rootA, rootB} + if err := ds.AddMany(ctx, nodes); err != nil { + t.Fatal("failed to add nodes") + } + + changes, err := Diff(ctx, ds, rootA, rootB) + if err != nil { + t.Fatal("unexpected diff error") + } + + if len(changes) != 3 { + t.Fatal("unexpected diff changes") + } + + expect := []Change{ + {Mod, "one", child1.Cid(), child3.Cid()}, + {Remove, "two", child2.Cid(), cid.Cid{}}, + {Add, "four", cid.Cid{}, child4.Cid()}, + } + + for i, change := range changes { + if change.Type != expect[i].Type { + t.Error("unexpected diff change type") + } + + if change.Path != expect[i].Path { + t.Error("unexpected diff change path") + } + + if change.Before != expect[i].Before { + t.Error("unexpected diff change before") + } + + if change.After != expect[i].After { + t.Error("unexpected diff change before") + } + } +}