Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework mqtt_consumer connect/reconnect #4846

Merged
merged 3 commits into from
Oct 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 54 additions & 69 deletions plugins/inputs/mqtt_consumer/mqtt_consumer.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
package mqtt_consumer

import (
"errors"
"fmt"
"log"
"strings"
"sync"
"time"

"github.com/influxdata/telegraf"
Expand All @@ -19,6 +19,14 @@ import (
// 30 Seconds is the default used by paho.mqtt.golang
var defaultConnectionTimeout = internal.Duration{Duration: 30 * time.Second}

type ConnectionState int

const (
Disconnected ConnectionState = iota
Connecting
Connected
)

type MQTTConsumer struct {
Servers []string
Topics []string
Expand All @@ -36,16 +44,10 @@ type MQTTConsumer struct {
ClientID string `toml:"client_id"`
tls.ClientConfig

sync.Mutex
client mqtt.Client
// channel of all incoming raw mqtt messages
in chan mqtt.Message
done chan struct{}

// keep the accumulator internally:
acc telegraf.Accumulator

connected bool
client mqtt.Client
acc telegraf.Accumulator
state ConnectionState
subscribed bool
}

var sampleConfig = `
Expand Down Expand Up @@ -110,22 +112,19 @@ func (m *MQTTConsumer) SetParser(parser parsers.Parser) {
}

func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
m.Lock()
defer m.Unlock()
m.connected = false
m.state = Disconnected

if m.PersistentSession && m.ClientID == "" {
return fmt.Errorf("ERROR MQTT Consumer: When using persistent_session" +
" = true, you MUST also set client_id")
return errors.New("persistent_session requires client_id")
}

m.acc = acc
if m.QoS > 2 || m.QoS < 0 {
return fmt.Errorf("MQTT Consumer, invalid QoS value: %d", m.QoS)
return fmt.Errorf("qos value must be 0, 1, or 2: %d", m.QoS)
}

if m.ConnectionTimeout.Duration < 1*time.Second {
return fmt.Errorf("MQTT Consumer, invalid connection_timeout value: %s", m.ConnectionTimeout.Duration)
return fmt.Errorf("connection_timeout must be greater than 1s: %s", m.ConnectionTimeout.Duration)
}

opts, err := m.createOpts()
Expand All @@ -134,9 +133,7 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
}

m.client = mqtt.NewClient(opts)
m.in = make(chan mqtt.Message, 1000)
m.done = make(chan struct{})

m.state = Connecting
m.connect()

return nil
Expand All @@ -145,80 +142,68 @@ func (m *MQTTConsumer) Start(acc telegraf.Accumulator) error {
func (m *MQTTConsumer) connect() error {
if token := m.client.Connect(); token.Wait() && token.Error() != nil {
err := token.Error()
log.Printf("D! MQTT Consumer, connection error - %v", err)

m.state = Disconnected
return err
}

go m.receiver()
log.Printf("I! [inputs.mqtt_consumer]: connected %v", m.Servers)
m.state = Connected

return nil
}

func (m *MQTTConsumer) onConnect(c mqtt.Client) {
log.Printf("I! MQTT Client Connected")
if !m.PersistentSession || !m.connected {
// Only subscribe on first connection when using persistent sessions. On
// subsequent connections the subscriptions should be stored in the
// session, but the proper way to do this is to check the connection
// response to ensure a session was found.
if !m.PersistentSession || !m.subscribed {
topics := make(map[string]byte)
for _, topic := range m.Topics {
topics[topic] = byte(m.QoS)
}
subscribeToken := c.SubscribeMultiple(topics, m.recvMessage)
subscribeToken := m.client.SubscribeMultiple(topics, m.recvMessage)
subscribeToken.Wait()
if subscribeToken.Error() != nil {
m.acc.AddError(fmt.Errorf("E! MQTT Subscribe Error\ntopics: %s\nerror: %s",
m.acc.AddError(fmt.Errorf("subscription error: topics: %s: %v",
strings.Join(m.Topics[:], ","), subscribeToken.Error()))
}
m.connected = true
m.subscribed = true
}
return

return nil
}

func (m *MQTTConsumer) onConnectionLost(c mqtt.Client, err error) {
m.acc.AddError(fmt.Errorf("E! MQTT Connection lost\nerror: %s\nMQTT Client will try to reconnect", err.Error()))
m.acc.AddError(fmt.Errorf("connection lost: %v", err))
log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers)
m.state = Disconnected
return
}

// receiver() reads all incoming messages from the consumer, and parses them into
// influxdb metric points.
func (m *MQTTConsumer) receiver() {
for {
select {
case <-m.done:
return
case msg := <-m.in:
topic := msg.Topic()
metrics, err := m.parser.Parse(msg.Payload())
if err != nil {
m.acc.AddError(fmt.Errorf("E! MQTT Parse Error\nmessage: %s\nerror: %s",
string(msg.Payload()), err.Error()))
}

for _, metric := range metrics {
tags := metric.Tags()
tags["topic"] = topic
m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time())
}
}
func (m *MQTTConsumer) recvMessage(c mqtt.Client, msg mqtt.Message) {
topic := msg.Topic()
metrics, err := m.parser.Parse(msg.Payload())
if err != nil {
m.acc.AddError(err)
}
}

func (m *MQTTConsumer) recvMessage(_ mqtt.Client, msg mqtt.Message) {
m.in <- msg
for _, metric := range metrics {
tags := metric.Tags()
tags["topic"] = topic
m.acc.AddFields(metric.Name(), metric.Fields(), tags, metric.Time())
}
}

func (m *MQTTConsumer) Stop() {
m.Lock()
defer m.Unlock()

if m.connected {
close(m.done)
if m.state == Connected {
log.Printf("D! [inputs.mqtt_consumer]: disconnecting %v", m.Servers)
m.client.Disconnect(200)
m.connected = false
log.Printf("D! [inputs.mqtt_consumer]: disconnected %v", m.Servers)
m.state = Disconnected
}
}

func (m *MQTTConsumer) Gather(acc telegraf.Accumulator) error {
if !m.connected {
if m.state == Disconnected {
m.state = Connecting
log.Printf("D! [inputs.mqtt_consumer]: connecting %v", m.Servers)
m.connect()
}

Expand Down Expand Up @@ -261,7 +246,7 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) {
for _, server := range m.Servers {
// Preserve support for host:port style servers; deprecated in Telegraf 1.4.4
if !strings.Contains(server, "://") {
log.Printf("W! mqtt_consumer server %q should be updated to use `scheme://host:port` format", server)
log.Printf("W! [inputs.mqtt_consumer] server %q should be updated to use `scheme://host:port` format", server)
if tlsCfg == nil {
server = "tcp://" + server
} else {
Expand All @@ -271,10 +256,9 @@ func (m *MQTTConsumer) createOpts() (*mqtt.ClientOptions, error) {

opts.AddBroker(server)
}
opts.SetAutoReconnect(true)
opts.SetAutoReconnect(false)
opts.SetKeepAlive(time.Second * 60)
opts.SetCleanSession(!m.PersistentSession)
opts.SetOnConnectHandler(m.onConnect)
opts.SetConnectionLostHandler(m.onConnectionLost)

return opts, nil
Expand All @@ -284,6 +268,7 @@ func init() {
inputs.Add("mqtt_consumer", func() telegraf.Input {
return &MQTTConsumer{
ConnectionTimeout: defaultConnectionTimeout,
state: Disconnected,
}
})
}
106 changes: 14 additions & 92 deletions plugins/inputs/mqtt_consumer/mqtt_consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,17 @@ import (
)

const (
testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n"
testMsgNeg = "cpu_load_short,host=server01 value=-23422.0 1422568543702900257\n"
testMsgGraphite = "cpu.load.short.graphite 23422 1454780029"
testMsgJSON = "{\"a\": 5, \"b\": {\"c\": 6}}\n"
invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n"
testMsg = "cpu_load_short,host=server01 value=23422.0 1422568543702900257\n"
invalidMsg = "cpu_load_short,host=server01 1422568543702900257\n"
)

func newTestMQTTConsumer() (*MQTTConsumer, chan mqtt.Message) {
in := make(chan mqtt.Message, 100)
func newTestMQTTConsumer() *MQTTConsumer {
n := &MQTTConsumer{
Topics: []string{"telegraf"},
Servers: []string{"localhost:1883"},
in: in,
done: make(chan struct{}),
connected: true,
Topics: []string{"telegraf"},
Servers: []string{"localhost:1883"},
}

return n, in
return n
}

// Test that default client has random ID
Expand Down Expand Up @@ -79,31 +72,12 @@ func TestPersistentClientIDFail(t *testing.T) {
}

func TestRunParser(t *testing.T) {
n, in := newTestMQTTConsumer()
n := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc
defer close(n.done)

n.parser, _ = parsers.NewInfluxParser()
go n.receiver()
in <- mqttMsg(testMsgNeg)
acc.Wait(1)

if a := acc.NFields(); a != 1 {
t.Errorf("got %v, expected %v", a, 1)
}
}

func TestRunParserNegativeNumber(t *testing.T) {
n, in := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc
defer close(n.done)

n.parser, _ = parsers.NewInfluxParser()
go n.receiver()
in <- mqttMsg(testMsg)
acc.Wait(1)
n.recvMessage(nil, mqttMsg(testMsg))

if a := acc.NFields(); a != 1 {
t.Errorf("got %v, expected %v", a, 1)
Expand All @@ -112,84 +86,32 @@ func TestRunParserNegativeNumber(t *testing.T) {

// Test that the parser ignores invalid messages
func TestRunParserInvalidMsg(t *testing.T) {
n, in := newTestMQTTConsumer()
n := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc
defer close(n.done)

n.parser, _ = parsers.NewInfluxParser()
go n.receiver()
in <- mqttMsg(invalidMsg)
acc.WaitError(1)

n.recvMessage(nil, mqttMsg(invalidMsg))

if a := acc.NFields(); a != 0 {
t.Errorf("got %v, expected %v", a, 0)
}
assert.Contains(t, acc.Errors[0].Error(), "MQTT Parse Error")
assert.Len(t, acc.Errors, 1)
}

// Test that the parser parses line format messages into metrics
func TestRunParserAndGather(t *testing.T) {
n, in := newTestMQTTConsumer()
n := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc

defer close(n.done)

n.parser, _ = parsers.NewInfluxParser()
go n.receiver()
in <- mqttMsg(testMsg)
acc.Wait(1)

n.Gather(&acc)
n.recvMessage(nil, mqttMsg(testMsg))

acc.AssertContainsFields(t, "cpu_load_short",
map[string]interface{}{"value": float64(23422)})
}

// Test that the parser parses graphite format messages into metrics
func TestRunParserAndGatherGraphite(t *testing.T) {
n, in := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc
defer close(n.done)

n.parser, _ = parsers.NewGraphiteParser("_", []string{}, nil)
go n.receiver()
in <- mqttMsg(testMsgGraphite)

n.Gather(&acc)
acc.Wait(1)

acc.AssertContainsFields(t, "cpu_load_short_graphite",
map[string]interface{}{"value": float64(23422)})
}

// Test that the parser parses json format messages into metrics
func TestRunParserAndGatherJSON(t *testing.T) {
n, in := newTestMQTTConsumer()
acc := testutil.Accumulator{}
n.acc = &acc
defer close(n.done)

n.parser, _ = parsers.NewParser(&parsers.Config{
DataFormat: "json",
MetricName: "nats_json_test",
})
go n.receiver()
in <- mqttMsg(testMsgJSON)

n.Gather(&acc)

acc.Wait(1)

acc.AssertContainsFields(t, "nats_json_test",
map[string]interface{}{
"a": float64(5),
"b_c": float64(6),
})
}

func mqttMsg(val string) mqtt.Message {
return &message{
topic: "telegraf/unit_test",
Expand Down