diff --git a/aws/sqs/sqs.go b/aws/sqs/sqs.go index 8b9778a..fb738a7 100644 --- a/aws/sqs/sqs.go +++ b/aws/sqs/sqs.go @@ -18,9 +18,122 @@ import ( ) type Raw struct { - Body string - ReceiptHandle string - Attributes map[string]string + Body string + ReceiptHandle string + Attributes map[string]string + MessageAttributes map[string]string +} + +type messageInput struct { + // generic + ctx context.Context + + // send + msgAttributes map[string]string + delay *int32 + group string + dedupe string + + // receive + visibilityTimeout *int32 + waitTime *int32 + queueAttributes []types.QueueAttributeName +} + +type MessageResponse struct { + MessageId string +} + +type messageInputOptionFunc func(*messageInput) + +func (s *SQS) SendMessage(queue string, body string, options ...messageInputOptionFunc) (MessageResponse, error) { + msg := &messageInput{} + for _, option := range options { + option(msg) + } + + awsMsg := sqs.SendMessageInput{ + QueueUrl: aws.String(queue), + MessageBody: aws.String(body), + } + + // delayed message + if msg.delay != nil { + awsMsg.DelaySeconds = *msg.delay + } + + // message with attributes + if len(msg.msgAttributes) > 0 { + awsMsg.MessageAttributes = map[string]types.MessageAttributeValue{} + for k, v := range msg.msgAttributes { + awsMsg.MessageAttributes[k] = types.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: aws.String(v), + } + } + } + + // message for FIFO queue + if msg.group != "" { + awsMsg.MessageGroupId = aws.String(msg.group) + awsMsg.MessageDeduplicationId = aws.String(msg.dedupe) + } + + // context + ctx := msg.ctx + if ctx == nil { + ctx = context.TODO() + } + + response, err := s.client.SendMessage(ctx, &awsMsg) + + return MessageResponse{MessageId: *response.MessageId}, err +} + +func WithDelay(delay int32) messageInputOptionFunc { + return func(m *messageInput) { + m.delay = new(int32) + *m.delay = delay + } +} + +func WithWait(waitTimeSeconds int32) messageInputOptionFunc { + return func(m *messageInput) { + m.waitTime = new(int32) + *m.waitTime = waitTimeSeconds + } +} + +func WithTimeout(visibilityTimeout int32) messageInputOptionFunc { + return func(m *messageInput) { + m.visibilityTimeout = new(int32) + *m.visibilityTimeout = visibilityTimeout + } +} + +func WithMessageAttributes(attributes map[string]string) messageInputOptionFunc { + return func(m *messageInput) { + m.msgAttributes = attributes + } +} + +func WithContext(ctx context.Context) messageInputOptionFunc { + return func(m *messageInput) { + m.ctx = ctx + } +} + +func WithQueueAttributes(attributes []types.QueueAttributeName) messageInputOptionFunc { + return func(m *messageInput) { + m.queueAttributes = attributes + } +} + +func WithFifo(group, dedupe string) messageInputOptionFunc { + return func(m *messageInput) { + m.group = group + m.dedupe = dedupe + } } type SQS struct { @@ -97,13 +210,13 @@ func (s *SQS) Ready() bool { // Applications should be able to handle duplicate or out of order messages, // and should back off on Receive error. func (s *SQS) Receive(queueURL string, visibilityTimeout int32) (Raw, error) { - return s.ReceiveWithContext(context.TODO(), queueURL, visibilityTimeout) + return s.ReceiveMessage(queueURL, WithTimeout(visibilityTimeout)) } // ReceiveWithAttributes is the same as Receive except that Queue Attributes can be requested // to be received with the message. func (s *SQS) ReceiveWithAttributes(queueURL string, visibilityTimeout int32, attrs []types.QueueAttributeName) (Raw, error) { - return s.ReceiveWithContextAttributes(context.TODO(), queueURL, visibilityTimeout, attrs) + return s.ReceiveMessage(queueURL, WithTimeout(visibilityTimeout), WithQueueAttributes(attrs)) } // ReceiveWithContextAttributes by context and Queue Attributes, @@ -112,26 +225,43 @@ func (s *SQS) ReceiveWithAttributes(queueURL string, visibilityTimeout int32, at // when system stop signal is received, an error with message '... context canceled' will be returned // which can be used to safely stop the system func (s *SQS) ReceiveWithContextAttributes(ctx context.Context, queueURL string, visibilityTimeout int32, attrs []types.QueueAttributeName) (Raw, error) { - input := sqs.ReceiveMessageInput{ - QueueUrl: aws.String(queueURL), - MaxNumberOfMessages: 1, - VisibilityTimeout: visibilityTimeout, - WaitTimeSeconds: 20, - AttributeNames: attrs, - } - return s.receiveMessage(ctx, &input) + return s.ReceiveMessage(queueURL, WithContext(ctx), WithTimeout(visibilityTimeout), WithWait(20), WithQueueAttributes(attrs)) } -// receiveMessage is the common code used internally to receive an SQS message based -// on the provided input. -func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput) (Raw, error) { +func (s *SQS) ReceiveMessage(queue string, options ...messageInputOptionFunc) (Raw, error) { + msg := &messageInput{} + for _, option := range options { + option(msg) + } - for { - r, err := s.client.ReceiveMessage(ctx, input) - if err != nil { - return Raw{}, err - } + awsMsg := sqs.ReceiveMessageInput{ + QueueUrl: aws.String(queue), + MessageAttributeNames: []string{"All"}, + AttributeNames: msg.queueAttributes, + } + + // visibility + if msg.visibilityTimeout != nil { + awsMsg.VisibilityTimeout = *msg.visibilityTimeout + } + + // wait time + if msg.waitTime != nil { + awsMsg.WaitTimeSeconds = *msg.waitTime + } + + // context + ctx := msg.ctx + if ctx == nil { + ctx = context.TODO() + } + + r, err := s.client.ReceiveMessage(ctx, &awsMsg) + if err != nil { + return Raw{}, err + } + for { switch { case r == nil || len(r.Messages) == 0: // no message received @@ -139,10 +269,20 @@ func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput case len(r.Messages) == 1: raw := r.Messages[0] + msgAttributes := map[string]string{} + for k, v := range raw.MessageAttributes { + if aws.ToString(v.DataType) == "Binary" { + continue + } + + msgAttributes[k] = aws.ToString(v.StringValue) + } + m := Raw{ - Body: aws.ToString(raw.Body), - ReceiptHandle: aws.ToString(raw.ReceiptHandle), - Attributes: raw.Attributes, + Body: aws.ToString(raw.Body), + ReceiptHandle: aws.ToString(raw.ReceiptHandle), + Attributes: raw.Attributes, + MessageAttributes: msgAttributes, } return m, nil case len(r.Messages) > 1: @@ -156,13 +296,7 @@ func (s *SQS) receiveMessage(ctx context.Context, input *sqs.ReceiveMessageInput // when system stop signal is received, an error with message '... context canceled' will be returned // which can be used to safely stop the system func (s *SQS) ReceiveWithContext(ctx context.Context, queueURL string, visibilityTimeout int32) (Raw, error) { - input := sqs.ReceiveMessageInput{ - QueueUrl: aws.String(queueURL), - MaxNumberOfMessages: 1, - VisibilityTimeout: visibilityTimeout, - WaitTimeSeconds: 20, - } - return s.receiveMessage(ctx, &input) + return s.ReceiveMessage(queueURL, WithContext(ctx), WithTimeout(visibilityTimeout)) } // Delete deletes the message referred to by receiptHandle from the queue. @@ -179,49 +313,20 @@ func (s *SQS) Delete(queueURL, receiptHandle string) error { // Send sends the message body to the SQS queue referred to by queueURL. func (s *SQS) Send(queueURL string, body string) error { - params := sqs.SendMessageInput{ - QueueUrl: aws.String(queueURL), - MessageBody: aws.String(body), - } - - _, err := s.client.SendMessage(context.TODO(), ¶ms) - + _, err := s.SendMessage(queueURL, body) return err } // SendWithDelay is the same as Send but adds a delay (in seconds) before sending. func (s *SQS) SendWithDelay(queueURL string, body string, delay int32) error { - params := sqs.SendMessageInput{ - QueueUrl: aws.String(queueURL), - MessageBody: aws.String(body), - DelaySeconds: delay, - } - - _, err := s.client.SendMessage(context.TODO(), ¶ms) - + _, err := s.SendMessage(queueURL, body, WithDelay(delay)) return err } // SendFifoMessage puts a message onto the given AWS SQS queue. func (s *SQS) SendFifoMessage(queue, group, dedupe string, msg []byte) (string, error) { - var id *string - if dedupe != "" { - id = aws.String(dedupe) - } - params := sqs.SendMessageInput{ - MessageBody: aws.String(string(msg)), - QueueUrl: aws.String(queue), - MessageGroupId: aws.String(group), - MessageDeduplicationId: id, - } - output, err := s.client.SendMessage(context.TODO(), ¶ms) - if err != nil { - return "", err - } - if id := output.MessageId; id != nil { - return *id, nil - } - return "", nil + r, err := s.SendMessage(queue, string(msg), WithFifo(group, dedupe)) + return r.MessageId, err } // Leverage the sendbatch api for uploading large numbers of messages diff --git a/aws/sqs/sqs_integration_test.go b/aws/sqs/sqs_integration_test.go index 6842b09..dd56ce7 100644 --- a/aws/sqs/sqs_integration_test.go +++ b/aws/sqs/sqs_integration_test.go @@ -22,6 +22,9 @@ const ( awsRegion = "ap-southeast-2" testQueue = "test-queue" testMessage = "test message" + + testMessageAttributeKey = "test-key" + testMessageAttributeValue = "test-value" ) // helper functions @@ -79,6 +82,9 @@ func awsCmdSendMessage() { "send-message", "--queue-url", awsCmdQueueURL(), "--message-body", testMessage, + "--message-attributes", fmt.Sprintf( + "%s={DataType=String, StringValue=\"%s\"}", + testMessageAttributeKey, testMessageAttributeValue), "--region", awsRegion, "--endpoint-url", cutomAWSEndpointURL).Run(); err != nil { @@ -103,6 +109,29 @@ func awsCmdReceiveMessage() string { } } +func awsCmdReceiveMessageWithAttributes() (string, map[string]string) { + if out, err := exec.Command( + "aws", "sqs", + "receive-message", + "--queue-url", awsCmdQueueURL(), + "--message-attribute-names", "All", + "--region", awsRegion, + "--endpoint-url", cutomAWSEndpointURL).CombinedOutput(); err != nil { + + panic(err) + } else { + var payload map[string][]map[string]interface{} + json.Unmarshal(out, &payload) + + msgAttributes := map[string]string{} + for k, v := range payload["Messages"][0]["MessageAttributes"].(map[string]interface{}) { + msgAttributes[k] = v.(map[string]interface{})["StringValue"].(string) + } + + return payload["Messages"][0]["Body"].(string), msgAttributes + } +} + func awsCmdQueueCount() int { if out, err := exec.Command( "aws", "sqs", @@ -188,6 +217,7 @@ func TestSQSReceive(t *testing.T) { // ASSERT assert.Nil(t, err) assert.Equal(t, testMessage, receivedMessage.Body) + assert.Equal(t, testMessageAttributeValue, receivedMessage.MessageAttributes[testMessageAttributeKey]) } func TestSQSReceiveWithAttributes(t *testing.T) { @@ -247,6 +277,44 @@ func TestSQSSend(t *testing.T) { assert.Equal(t, testMessage, awsCmdReceiveMessage()) } +func TestSQSSendMessage(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("Error creating sqs client: %v", err)) + + // ACTION + _, err = client.SendMessage(awsCmdQueueURL(), testMessage) + + // ASSERT + assert.Nil(t, err) + assert.Equal(t, testMessage, awsCmdReceiveMessage()) +} + +func TestSQSSendMessageWithAttributes(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("Error creating sqs client: %v", err)) + + // ACTION + _, err = client.SendMessage( + awsCmdQueueURL(), + testMessage, + WithMessageAttributes(map[string]string{testMessageAttributeKey: testMessageAttributeValue})) + + // ASSERT + assert.Nil(t, err) + + msg, attributes := awsCmdReceiveMessageWithAttributes() + assert.Equal(t, testMessage, msg) + assert.Equal(t, testMessageAttributeValue, attributes[testMessageAttributeKey]) +} + func TestSQSSendWithDelay(t *testing.T) { // ARRANGE setup() @@ -274,6 +342,33 @@ func TestSQSSendWithDelay(t *testing.T) { assert.True(t, timeElapsed < timeout) } +func TestSQSSendMessageWithDelay(t *testing.T) { + // ARRANGE + setup() + defer teardown() + + client, err := New() + require.Nil(t, err, fmt.Sprintf("Error creating sqs client: %v", err)) + + // ACTION + _, err = client.SendMessage(awsCmdQueueURL(), testMessage, WithDelay(5)) // delay seconds + assert.Nil(t, err) // fail fast here to avoid loop queue check + + start := time.Now() + timeout := 10 * time.Second + for awsCmdQueueCount() < 1 { + time.Sleep(500 * time.Millisecond) + if time.Since(start) > timeout { + break + } + } + timeElapsed := time.Since(start) + + // ASSERT + assert.True(t, timeElapsed > 5*time.Second) + assert.True(t, timeElapsed < timeout) +} + func TestGetQueueUrl(t *testing.T) { // ARRANGE setup()