Skip to content

Commit

Permalink
Support setting principal and SASL extensions in oauth_cb, handle fai…
Browse files Browse the repository at this point in the history
…lures (#1402)

* Support setting principal and SASL extensions in oauth_cb and handle token failures

* removed global variables

Co-authored-by: Emanuele Sabellico <[email protected]>
  • Loading branch information
Manicben and emasab authored Aug 2, 2022
1 parent 9ea3aae commit f8b6468
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 38 deletions.
103 changes: 97 additions & 6 deletions src/confluent_kafka/src/confluent_kafka.c
Original file line number Diff line number Diff line change
Expand Up @@ -1522,13 +1522,73 @@ static void log_cb (const rd_kafka_t *rk, int level,
CallState_resume(cs);
}

/**
* @brief Translate Python \p key and \p value to C types and set on
* provided \p extensions char* array at the provided index.
*
* @returns 1 on success or 0 if an exception was raised.
*/
static int py_extensions_to_c (char **extensions, Py_ssize_t idx,
PyObject *key, PyObject *value) {
PyObject *ks, *ks8, *vo8 = NULL;
const char *k;
const char *v;
Py_ssize_t ksize = 0;
Py_ssize_t vsize = 0;

if (!(ks = cfl_PyObject_Unistr(key))) {
PyErr_SetString(PyExc_TypeError,
"expected extension key to be unicode "
"string");
return 0;
}

k = cfl_PyUnistr_AsUTF8(ks, &ks8);
ksize = (Py_ssize_t)strlen(k);

if (cfl_PyUnistr(_Check(value))) {
/* Unicode string, translate to utf-8. */
v = cfl_PyUnistr_AsUTF8(value, &vo8);
if (!v) {
Py_DECREF(ks);
Py_XDECREF(ks8);
return 0;
}
vsize = (Py_ssize_t)strlen(v);
} else {
PyErr_Format(PyExc_TypeError,
"expected extension value to be "
"unicode string, not %s",
((PyTypeObject *)PyObject_Type(value))->
tp_name);
Py_DECREF(ks);
Py_XDECREF(ks8);
return 0;
}

extensions[idx] = (char*)malloc(ksize);
strcpy(extensions[idx], k);
extensions[idx + 1] = (char*)malloc(vsize);
strcpy(extensions[idx + 1], v);

Py_DECREF(ks);
Py_XDECREF(ks8);
Py_XDECREF(vo8);

return 1;
}

static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
void *opaque) {
Handle *h = opaque;
PyObject *eo, *result;
CallState *cs;
const char *token;
double expiry;
const char *principal = "";
PyObject *extensions = NULL;
char **rd_extensions = NULL;
Py_ssize_t rd_extensions_size = 0;
char err_msg[2048];
rd_kafka_resp_err_t err_code;

Expand All @@ -1539,26 +1599,57 @@ static void oauth_cb (rd_kafka_t *rk, const char *oauthbearer_config,
Py_DECREF(eo);

if (!result) {
goto err;
goto fail;
}
if (!PyArg_ParseTuple(result, "sd", &token, &expiry)) {
if (!PyArg_ParseTuple(result, "sd|sO!", &token, &expiry, &principal, &PyDict_Type, &extensions)) {
Py_DECREF(result);
PyErr_Format(PyExc_TypeError,
PyErr_SetString(PyExc_TypeError,
"expect returned value from oauth_cb "
"to be (token_str, expiry_time) tuple");
goto err;
}

if (extensions) {
int len = (int)PyDict_Size(extensions);
rd_extensions = (char **)malloc(2 * len * sizeof(char *));
Py_ssize_t pos = 0;
PyObject *ko, *vo;
while (PyDict_Next(extensions, &pos, &ko, &vo)) {
if (!py_extensions_to_c(rd_extensions, rd_extensions_size, ko, vo)) {
Py_DECREF(result);
free(rd_extensions);
goto err;
}
rd_extensions_size = rd_extensions_size + 2;
}
}

err_code = rd_kafka_oauthbearer_set_token(h->rk, token,
(int64_t)(expiry * 1000),
"", NULL, 0, err_msg,
principal, (const char **)rd_extensions, rd_extensions_size, err_msg,
sizeof(err_msg));
Py_DECREF(result);
if (err_code) {
if (rd_extensions) {
for(int i = 0; i < rd_extensions_size; i++) {
free(rd_extensions[i]);
}
free(rd_extensions);
}

if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
PyErr_Format(PyExc_ValueError, "%s", err_msg);
goto err;
goto fail;
}
goto done;

fail:
err_code = rd_kafka_oauthbearer_set_token_failure(h->rk, "OAuth callback raised exception");
if (err_code != RD_KAFKA_RESP_ERR_NO_ERROR) {
PyErr_SetString(PyExc_ValueError, "Failed to set token failure");
goto err;
}
PyErr_Clear();
goto done;
err:
CallState_crash(cs);
rd_kafka_yield(h->rk);
Expand Down
99 changes: 67 additions & 32 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,49 +24,41 @@ def test_version():
assert confluent_kafka.version()[0] == confluent_kafka.__version__


# global variable for error_cb call back function
seen_error_cb = False


def test_error_cb():
""" Tests error_cb. """
seen_error_cb = False

def error_cb(error_msg):
global seen_error_cb
nonlocal seen_error_cb
seen_error_cb = True
acceptable_error_codes = (confluent_kafka.KafkaError._TRANSPORT, confluent_kafka.KafkaError._ALL_BROKERS_DOWN)
assert error_msg.code() in acceptable_error_codes

conf = {'bootstrap.servers': 'localhost:65531', # Purposely cause connection refused error
'group.id': 'test',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'error_cb': error_cb
}

kc = confluent_kafka.Consumer(**conf)
kc.subscribe(["test"])
while not seen_error_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)

kc.close()


# global variable for stats_cb call back function
seen_stats_cb = False


def test_stats_cb():
""" Tests stats_cb. """
seen_stats_cb = False

def stats_cb(stats_json_str):
global seen_stats_cb
nonlocal seen_stats_cb
seen_stats_cb = True
stats_json = json.loads(stats_json_str)
assert len(stats_json['name']) > 0

conf = {'group.id': 'test',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'statistics.interval.ms': 200,
'stats_cb': stats_cb
Expand All @@ -76,22 +68,20 @@ def stats_cb(stats_json_str):

kc.subscribe(["test"])
while not seen_stats_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)
kc.close()


seen_stats_cb_check_no_brokers = False


def test_conf_none():
""" Issue #133
Test that None can be passed for NULL by setting bootstrap.servers
to None. If None would be converted to a string then a broker would
show up in statistics. Verify that it doesnt. """
seen_stats_cb_check_no_brokers = False

def stats_cb_check_no_brokers(stats_json_str):
""" Make sure no brokers are reported in stats """
global seen_stats_cb_check_no_brokers
nonlocal seen_stats_cb_check_no_brokers
stats = json.loads(stats_json_str)
assert len(stats['brokers']) == 0, "expected no brokers in stats: %s" % stats_json_str
seen_stats_cb_check_no_brokers = True
Expand All @@ -101,9 +91,8 @@ def stats_cb_check_no_brokers(stats_json_str):
'stats_cb': stats_cb_check_no_brokers}

p = confluent_kafka.Producer(conf)
p.poll(timeout=1)
p.poll(timeout=0.1)

global seen_stats_cb_check_no_brokers
assert seen_stats_cb_check_no_brokers


Expand All @@ -130,23 +119,19 @@ def test_throttle_event_types():
assert str(throttle_event) == "broker/0 throttled for 10000 ms"


# global variable for oauth_cb call back function
seen_oauth_cb = False


def test_oauth_cb():
""" Tests oauth_cb. """
seen_oauth_cb = False

def oauth_cb(oauth_config):
global seen_oauth_cb
nonlocal seen_oauth_cb
seen_oauth_cb = True
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 300.0

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'socket.timeout.ms': '100',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
Expand All @@ -155,7 +140,59 @@ def oauth_cb(oauth_config):
kc = confluent_kafka.Consumer(**conf)

while not seen_oauth_cb:
kc.poll(timeout=1)
kc.poll(timeout=0.1)
kc.close()


def test_oauth_cb_principal_sasl_extensions():
""" Tests oauth_cb. """
seen_oauth_cb = False

def oauth_cb(oauth_config):
nonlocal seen_oauth_cb
seen_oauth_cb = True
assert oauth_config == 'oauth_cb'
return 'token', time.time() + 300.0, oauth_config, {"extone": "extoneval", "exttwo": "exttwoval"}

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 100, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = confluent_kafka.Consumer(**conf)

while not seen_oauth_cb:
kc.poll(timeout=0.1)
kc.close()


def test_oauth_cb_failure():
""" Tests oauth_cb. """
oauth_cb_count = 0

def oauth_cb(oauth_config):
nonlocal oauth_cb_count
oauth_cb_count += 1
assert oauth_config == 'oauth_cb'
if oauth_cb_count == 2:
return 'token', time.time() + 100.0, oauth_config, {"extthree": "extthreeval"}
raise Exception

conf = {'group.id': 'test',
'security.protocol': 'sasl_plaintext',
'sasl.mechanisms': 'OAUTHBEARER',
'session.timeout.ms': 1000, # Avoid close() blocking too long
'sasl.oauthbearer.config': 'oauth_cb',
'oauth_cb': oauth_cb
}

kc = confluent_kafka.Consumer(**conf)

while oauth_cb_count < 2:
kc.poll(timeout=0.1)
kc.close()


Expand Down Expand Up @@ -194,11 +231,9 @@ def test_unordered_dict(init_func):
client.poll(0)


# global variable for on_delivery call back function
seen_delivery_cb = False


def test_topic_config_update():
seen_delivery_cb = False

# *NOTE* default.topic.config has been deprecated.
# This example remains to ensure backward-compatibility until its removal.
confs = [{"message.timeout.ms": 600000, "default.topic.config": {"message.timeout.ms": 1000}},
Expand All @@ -207,7 +242,7 @@ def test_topic_config_update():

def on_delivery(err, msg):
# Since there is no broker, produced messages should time out.
global seen_delivery_cb
nonlocal seen_delivery_cb
seen_delivery_cb = True
assert err.code() == confluent_kafka.KafkaError._MSG_TIMED_OUT

Expand Down

0 comments on commit f8b6468

Please sign in to comment.