Skip to content

Commit

Permalink
Add history support to OpenAI spam checker and update tests
Browse files Browse the repository at this point in the history
- Modify `check` function in `openAIChecker` to include message history.
- Introduce `OpenAIHistorySize` and `HistorySize` configurations.
- Add tests for message formatting with history and spam checks using history.
- Ensure thread-safe implementation for `LastRequests`.
- Update mock client to reset call history.
- Extend main application configuration for history size management.
  • Loading branch information
umputun committed Jan 20, 2025
1 parent 8d6ea2f commit 5e34ea0
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 14 deletions.
6 changes: 5 additions & 1 deletion app/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ type options struct {
MaxTokensRequestMaxTokensRequest int `long:"max-tokens-request" env:"MAX_TOKENS_REQUEST" default:"2048" description:"openai max tokens in request"`
MaxSymbolsRequest int `long:"max-symbols-request" env:"MAX_SYMBOLS_REQUEST" default:"16000" description:"openai max symbols in request, failback if tokenizer failed"`
RetryCount int `long:"retry-count" env:"RETRY_COUNT" default:"1" description:"openai retry count"`
HistorySize int `long:"history-size" env:"HISTORY_SIZE" default:"0" description:"openai history size"`
} `group:"openai" namespace:"openai" env-namespace:"OPENAI"`

