diff --git a/index/index.go b/index/index.go index a5f095ec..a318428d 100644 --- a/index/index.go +++ b/index/index.go @@ -64,9 +64,9 @@ type symbolFrequencyPair struct { type symbolFrequencylist []symbolFrequencyPair -func (s symbolFrequencylist) Len() int { return len(s) } -func (s symbolFrequencylist) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s symbolFrequencylist) Less(i, j int) bool { return s[i].frequency < s[j].frequency } +func (s symbolFrequencylist) Len() int { return len(s) } +func (s symbolFrequencylist) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s symbolFrequencylist) Greater(i, j int) bool { return s[i].frequency > s[j].frequency } type indexWriterStage uint8 @@ -355,7 +355,14 @@ func (w *Writer) AddSymbols(sym map[string]int) error { for k, v := range sym { symbols = append(symbols, symbolFrequencyPair{k, v}) } - sort.Sort(sort.Reverse(symbols)) + sort.Slice(symbols, func(i, j int) bool { + // We get the symbols back as a map so we need to be sure + // to sort by symbol if the frequencies are the same. + if symbols[i].frequency == symbols[j].frequency { + return symbols[i].symbol > symbols[j].symbol + } + return symbols.Greater(i, j) + }) const headerSize = 4 @@ -874,7 +881,7 @@ func (r *Reader) Symbols() (map[string]int, error) { res[s] = 0 } for _, s := range r.symbolSlice { - res[s] = struct{}{} + res[s] = 0 } return res, nil } diff --git a/index/index_test.go b/index/index_test.go index 7de7e348..7cdd3ab9 100644 --- a/index/index_test.go +++ b/index/index_test.go @@ -232,6 +232,48 @@ func TestIndexRW_Postings(t *testing.T) { testutil.Ok(t, ir.Close()) } +func TestIndexRW_SymbolsOrder(t *testing.T) { + dir, err := ioutil.TempDir("", "test_index_order") + testutil.Ok(t, err) + defer os.RemoveAll(dir) + + fn := filepath.Join(dir, "index") + + iw, err := NewWriter(fn) + testutil.Ok(t, err) + + err = iw.AddSymbols(map[string]int{ + "a": 1, + "b": 2, + "c": 1, + "2": 4, + "3": 5, + "4": 3, + }) + + testutil.Ok(t, err) + testutil.Ok(t, iw.Close()) + + exp := []string{"3", "2", "4", "b", "c", "a"} + + ir, err := NewFileReader(fn) + testutil.Ok(t, err) + + err = ir.readSymbols(int(ir.toc.symbols)) + testutil.Ok(t, err) + + s, err := ir.Symbols() + t.Logf("symbols: %+v", s) + + testutil.Equals(t, len(ir.symbolSlice), len(exp)) + + for i := range ir.symbolSlice { + testutil.Equals(t, ir.symbolSlice[i], exp[i]) + } + + testutil.Ok(t, ir.Close()) +} + func TestPersistence_index_e2e(t *testing.T) { dir, err := ioutil.TempDir("", "test_persistence_e2e") testutil.Ok(t, err)