forked from juan-rael/beam_spanner
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwrite_test.py
204 lines (168 loc) · 6.81 KB
/
write_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
from __future__ import absolute_import
import argparse
import datetime
import uuid
import apache_beam as beam
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import SetupOptions
from google.cloud._helpers import _microseconds_from_datetime
from google.cloud._helpers import UTC
from google.cloud.spanner import Client
from google.cloud.spanner_v1.session import Session
EXISTING_INSTANCES = []
LABEL_KEY = u'python-bigtable-beam'
label_stamp = datetime.datetime.utcnow().replace(tzinfo=UTC)
label_stamp_micros = _microseconds_from_datetime(label_stamp)
LABELS = {LABEL_KEY: str(label_stamp_micros)}
class GenerateRow(beam.DoFn):
def __init__(self):
from apache_beam.metrics import Metrics
self.generate_row = Metrics.counter(self.__class__, 'generate_row')
def __setstate__(self, options):
from apache_beam.metrics import Metrics
self.generate_row = Metrics.counter(self.__class__, 'generate_row')
def process(self, ranges):
from faker import Faker
fake = Faker()
for row_id in range(int(ranges[0]), int(ranges[1][0])):
self.generate_row.inc()
yield (row_id, fake.name())
class CreateAll():
def __init__(self, project_id, instance_id, database_id, table_id):
self.project_id = project_id
self.instance_id = instance_id
self.database_id = database_id
self.table_id = table_id
self.client = Client(project=self.project_id)
def create_table(self):
instance = self.client.instance(self.instance_id)
database = instance.database(self.database_id)
if not database.exists():
database = instance.database(self.database_id, ddl_statements=[
"""CREATE TABLE """+self.table_id+""" (
keyId INT64 NOT NULL,
Name STRING(1024),
) PRIMARY KEY (keyId)""",])
operation = database.create()
operation.result()
print('Database and Table Created')
class SpannerWriteFn(beam.DoFn):
def __init__(self, project_id, instance_id,
database_id, table_id, columns,
max_num_mutations=10000,
batch_size_bytes=0):
from google.cloud.spanner import Client
from apache_beam.metrics import Metrics
super(SpannerWriteFn, self).__init__()
self.beam_options = {'project_id': project_id,
'instance_id': instance_id,
'database_id': database_id,
'table_id': table_id,
'columns': columns,
'max_num_mutations': max_num_mutations,
'batch_size_bytes': batch_size_bytes}
client = Client(project=self.beam_options['project_id'])
instance = client.instance(self.beam_options['instance_id'])
database = instance.database(self.beam_options['database_id'])
# Create a Session
self.session = Session(database)
self.session.create()
self.written_row = Metrics.counter(self.__class__, 'Written Row')
def __getstate__(self):
return self.beam_options
def __setstate__(self, options):
from google.cloud.spanner import Client
from google.cloud.spanner_v1.session import Session
from apache_beam.metrics import Metrics
self.beam_options = options
client = Client(project=self.beam_options['project_id'])
instance = client.instance(self.beam_options['instance_id'])
database = instance.database(self.beam_options['database_id'])
# Create a Session
self.session = Session(database)
self.session.create()
self.written_row = Metrics.counter(self.__class__, 'Written Row')
def start_bundle(self):
self.transaction = self.session.transaction()
self.transaction.begin()
self.values = []
def _insert(self):
if len(self.values) > 0:
self.transaction.insert(
table=self.beam_options['table_id'],
columns=self.beam_options['columns'],
values=self.values)
self.transaction.commit()
self.written_row.inc(len(self.values))
self.values = []
def process(self, element):
if len(self.values) >= self.beam_options['max_num_mutations']:
self._insert()
self.transaction = self.session.transaction()
self.transaction.begin()
self.values.append(element)
def finish_bundle(self):
self._insert()
self.transaction = None
self.values = []
def display_data(self):
return {
'projectId': DisplayDataItem(self.beam_options['project_id'],
label='Spanner Project Id'),
'instanceId': DisplayDataItem(self.beam_options['instance_id'],
label='Spanner Instance Id'),
'databaseId': DisplayDataItem(self.beam_options['database_id'],
label='Spanner Database Id'),
'tableId': DisplayDataItem(self.beam_options['table_id'],
label='Spanner Table Id'),
}
def run(argv=[]):
project_id = 'grass-clump-479'
instance_id = 'python-write'
guid = str(uuid.uuid4())[:8]
database_id = 'pythontest'+ guid
guid = str(uuid.uuid4())[:8]
table_id = 'pythontable'
jobname = 'spanner-write-' + guid
argv.extend([
'--experiments=beam_fn_api',
'--project={}'.format(project_id),
'--instance={}'.format(instance_id),
'--job_name={}'.format(jobname),
'--requirements_file=requirements.txt',
'--disk_size_gb=50',
'--region=us-central1',
'--runner=dataflow',
'--autoscaling_algorithm=NONE',
'--num_workers=5',
'--staging_location=gs://juantest/stage',
'--temp_location=gs://juantest/temp',
])
parser = argparse.ArgumentParser(argv)
(known_args, pipeline_args) = parser.parse_known_args(argv)
create_table = CreateAll(project_id, instance_id, database_id, table_id)
print('ProjectID:',project_id)
print('InstanceID:',instance_id)
print('DatabaseID:',database_id)
print('TableID:',table_id)
print('JobID:', jobname)
create_table.create_table()
row_count = 10000000
row_limit = 1000
row_step = row_count if row_count <= row_limit else row_count/row_limit
pipeline_options = PipelineOptions(argv)
pipeline_options.view_as(SetupOptions).save_main_session = True
p = beam.Pipeline(options=pipeline_options)
count = (p
| 'Ranges' >> beam.Create([(str(i),str(i+row_step)) for i in xrange(0, row_count, row_step)])
| 'Group' >> beam.GroupByKey()
| 'Generate' >> beam.ParDo(GenerateRow())
| 'Print' >> beam.ParDo(SpannerWriteFn(project_id,
instance_id,
database_id,
table_id,
columns=('keyId', 'Name',))))
p.run()
if __name__ == '__main__':
run()