Skip to content

Commit

Permalink
fix: snapshot file will be overwritten and mixed if multiple consumer…
Browse files Browse the repository at this point in the history
…s or producers use different name server domains (#1099)
  • Loading branch information
tuweizhong authored Sep 25, 2023
1 parent 4de354a commit 5eab91b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
14 changes: 9 additions & 5 deletions primitive/nsresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package primitive

import (
"crypto/md5"
"encoding/hex"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -139,7 +141,7 @@ func (h *HttpResolver) Resolve() []string {
}

func (h *HttpResolver) Description() string {
return fmt.Sprintf("passthrough resolver of domain:%v instance:%v", h.domain, h.instance)
return fmt.Sprintf("http resolver of domain:%v", h.domain)
}

func (h *HttpResolver) get() []string {
Expand Down Expand Up @@ -177,7 +179,7 @@ func (h *HttpResolver) get() []string {
}

func (h *HttpResolver) saveSnapshot(body []byte) error {
filePath := h.getSnapshotFilePath(h.instance)
filePath := h.getSnapshotFilePath()
err := ioutil.WriteFile(filePath, body, 0644)
if err != nil {
rlog.Error("name server snapshot save failed", map[string]interface{}{
Expand All @@ -194,7 +196,7 @@ func (h *HttpResolver) saveSnapshot(body []byte) error {
}

func (h *HttpResolver) loadSnapshot() []string {
filePath := h.getSnapshotFilePath(h.instance)
filePath := h.getSnapshotFilePath()
_, err := os.Stat(filePath)
if os.IsNotExist(err) {
rlog.Warning("name server snapshot local file not exists", map[string]interface{}{
Expand All @@ -214,7 +216,7 @@ func (h *HttpResolver) loadSnapshot() []string {
return strings.Split(string(bs), ";")
}

func (h *HttpResolver) getSnapshotFilePath(instanceName string) string {
func (h *HttpResolver) getSnapshotFilePath() string {
homeDir := ""
if usr, err := user.Current(); err == nil {
homeDir = usr.HomeDir
Expand All @@ -232,6 +234,8 @@ func (h *HttpResolver) getSnapshotFilePath(instanceName string) string {
})
}
}
filePath := path.Join(storePath, fmt.Sprintf("nameserver_addr-%s", instanceName))
hash := md5.Sum([]byte(h.domain))
domainHash := hex.EncodeToString(hash[:])
filePath := path.Join(storePath, fmt.Sprintf("nameserver_addr-%s", domainHash))
return filePath
}
14 changes: 7 additions & 7 deletions primitive/nsresolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func TestHttpResolverWithGet(t *testing.T) {
resolver.Resolve()

// check snapshot saved
filePath := resolver.getSnapshotFilePath("DEFAULT")
filePath := resolver.getSnapshotFilePath()
body := strings.Join(srvs, ";")
bs, _ := ioutil.ReadFile(filePath)
So(string(bs), ShouldEqual, body)
Expand Down Expand Up @@ -112,7 +112,7 @@ func TestHttpResolverWithGetUnitName(t *testing.T) {
resolver.Resolve()

// check snapshot saved
filePath := resolver.getSnapshotFilePath("DEFAULT")
filePath := resolver.getSnapshotFilePath()
body := strings.Join(srvs, ";")
bs, _ := ioutil.ReadFile(filePath)
So(string(bs), ShouldEqual, body)
Expand All @@ -133,7 +133,7 @@ func TestHttpResolverWithSnapshotFile(t *testing.T) {

os.Setenv("NAMESRV_ADDR", "") // clear env
// setup local snapshot file
filePath := resolver.getSnapshotFilePath("DEFAULT")
filePath := resolver.getSnapshotFilePath()
body := strings.Join(srvs, ";")
_ = ioutil.WriteFile(filePath, []byte(body), 0644)

Expand All @@ -143,7 +143,7 @@ func TestHttpResolverWithSnapshotFile(t *testing.T) {
})
}

func TesHttpReslverWithSnapshotFileOnce(t *testing.T) {
func TestHttpResolverWithSnapshotFileOnce(t *testing.T) {
Convey("Test UpdateNameServerAddress Load Local Snapshot Once", t, func() {
srvs := []string{
"192.168.100.1",
Expand All @@ -157,18 +157,18 @@ func TesHttpReslverWithSnapshotFileOnce(t *testing.T) {

os.Setenv("NAMESRV_ADDR", "") // clear env
// setup local snapshot file
filePath := resolver.getSnapshotFilePath("DEFAULT")
filePath := resolver.getSnapshotFilePath()
body := strings.Join(srvs, ";")
_ = ioutil.WriteFile(filePath, []byte(body), 0644)
// load local snapshot file first time
addrs1 := resolver.Resolve()

// change the local snapshot file to check load once
// change the local snapshot file
_ = ioutil.WriteFile(filePath, []byte("127.0.0.1;127.0.0.2"), 0644)

addrs2 := resolver.Resolve()

So(Diff(addrs1, addrs2), ShouldBeFalse)
So(Diff(addrs1, addrs2), ShouldBeTrue)
So(Diff(addrs1, srvs), ShouldBeFalse)
})
}

0 comments on commit 5eab91b

Please sign in to comment.