AbnormalSpacing struct {
Expand Down Expand Up @@ -128,7 +129,8 @@ type options struct {
Training bool `long:"training" env:"TRAINING" description:"training mode, passive spam detection only"`
SoftBan bool `long:"soft-ban" env:"SOFT_BAN" description:"soft ban mode, restrict user actions but not ban"`

Convert string `long:"convert" choice:"only" choice:"enabled" choice:"disabled" default:"enabled" description:"convert mode for txt samples and other storage files to DB"`
HistorySize int `long:"history-size" env:"LAST_MSGS_HISTORY_SIZE" default:"100" description:"history size"`
Convert string `long:"convert" choice:"only" choice:"enabled" choice:"disabled" default:"enabled" description:"convert mode for txt samples and other storage files to DB"`

Dry bool `long:"dry" env:"DRY" description:"dry mode, no bans"`
Dbg bool `long:"dbg" env:"DEBUG" description:"debug mode"`
Expand Down Expand Up @@ -441,7 +443,9 @@ func makeDetector(opts options) *tgspam.Detector {
FirstMessageOnly: !opts.ParanoidMode,
FirstMessagesCount: opts.FirstMessagesCount,
OpenAIVeto: opts.OpenAI.Veto,
OpenAIHistorySize: opts.OpenAI.HistorySize, // how many last requests sent to openai
MultiLangWords: opts.MultiLangWords,
HistorySize: opts.HistorySize, // how many last request stored in memory
}

// FirstMessagesCount and ParanoidMode are mutually exclusive.
Expand Down
2 changes: 1 addition & 1 deletion lib/spamcheck/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"sync"
)

// LastRequests keeps track of last N requests globally
// LastRequests keeps track of last N requests, thread-safe.
type LastRequests struct {
requests *ring.Ring
size int
Expand Down
8 changes: 7 additions & 1 deletion lib/tgspam/detector.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ type Config struct {
HTTPClient HTTPClient // http client to use for requests
MinSpamProbability float64 // minimum spam probability to consider a message spam with classifier, if 0 - ignored
OpenAIVeto bool // if true, openai will be used to veto spam messages, otherwise it will be used to veto ham messages
OpenAIHistorySize int // history size for openai
MultiLangWords int // if true, check for number of multi-lingual words
StorageTimeout time.Duration // timeout for storage operations, if not set - no timeout

Expand Down Expand Up @@ -206,7 +207,12 @@ func (d *Detector) Check(req spamcheck.Request) (spam bool, cr []spamcheck.Respo
// FirstMessageOnly or FirstMessagesCount has to be set to use openai, because it's slow and expensive to run on all messages
if d.openaiChecker != nil && (d.FirstMessageOnly || d.FirstMessagesCount > 0) {
if !spamDetected && !d.OpenAIVeto || spamDetected && d.OpenAIVeto {
spam, details := d.openaiChecker.check(cleanMsg)
var hist []spamcheck.Request // by default, openai doesn't use history
if d.OpenAIHistorySize > 0 && d.HistorySize > 0 {
// if history size is set, we use the last N messages for openai
hist = d.hamHistory.Last(d.OpenAIHistorySize)
}
spam, details := d.openaiChecker.check(cleanMsg, hist)
cr = append(cr, details)
if spamDetected && details.Error != nil {
// spam detected with other checks, but openai failed. in this case, we still return spam, but log the error
Expand Down
14 changes: 14 additions & 0 deletions lib/tgspam/mocks/openai_client.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 15 additions & 6 deletions lib/tgspam/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"github.com/umputun/tg-spam/lib/spamcheck"
)

//go:generate moq --out mocks/openai_client.go --pkg mocks --skip-ensure . openAIClient:OpenAIClientMock
//go:generate moq --out mocks/openai_client.go --pkg mocks --with-resets --skip-ensure . openAIClient:OpenAIClientMock

// openAIChecker is a wrapper for OpenAI API to check if a text is spam
type openAIChecker struct {
Expand Down Expand Up @@ -67,12 +67,21 @@ func newOpenAIChecker(client openAIClient, params OpenAIConfig) *openAIChecker {
return &openAIChecker{client: client, params: params}
}

// check checks if a text is spam
func (o *openAIChecker) check(msg string) (spam bool, cr spamcheck.Response) {
// check checks if a text is spam using OpenAI API
func (o *openAIChecker) check(msg string, history []spamcheck.Request) (spam bool, cr spamcheck.Response) {
if o.client == nil {
return false, spamcheck.Response{}
}

// update the message with the history
if len(history) > 0 {
hist := []string{"--------------------------------", "this is the history of previous messages from other users:", ""}
for _, h := range history {
hist = append(hist, fmt.Sprintf("%q: %q", h.UserName, h.Msg))
}
msg = msg + "\n" + strings.Join(hist, "\n")
}

// try to send a request several times if it fails
var resp openAIResponse
var err error
Expand All @@ -91,9 +100,9 @@ func (o *openAIChecker) check(msg string) (spam bool, cr spamcheck.Response) {
}

func (o *openAIChecker) sendRequest(msg string) (response openAIResponse, err error) {
// Reduce the request size with tokenizer and fallback to default reducer if it fails
// The API supports 4097 tokens ~16000 characters (<=4 per token) for request + result together
// The response is limited to 1000 tokens and OpenAI always reserved it for the result
// Reduce the request size with tokenizer and fallback to default reducer if it fails.
// The API supports 4097 tokens ~16000 characters (<=4 per token) for request + result together.
// The response is limited to 1000 tokens, and OpenAI always reserved it for the result.
// So the max length of the request should be 3000 tokens or ~12000 characters
reduceRequest := func(text string) (result string) {
// defaultReducer is a fallback if tokenizer fails
Expand Down
110 changes: 105 additions & 5 deletions lib/tgspam/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/sashabaranov/go-openai"
"github.com/stretchr/testify/assert"

"github.com/umputun/tg-spam/lib/spamcheck"
"github.com/umputun/tg-spam/lib/tgspam/mocks"
)

Expand Down Expand Up @@ -39,7 +40,7 @@ func TestOpenAIChecker_Check(t *testing.T) {
}},
}, nil
}
spam, details := checker.check("some text")
spam, details := checker.check("some text", nil)
t.Logf("spam: %v, details: %+v", spam, details)
assert.True(t, spam)
assert.Equal(t, "openai", details.Name)
Expand All @@ -56,7 +57,7 @@ func TestOpenAIChecker_Check(t *testing.T) {
}},
}, nil
}
spam, details := checker.check("some text")
spam, details := checker.check("some text", nil)
t.Logf("spam: %v, details: %+v", spam, details)
assert.False(t, spam)
assert.Equal(t, "openai", details.Name)
Expand All @@ -69,7 +70,7 @@ func TestOpenAIChecker_Check(t *testing.T) {
contextMoqParam context.Context, chatCompletionRequest openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{}, assert.AnError
}
spam, details := checker.check("some text")
spam, details := checker.check("some text", nil)
t.Logf("spam: %v, details: %+v", spam, details)
assert.False(t, spam)
assert.Equal(t, "openai", details.Name)
Expand All @@ -86,7 +87,7 @@ func TestOpenAIChecker_Check(t *testing.T) {
}},
}, nil
}
spam, details := checker.check("some text")
spam, details := checker.check("some text", nil)
t.Logf("spam: %v, details: %+v", spam, details)
assert.False(t, spam)
assert.Equal(t, "openai", details.Name)
Expand All @@ -101,10 +102,109 @@ func TestOpenAIChecker_Check(t *testing.T) {
contextMoqParam context.Context, chatCompletionRequest openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
return openai.ChatCompletionResponse{}, nil
}
spam, details := checker.check("some text")
spam, details := checker.check("some text", nil)
t.Logf("spam: %v, details: %+v", spam, details)
assert.False(t, spam)
assert.Equal(t, "openai", details.Name)
assert.Equal(t, "OpenAI error: no choices in response", details.Details)
})
}

