-
Notifications
You must be signed in to change notification settings - Fork 33
/
Salesforce.py
426 lines (352 loc) · 14.3 KB
/
Salesforce.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
from base64 import b64encode
from logging import getLogger
from pathlib import Path
from random import randrange
from tempfile import TemporaryDirectory
from snowfakery import data_gen_exceptions as exc
from snowfakery.data_gen_exceptions import (
DataGenError,
DataGenNameError,
DataGenValueError,
)
from snowfakery.data_generator_runtime_object_model import (
FieldFactory,
ObjectTemplate,
SimpleValue,
StructuredValue,
)
from snowfakery.output_streams import SqlDbOutputStream
from snowfakery.parse_recipe_yaml import TableInfo
from snowfakery.plugins import (
ParserMacroPlugin,
PluginOption,
PluginResult,
SnowfakeryPlugin,
memorable,
)
from snowfakery.standard_plugins.datasets import (
DatasetBase,
DatasetPluginBase,
sql_dataset,
)
MAX_SALESFORCE_OFFSET = 2000 # Any way around this?
# the option name that the user specifies on the CLI or API is just "org_name"
# but using this long name internally prevents us from clashing with the
# user's variable names.
plugin_option_org_name = (
"snowfakery.standard_plugins.Salesforce.SalesforceQuery.org_name"
)
plugin_option_org_config = (
"snowfakery.standard_plugins.Salesforce.SalesforceQuery.org_config"
)
plugin_option_project_config = (
"snowfakery.standard_plugins.Salesforce.SalesforceQuery.project_config"
)
class SalesforceConnection:
"""Helper layer above simple_salesforce and salesforce_bulk"""
_sf = None
def __init__(self, get_project_config_and_org_config):
self.get_project_config_and_org_config = get_project_config_and_org_config
self.logger = getLogger(__name__)
@property
def sf(self):
"""simple_salesforce client"""
if not self._sf:
project_config, org_config = self.get_project_config_and_org_config()
self._sf, self._bulk = self._get_sf_clients(project_config, org_config)
return self._sf
@property
def bulk(self):
"""salesforce_bulk client"""
self.sf # initializes self._bulk as a side-effect
return self._bulk
def query(self, *args, **kwargs):
"""Query Salesforce through simple_salesforce"""
return self.sf.query(*args, **kwargs)
def query_single_record(self, query):
"""Query Salesforce through simple_salesforce and
validate that query returns 1 and only 1 record"""
qr = self.sf.query(query)
records = qr.get("records")
if not records:
raise DataGenValueError(f"No records returned by query {query}", None, None)
elif len(records) > 1: # pragma: no cover
raise DataGenValueError(
f"Multiple records returned by query {query}", None, None
)
record = records[0]
if "attributes" in record:
del record["attributes"]
if len(record.keys()) == 1:
return tuple(record.values())[0]
else:
return PluginResult(record)
def compose_query(self, context_name, **kwargs):
kwargs = kwargs.copy()
fields = kwargs.pop("fields", None)
sobject = kwargs.pop("from", None)
where = kwargs.pop("where", None)
if not fields:
raise DataGenError(f"{context_name} needs a 'fields' list")
if not sobject:
raise DataGenError(f"{context_name} needs a 'from'")
if kwargs:
raise DataGenError(
f"Unknown argument in {context_name}: {tuple(kwargs.keys())}"
)
query = f"SELECT {fields} FROM {sobject} "
if where:
query += f" WHERE {where}"
return query
@staticmethod
def _get_sf_clients(project_config, org_config):
from cumulusci.salesforce_api.utils import get_simple_salesforce_connection
sf = get_simple_salesforce_connection(project_config, org_config)
return sf, _init_bulk(sf, org_config)
def _init_bulk(sf, org_config):
from salesforce_bulk import SalesforceBulk
return SalesforceBulk(
host=org_config.instance_url.replace("https://", "").rstrip("/"),
sessionId=org_config.access_token,
API_version=sf.sf_version,
)
def check_orgconfig(config):
from cumulusci.core.config import BaseConfig
if isinstance(config, BaseConfig):
return config
raise TypeError(f"Should be a CCI Config, not {type(config)}")
class SalesforceConnectionMixin:
_sf_connection = None
_runtime = None
allowed_options = [
PluginOption(plugin_option_org_name, str),
PluginOption(plugin_option_org_config, check_orgconfig),
PluginOption(plugin_option_project_config, check_orgconfig),
]
@property
def sf_connection(self):
assert self.context
if not self._sf_connection:
self._sf_connection = SalesforceConnection(
self.get_project_config_and_org_config
)
return self._sf_connection
def get_project_config_and_org_config(self):
fieldvars = self.context.field_vars()
project_config = fieldvars.get(plugin_option_project_config)
org_config = fieldvars.get(plugin_option_org_config)
if not project_config or not org_config:
project_config, org_config = self._get_org_info_from_cli_keychain()
return project_config, org_config
def _get_org_info_from_cli_keychain(self):
org_name = self.get_org_name() # from command line argument
runtime = self._get_CliRuntime() # from CCI CliRuntime
name, org_config = runtime.get_org(org_name)
return runtime.project_config, org_config
def _get_CliRuntime(self):
if self._runtime:
return self._runtime # pragma: no cover
try:
from cumulusci.cli.runtime import CliRuntime
self._runtime = CliRuntime(load_keychain=True)
return self._runtime
except Exception as e: # pragma: no cover
raise DataGenError("CumulusCI Runtime cannot be loaded", *e.args)
def get_org_name(self):
"""Look up the org_name in the scope"""
fieldvars = self.context.field_vars()
try:
return fieldvars[plugin_option_org_name]
except KeyError:
raise DataGenNameError(
"Orgname is not specified. Use --plugin-option org_name <yourorgname>",
None,
None,
)
class Salesforce(ParserMacroPlugin, SnowfakeryPlugin, SalesforceConnectionMixin):
def __init__(self, *args, **kwargs):
args = args or [None]
super().__init__(*args, **kwargs)
def SpecialObject(self, context, args) -> ObjectTemplate:
"""Currently there is only one special object defined: PersonContact"""
sobj, nickname = self._parse_special_args(args)
line_info = context.line_num()
if sobj == "PersonContact":
return self._render_person_contact(context, sobj, nickname, line_info)
else:
raise exc.DataGenError(
f"Unknown special object '{sobj}'. Did you mean 'PersonContact'?",
None,
None,
)
def _render_person_contact(self, context, sobj, nickname, line_info):
"""Generate the code to render a person contact as CCI expects.
Code generation is a better strategy for this than a runtime
plugin because some analysis of the table structures happens
at parse time.
"""
fields = [
FieldFactory(
"IsPersonAccount",
SimpleValue("true", **line_info),
**line_info,
),
FieldFactory(
"AccountId",
StructuredValue("reference", ["Account"], **line_info),
**line_info,
),
]
new_template = ObjectTemplate(
sobj,
filename=line_info["filename"],
line_num=line_info["line_num"],
nickname=nickname,
fields=fields,
)
context.register_template(new_template)
return new_template
def _parse_special_args(self, args):
"""Parse args of SpecialObject"""
nickname = None
if isinstance(args, str):
sobj = args
elif isinstance(args, dict):
sobj = args["name"]
if not isinstance(sobj, str):
raise exc.DataGenError(
f"`name` argument should be a string, not `{sobj}`: ({type(sobj)})"
)
nickname = args.get("nickname")
if nickname and not isinstance(nickname, str):
raise exc.DataGenError(
f"`nickname` argument should be a string, not `{nickname}``: ({type(sobj)})"
)
else:
raise exc.DataGenError(
f"`name` argument should be a string, not `{args}``: ({type(args)})"
)
return sobj, nickname
class Functions:
@memorable
def ProfileId(self, name):
query = f"select Id from Profile where Name='{name}'"
return self.context.plugin.sf_connection.query_single_record(query)
Profile = ProfileId
def ContentFile(self, file: str):
template_path = Path(self.context.current_filename).parent
with open(template_path / file, "rb") as data:
return b64encode(data.read()).decode("ascii")
class SOQLDatasetImpl(DatasetBase):
iterator = None
tempdir = None
def __init__(self, plugin, *args, **kwargs):
from cumulusci.tasks.bulkdata.step import (
DataOperationStatus,
get_query_operation,
)
self.get_query_operation = get_query_operation
self.DataOperationStatus = DataOperationStatus
self.plugin = plugin
super().__init__(*args, **kwargs)
@property
def sf_connection(self):
return self.plugin.sf_connection
def _load_dataset(self, iteration_mode, rootpath, kwargs):
from cumulusci.tasks.bulkdata.step import DataApi
query = self.sf_connection.compose_query("SOQLDataset", **kwargs)
fields = kwargs.get("fields")
sobject = kwargs.get("from")
fieldnames = [f.strip() for f in fields.split(",")]
qs = self.get_query_operation(
sobject=sobject,
fields=fieldnames,
api_options={},
context=self.sf_connection,
query=query,
api=DataApi.SMART,
)
try:
qs.query()
except Exception as e:
raise DataGenError(f"Unable to query records for {query}: {e}") from e
if qs.job_result.status is not self.DataOperationStatus.SUCCESS:
raise DataGenError(
f"Unable to query records for {query}: {','.join(qs.job_result.job_errors)}"
)
tempdir, iterator = create_tempfile_sql_db_iterator(
iteration_mode, fieldnames, qs.get_results()
)
iterator.cleanup.push(tempdir)
return iterator
def close(self):
pass
def create_tempfile_sql_db_iterator(mode, fieldnames, results):
tempdir, db_url = _create_db(fieldnames, results)
rc = sql_dataset(db_url, "data", mode)
return tempdir, rc
def _create_db(fieldnames, results):
tempdir = TemporaryDirectory()
tempfile = Path(tempdir.name) / "queryresults.db"
# TODO: try a real tempdb: "sqlite:///"
dburl = f"sqlite:///{tempfile}"
with SqlDbOutputStream.from_url(dburl) as db:
ti = TableInfo("data")
ti.fields = {fieldname: None for fieldname in fieldnames}
db.create_or_validate_tables({"data": ti})
for row in results:
row_dict = {fieldname: result for fieldname, result in zip(fieldnames, row)}
db.write_row("data", row_dict)
db.flush()
db.close()
return tempdir, dburl
class SOQLDataset(SalesforceConnectionMixin, DatasetPluginBase):
def __init__(self, *args, **kwargs):
self.dataset_impl = SOQLDatasetImpl(self)
super().__init__(*args, **kwargs)
class SalesforceQuery(SalesforceConnectionMixin, SnowfakeryPlugin):
class Functions:
@property
def _sf_connection(self):
return self.context.plugin.sf_connection
def random_record(self, *args, fields="Id", where=None, **kwargs):
"""Query a random record."""
context_vars = self.context.context_vars()
context_vars.setdefault("count_query_cache", {})
# "from" has to be handled separately because its a Python keyword
query_from = self._parse_from_from_args(args, kwargs)
# TODO: Test WHERE
where_clause = f" WHERE {where}" if where else ""
count_query = f"SELECT count() FROM {query_from}{where_clause}"
count_result = self._sf_connection.query(count_query)
count = count_result["totalSize"]
mx = min(count, MAX_SALESFORCE_OFFSET)
context_vars["count_query_cache"][count_query] = mx
if mx < 1:
raise DataGenError(
f"No records found matching {query_from}{where_clause}"
)
rand_offset = randrange(0, mx)
query = f"SELECT {fields} FROM {query_from}{where_clause} LIMIT 1 OFFSET {rand_offset}"
# todo: use CompositeParallelSalesforce to cache 200 at a time
return self._sf_connection.query_single_record(query)
@memorable
def find_record(self, *args, fields="Id", where=None, **kwargs):
"""Find a particular record"""
query_from = self._parse_from_from_args(args, kwargs)
where_clause = f" WHERE {where}" if where else ""
query = f"SELECT {fields} FROM {query_from}{where_clause} LIMIT 1"
return self._sf_connection.query_single_record(query)
def _parse_from_from_args(self, args, kwargs):
query_from = None
if kwargs:
query_from = kwargs.pop("from", None)
if kwargs:
raise ValueError(f"Unknown arguments: {tuple(kwargs.keys())}")
elif args:
if len(args) != 1 or not isinstance(args[0], str):
raise ValueError(f"Only one string argument allowed, not: {args}")
query_from = args[0]
if not query_from:
raise ValueError("Must supply 'from:'")
return query_from