-
Notifications
You must be signed in to change notification settings - Fork 420
/
Copy pathsession.py
297 lines (241 loc) · 9.5 KB
/
session.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
"""
Trace queries along a session to a cassandra cluster
"""
import sys
# 3p
import cassandra.cluster
# project
from ...compat import stringify
from ...constants import ANALYTICS_SAMPLE_RATE_KEY
from ...ext import net, cassandra as cassx, errors
from ...internal.logger import get_logger
from ...pin import Pin
from ...settings import config
from ...utils.deprecation import deprecated
from ...utils.formats import deep_getattr
from ...vendor import wrapt
log = get_logger(__name__)
RESOURCE_MAX_LENGTH = 5000
SERVICE = 'cassandra'
CURRENT_SPAN = '_ddtrace_current_span'
PAGE_NUMBER = '_ddtrace_page_number'
# Original connect connect function
_connect = cassandra.cluster.Cluster.connect
def patch():
""" patch will add tracing to the cassandra library. """
setattr(cassandra.cluster.Cluster, 'connect',
wrapt.FunctionWrapper(_connect, traced_connect))
Pin(service=SERVICE, app=SERVICE, app_type='db').onto(cassandra.cluster.Cluster)
def unpatch():
cassandra.cluster.Cluster.connect = _connect
def traced_connect(func, instance, args, kwargs):
session = func(*args, **kwargs)
if not isinstance(session.execute, wrapt.FunctionWrapper):
# FIXME[matt] this should probably be private.
setattr(session, 'execute_async', wrapt.FunctionWrapper(session.execute_async, traced_execute_async))
return session
def _close_span_on_success(result, future):
span = getattr(future, CURRENT_SPAN, None)
if not span:
log.debug('traced_set_final_result was not able to get the current span from the ResponseFuture')
return
try:
span.set_tags(_extract_result_metas(cassandra.cluster.ResultSet(future, result)))
except Exception as e:
log.debug('an exception occured while setting tags: %s', e)
finally:
span.finish()
delattr(future, CURRENT_SPAN)
def traced_set_final_result(func, instance, args, kwargs):
result = args[0]
_close_span_on_success(result, instance)
return func(*args, **kwargs)
def _close_span_on_error(exc, future):
span = getattr(future, CURRENT_SPAN, None)
if not span:
log.debug('traced_set_final_exception was not able to get the current span from the ResponseFuture')
return
try:
# handling the exception manually because we
# don't have an ongoing exception here
span.error = 1
span.set_tag(errors.ERROR_MSG, exc.args[0])
span.set_tag(errors.ERROR_TYPE, exc.__class__.__name__)
except Exception as e:
log.debug('traced_set_final_exception was not able to set the error, failed with error: %s', e)
finally:
span.finish()
delattr(future, CURRENT_SPAN)
def traced_set_final_exception(func, instance, args, kwargs):
exc = args[0]
_close_span_on_error(exc, instance)
return func(*args, **kwargs)
def traced_start_fetching_next_page(func, instance, args, kwargs):
has_more_pages = getattr(instance, 'has_more_pages', True)
if not has_more_pages:
return func(*args, **kwargs)
session = getattr(instance, 'session', None)
cluster = getattr(session, 'cluster', None)
pin = Pin.get_from(cluster)
if not pin or not pin.enabled():
return func(*args, **kwargs)
# In case the current span is not finished we make sure to finish it
old_span = getattr(instance, CURRENT_SPAN, None)
if old_span:
log.debug('previous span was not finished before fetching next page')
old_span.finish()
query = getattr(instance, 'query', None)
span = _start_span_and_set_tags(pin, query, session, cluster)
page_number = getattr(instance, PAGE_NUMBER, 1) + 1
setattr(instance, PAGE_NUMBER, page_number)
setattr(instance, CURRENT_SPAN, span)
try:
return func(*args, **kwargs)
except Exception:
with span:
span.set_exc_info(*sys.exc_info())
raise
def traced_execute_async(func, instance, args, kwargs):
cluster = getattr(instance, 'cluster', None)
pin = Pin.get_from(cluster)
if not pin or not pin.enabled():
return func(*args, **kwargs)
query = kwargs.get('query') or args[0]
span = _start_span_and_set_tags(pin, query, instance, cluster)
try:
result = func(*args, **kwargs)
setattr(result, CURRENT_SPAN, span)
setattr(result, PAGE_NUMBER, 1)
setattr(
result,
'_set_final_result',
wrapt.FunctionWrapper(
result._set_final_result,
traced_set_final_result
)
)
setattr(
result,
'_set_final_exception',
wrapt.FunctionWrapper(
result._set_final_exception,
traced_set_final_exception
)
)
setattr(
result,
'start_fetching_next_page',
wrapt.FunctionWrapper(
result.start_fetching_next_page,
traced_start_fetching_next_page
)
)
# Since we cannot be sure that the previous methods were overwritten
# before the call ended, we add callbacks that will be run
# synchronously if the call already returned and we remove them right
# after.
result.add_callbacks(
_close_span_on_success,
_close_span_on_error,
callback_args=(result,),
errback_args=(result,)
)
result.clear_callbacks()
return result
except Exception:
with span:
span.set_exc_info(*sys.exc_info())
raise
def _start_span_and_set_tags(pin, query, session, cluster):
service = pin.service
tracer = pin.tracer
span = tracer.trace('cassandra.query', service=service, span_type=cassx.TYPE)
_sanitize_query(span, query)
span.set_tags(_extract_session_metas(session)) # FIXME[matt] do once?
span.set_tags(_extract_cluster_metas(cluster))
# set analytics sample rate if enabled
span.set_tag(
ANALYTICS_SAMPLE_RATE_KEY,
config.cassandra.get_analytics_sample_rate()
)
return span
def _extract_session_metas(session):
metas = {}
if getattr(session, 'keyspace', None):
# FIXME the keyspace can be overridden explicitly in the query itself
# e.g. 'select * from trace.hash_to_resource'
metas[cassx.KEYSPACE] = session.keyspace.lower()
return metas
def _extract_cluster_metas(cluster):
metas = {}
if deep_getattr(cluster, 'metadata.cluster_name'):
metas[cassx.CLUSTER] = cluster.metadata.cluster_name
if getattr(cluster, 'port', None):
metas[net.TARGET_PORT] = cluster.port
return metas
def _extract_result_metas(result):
metas = {}
if result is None:
return metas
future = getattr(result, 'response_future', None)
if future:
# get the host
host = getattr(future, 'coordinator_host', None)
if host:
metas[net.TARGET_HOST] = host
elif hasattr(future, '_current_host'):
address = deep_getattr(future, '_current_host.address')
if address:
metas[net.TARGET_HOST] = address
query = getattr(future, 'query', None)
if getattr(query, 'consistency_level', None):
metas[cassx.CONSISTENCY_LEVEL] = query.consistency_level
if getattr(query, 'keyspace', None):
metas[cassx.KEYSPACE] = query.keyspace.lower()
page_number = getattr(future, PAGE_NUMBER, 1)
has_more_pages = getattr(future, 'has_more_pages')
is_paginated = has_more_pages or page_number > 1
metas[cassx.PAGINATED] = is_paginated
if is_paginated:
metas[cassx.PAGE_NUMBER] = page_number
if hasattr(result, 'current_rows'):
result_rows = result.current_rows or []
metas[cassx.ROW_COUNT] = len(result_rows)
return metas
def _sanitize_query(span, query):
# TODO (aaditya): fix this hacky type check. we need it to avoid circular imports
t = type(query).__name__
resource = None
if t in ('SimpleStatement', 'PreparedStatement'):
# reset query if a string is available
resource = getattr(query, 'query_string', query)
elif t == 'BatchStatement':
resource = 'BatchStatement'
# Each element in `_statements_and_parameters` is:
# (is_prepared, statement, parameters)
# ref:https://github.com/datastax/python-driver/blob/13d6d72be74f40fcef5ec0f2b3e98538b3b87459/cassandra/query.py#L844
#
# For prepared statements, the `statement` value is just the query_id
# which is not a statement and when trying to join with other strings
# raises an error in python3 around joining bytes to unicode, so this
# just filters out prepared statements from this tag value
q = '; '.join(q[1] for q in query._statements_and_parameters[:2] if not q[0])
span.set_tag('cassandra.query', q)
span.set_metric('cassandra.batch_size', len(query._statements_and_parameters))
elif t == 'BoundStatement':
ps = getattr(query, 'prepared_statement', None)
if ps:
resource = getattr(ps, 'query_string', None)
elif t == 'str':
resource = query
else:
resource = 'unknown-query-type' # FIXME[matt] what else do to here?
span.resource = stringify(resource)[:RESOURCE_MAX_LENGTH]
#
# DEPRECATED
#
@deprecated(message='Use patching instead (see the docs).', version='1.0.0')
def get_traced_cassandra(*args, **kwargs):
return _get_traced_cluster(*args, **kwargs)
def _get_traced_cluster(*args, **kwargs):
return cassandra.cluster.Cluster