func TestOpenAIChecker_CheckWithHistory(t *testing.T) {
clientMock := &mocks.OpenAIClientMock{
CreateChatCompletionFunc: func(contextMoqParam context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
// verify the request contains history
assert.Contains(t, req.Messages[1].Content, "history of previous messages")
assert.Contains(t, req.Messages[1].Content, `"user1": "first message"`)
assert.Contains(t, req.Messages[1].Content, `"user2": "second message"`)
assert.Contains(t, req.Messages[1].Content, `"user1": "third message"`)

return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{Content: `{"spam": true, "reason":"suspicious pattern in history", "confidence":90}`},
}},
}, nil
},
}

checker := newOpenAIChecker(clientMock, OpenAIConfig{Model: "gpt-4o-mini"})
history := []spamcheck.Request{
{Msg: "first message", UserName: "user1"},
{Msg: "second message", UserName: "user2"},
{Msg: "third message", UserName: "user1"},
}

spam, details := checker.check("current message", history)
t.Logf("spam: %v, details: %+v", spam, details)
assert.True(t, spam)
assert.Equal(t, "openai", details.Name)
assert.Equal(t, "suspicious pattern in history, confidence: 90%", details.Details)
assert.NoError(t, details.Error)
assert.Equal(t, 1, len(clientMock.CreateChatCompletionCalls()))
}

func TestOpenAIChecker_FormatMessage(t *testing.T) {
var capturedMsg string
clientMock := &mocks.OpenAIClientMock{
CreateChatCompletionFunc: func(contextMoqParam context.Context, req openai.ChatCompletionRequest) (openai.ChatCompletionResponse, error) {
capturedMsg = req.Messages[1].Content // capture the message being sent
return openai.ChatCompletionResponse{
Choices: []openai.ChatCompletionChoice{{
Message: openai.ChatCompletionMessage{Content: `{"spam": false, "reason":"test message", "confidence":90}`},
}},
}, nil
},
}

tests := []struct {
name string
currentMsg string
history []spamcheck.Request
expectedMessage string
}{
{
name: "message with no history",
currentMsg: "hello world",
history: []spamcheck.Request{},
expectedMessage: "hello world",
},
{
name: "message with history",
currentMsg: "current message",
history: []spamcheck.Request{
{Msg: "first message", UserName: "user1"},
{Msg: "second message", UserName: "user2"},
},
expectedMessage: `current message
--------------------------------
this is the history of previous messages from other users:
"user1": "first message"
"user2": "second message"`,
},
{
name: "message with empty username in history",
currentMsg: "current message",
history: []spamcheck.Request{
{Msg: "first message", UserName: ""},
{Msg: "second message", UserName: "user2"},
},
expectedMessage: `current message
--------------------------------
this is the history of previous messages from other users:
"": "first message"
"user2": "second message"`,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientMock.ResetCalls() // reset mock before each test case
checker := newOpenAIChecker(clientMock, OpenAIConfig{Model: "gpt-4o-mini"})
checker.check(tt.currentMsg, tt.history)
assert.Equal(t, tt.expectedMessage, capturedMsg, "message formatting mismatch")
assert.Equal(t, 1, len(clientMock.CreateChatCompletionCalls()))
})
}
}

0 comments on commit 5e34ea0

Please sign in to comment.