Skip to content

Commit

Permalink
Merge branch 'master' into model-variable-agg-options
Browse files Browse the repository at this point in the history
  • Loading branch information
Robert McMahan committed Mar 5, 2024
2 parents d9c2467 + 4acbdd2 commit 2bdd284
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 33 deletions.
6 changes: 3 additions & 3 deletions backend/controller/ml_model/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ def _compile_template(self,
'is_calculating_conversion_values': step == Step.CALCULATING_CONVERSION_VALUES
},
'name': self.ml_model.name,
'model_project': self.project_id,
'model_dataset': self.ml_model.bigquery_dataset.name,
'ga4_measurement_id': self.ga4_measurement_id,
'ga4_api_secret': self.ga4_api_secret,
'dataset_location': self.ml_model.bigquery_dataset.location,
'location': self.ml_model.bigquery_dataset.location,
'project': self.project_id,
'dataset': self.ml_model.bigquery_dataset.name,
'type': {
'name': self.ml_model.type,
'is_regression': self.ml_model.type in ModelTypes.REGRESSION,
Expand Down
6 changes: 6 additions & 0 deletions backend/controller/ml_model/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class Timespan:
_consider_datetime: bool

def __init__(self, timespans: list[dict[str, Any]], consider_datetime: bool = False) -> None:
"""Uses timespans provided to create an accurate start/end for each step in the modeling process.
Args:
timespans: The set of timespans including training, predictive, and exclusion periods.
consider_datetime: Whether or not it should consider the timespan as a datetime or just a date.
"""
self._consider_datetime = consider_datetime
for timespan in timespans:
setattr(self, '_' + timespan['name'], int(timespan['value']))
Expand Down
10 changes: 5 additions & 5 deletions backend/controller/ml_model/templates/model_bqml.sql
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{% if step.is_training %}
CREATE OR REPLACE MODEL `{{model_project}}.{{model_dataset}}.predictive_model`
CREATE OR REPLACE MODEL `{{project}}.{{dataset}}.predictive_model`
OPTIONS (
MODEL_TYPE = "{{type.name}}",
-- inject the selected hyper parameters
Expand All @@ -15,7 +15,7 @@ OPTIONS (
INPUT_LABEL_COLS = ["label"]
) AS
{% elif step.is_predicting %}
CREATE OR REPLACE TABLE `{{model_project}}.{{model_dataset}}.predictions` AS (
CREATE OR REPLACE TABLE `{{project}}.{{dataset}}.predictions` AS (
SELECT
unique_id,
{% if google_analytics.in_source %}
Expand All @@ -26,9 +26,9 @@ SELECT
plp.prob AS probability,
{% endif %}
predicted_label
FROM ML.PREDICT(MODEL `{{model_project}}.{{model_dataset}}.predictive_model`, (
FROM ML.PREDICT(MODEL `{{project}}.{{dataset}}.predictive_model`, (
{% elif step.is_calculating_conversion_values %}
CREATE OR REPLACE TABLE `{{model_project}}.{{model_dataset}}.conversion_values` AS (
CREATE OR REPLACE TABLE `{{project}}.{{dataset}}.conversion_values` AS (
SELECT
normalized_probability,
(SUM(label) / COUNT(1)) * {{output.parameters.average_conversion_value}} AS value,
Expand All @@ -47,7 +47,7 @@ SELECT
p.label,
plp.prob AS probability,
NTILE({{conversion_rate_segments}}) OVER (ORDER BY plp.prob ASC) AS normalized_probability
FROM ML.PREDICT(MODEL `{{model_project}}.{{model_dataset}}.predictive_model`, (
FROM ML.PREDICT(MODEL `{{project}}.{{dataset}}.predictive_model`, (
{% endif %}
WITH
{% if first_party.in_source %}
Expand Down
8 changes: 4 additions & 4 deletions backend/controller/ml_model/templates/output.sql
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ SET _LATEST_TABLE_SUFFIX = (
WHERE REGEXP_CONTAINS(table_id, r"^(events_[0-9]{8})$")
);
{% endif %}
CREATE OR REPLACE TABLE `{{model_project}}.{{model_dataset}}.output` AS (
CREATE OR REPLACE TABLE `{{project}}.{{dataset}}.output` AS (
WITH
{% if google_analytics.in_source %}
events AS (
Expand Down Expand Up @@ -44,8 +44,8 @@ CREATE OR REPLACE TABLE `{{model_project}}.{{model_dataset}}.output` AS (
ROUND(MAX(cv.value), 4) AS value,
MAX(cv.normalized_probability) AS normalized_score,
MAX(p.probability) * 100 AS score
FROM `{{model_project}}.{{model_dataset}}.predictions` p
LEFT OUTER JOIN `{{model_project}}.{{model_dataset}}.conversion_values` cv
FROM `{{project}}.{{dataset}}.predictions` p
LEFT OUTER JOIN `{{project}}.{{dataset}}.conversion_values` cv
ON p.probability BETWEEN cv.probability_range_start AND cv.probability_range_end
{% if google_analytics.in_source %}
GROUP BY 1,2,3
Expand All @@ -63,7 +63,7 @@ CREATE OR REPLACE TABLE `{{model_project}}.{{model_dataset}}.output` AS (
{% endif %}
IF(predicted_label > 0, ROUND(predicted_label, 4), 0) AS value,
IF(predicted_label > 0, ROUND(predicted_label, 4), 0) AS revenue
FROM `{{model_project}}.{{model_dataset}}.predictions`
FROM `{{project}}.{{dataset}}.predictions`
),
{% endif %}
{% if output.destination.is_google_analytics_mp_event %}
Expand Down
16 changes: 8 additions & 8 deletions backend/controller/ml_model/templates/predictive_pipeline.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
}
]
},
Expand All @@ -39,7 +39,7 @@
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
}
]
},
Expand All @@ -56,17 +56,17 @@
{
"name": "bq_project_id",
"type": "{{ParamType.STRING}}",
"value": "{{model_project}}"
"value": "{{project}}"
},
{
"name": "bq_dataset_id",
"type": "{{ParamType.STRING}}",
"value": "{{model_dataset}}"
"value": "{{dataset}}"
},
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
},
{
"name": "bq_table_id",
Expand Down Expand Up @@ -105,17 +105,17 @@
{
"name": "bq_project_id",
"type": "{{ParamType.STRING}}",
"value": "{{model_project}}"
"value": "{{project}}"
},
{
"name": "bq_dataset_id",
"type": "{{ParamType.STRING}}",
"value": "{{model_dataset}}"
"value": "{{dataset}}"
},
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
},
{
"name": "bq_table_id",
Expand Down
4 changes: 2 additions & 2 deletions backend/controller/ml_model/templates/training_pipeline.json
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
}
]
}{% if type.is_classification %},
Expand All @@ -38,7 +38,7 @@
{
"name": "bq_dataset_location",
"type": "{{ParamType.STRING}}",
"value": "{{dataset_location}}"
"value": "{{location}}"
}
]
}
Expand Down
17 changes: 13 additions & 4 deletions backend/jobs/workers/bigquery/bq_worker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2020 Google Inc. All rights reserved.
# Copyright 2024 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,11 +14,11 @@

"""CRMint's abstract worker dealing with BigQuery."""


import os
import time

from google.api_core.client_info import ClientInfo
from google.cloud import bigquery

from jobs.workers import worker


Expand All @@ -42,7 +42,16 @@ class BQWorker(worker.Worker):
]

def _get_client(self):
return bigquery.Client(client_options={'scopes': self._SCOPES})
client_info = None
if 'REPORT_USAGE_ID' in os.environ:
client_id = os.getenv('REPORT_USAGE_ID')
opt_out = not bool(client_id)
if not opt_out:
client_info = ClientInfo(user_agent='cloud-solutions/crmint-usage-v3')
return bigquery.Client(
client_options={'scopes': self._SCOPES},
client_info=client_info,
)

def _get_prefix(self):
return f'{self._pipeline_id}_{self._job_id}_{self.__class__.__name__}'
Expand Down
46 changes: 46 additions & 0 deletions backend/tests/jobs/unit/workers/bq_worker_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Tests for bq_worker."""

from unittest import mock
import os

from absl.testing import absltest
from absl.testing import parameterized

from google.auth import credentials
from google.cloud import bigquery
from google.api_core.client_info import ClientInfo

from jobs.workers import worker
from jobs.workers.bigquery import bq_worker
Expand Down Expand Up @@ -82,6 +85,49 @@ def test_generates_proper_bq_table_name_from_params(self):
self.assertEqual('a_project.a_dataset_id.a_table_id',
worker._generate_qualified_bq_table_name())

class BQWorkerGetClientTest(parameterized.TestCase):

@parameterized.parameters(
{
'report_usage_id_present': True,
'client_info_user_agent': 'cloud-solutions/crmint-usage-v3',
},
{
'report_usage_id_present': False,
'client_info_user_agent': None
},
)
def test_get_client_handles_report_usage_id(
self, report_usage_id_present, client_info_user_agent):
report_usage_id = 'some-usage-id' if report_usage_id_present else ''
with (
mock.patch.dict(
os.environ,
{'REPORT_USAGE_ID': report_usage_id}
if report_usage_id_present
else {},
),
mock.patch('os.getenv', return_value=report_usage_id) as getenv_mock,
mock.patch('google.cloud.bigquery.Client') as client_mock,
):
worker_inst = bq_worker.BQWorker({}, 0, 0)
worker_inst._get_client()

if report_usage_id_present:
getenv_mock.assert_called_with('REPORT_USAGE_ID')
else:
getenv_mock.assert_not_called()

client_mock.assert_called_once()
_, kwargs = client_mock.call_args
if report_usage_id_present:
self.assertIsInstance(kwargs['client_info'], ClientInfo)
self.assertEqual(
kwargs['client_info'].user_agent, client_info_user_agent
)
else:
self.assertIsNone(kwargs.get('client_info'))


if __name__ == '__main__':
absltest.main()
2 changes: 1 addition & 1 deletion cli/appcli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _ask_permission(self):
fg='yellow')
msg += click.style(pkg_name, fg='red', bold=True)
msg += click.style(
' better! \nMay we anonymously report usage statistics to improve the'
' better! \nMay we anonymously report usage statistics to improve the '
'tool over time? \nMore info: https://github.com/google/crmint & '
'https://google.github.io/crmint',
fg='yellow')
Expand Down
12 changes: 6 additions & 6 deletions frontend/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions terraform/services.tf
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,10 @@ resource "google_cloud_run_service" "jobs_run" {
name = "PUBSUB_VERIFICATION_TOKEN"
value = random_id.pubsub_verification_token.b64_url
}
env {
name = "REPORT_USAGE_ID"
value = var.report_usage_id
}
}

timeout_seconds = 900 # 15min
Expand Down

0 comments on commit 2bdd284

Please sign in to comment.