Skip to content

Commit

Permalink
fix(automl): pass params to underlying client (#9794)
Browse files Browse the repository at this point in the history
  • Loading branch information
sirtorry authored and busunkim96 committed Nov 15, 2019
1 parent c5b12d4 commit 972e5b4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
6 changes: 5 additions & 1 deletion automl/google/cloud/automl_v1beta1/tables/tables_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2596,6 +2596,7 @@ def predict(
model=None,
model_name=None,
model_display_name=None,
params=None,
project=None,
region=None,
**kwargs
Expand Down Expand Up @@ -2642,6 +2643,9 @@ def predict(
The `model` instance you want to predict with . This must be
supplied if `model_display_name` or `model_name` are not
supplied.
params (dict[str, str]):
`feature_importance` can be set as True to enable local
explainability. The default is false.
Returns:
A :class:`~google.cloud.automl_v1beta1.types.PredictResponse`
Expand Down Expand Up @@ -2683,7 +2687,7 @@ def predict(

request = {"row": {"values": values}}

return self.prediction_client.predict(model.name, request, **kwargs)
return self.prediction_client.predict(model.name, request, params, **kwargs)

def batch_predict(
self,
Expand Down
8 changes: 6 additions & 2 deletions automl/tests/unit/gapic/v1beta1/test_tables_client_v1beta1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,7 +1117,7 @@ def test_predict_from_array(self):
client = self.tables_client({"get_model.return_value": model}, {})
client.predict(["1"], model_name="my_model")
client.prediction_client.predict.assert_called_with(
"my_model", {"row": {"values": [{"string_value": "1"}]}}
"my_model", {"row": {"values": [{"string_value": "1"}]}}, None
)

def test_predict_from_dict(self):
Expand All @@ -1134,6 +1134,7 @@ def test_predict_from_dict(self):
client.prediction_client.predict.assert_called_with(
"my_model",
{"row": {"values": [{"string_value": "1"}, {"string_value": "2"}]}},
None,
)

def test_predict_from_dict_missing(self):
Expand All @@ -1148,7 +1149,9 @@ def test_predict_from_dict_missing(self):
client = self.tables_client({"get_model.return_value": model}, {})
client.predict({"a": "1"}, model_name="my_model")
client.prediction_client.predict.assert_called_with(
"my_model", {"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}}
"my_model",
{"row": {"values": [{"string_value": "1"}, {"null_value": 0}]}},
None,
)

def test_predict_all_types(self):
Expand Down Expand Up @@ -1210,6 +1213,7 @@ def test_predict_all_types(self):
]
}
},
None,
)

def test_predict_from_array_missing(self):
Expand Down

0 comments on commit 972e5b4

Please sign in to comment.