Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
resolving tabular task & test serve (#1551)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda authored May 11, 2023
1 parent 5c42e23 commit d3421c0
Show file tree
Hide file tree
Showing 24 changed files with 114 additions and 111 deletions.
8 changes: 5 additions & 3 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,13 @@ jobs:
key: flash-datasets-${{ hashFiles('tests/examples/test_scripts.py') }}
restore-keys: flash-datasets-

# ToDO
#- name: DocTests
# run: |
# pytest src/ -vv # --reruns 3 --reruns-delay 2

- name: Tests
env:
FIFTYONE_DO_NOT_TRACK: true
run: |
# FixMe: include doctests for src/
coverage run --source flash -m pytest \
tests/core \
tests/deprecated_api \
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ torchmetrics >0.7.0, <0.11.0 # strict
pytorch-lightning >1.6.0, <1.9.0 # strict
pyDeprecate >0.1.0
pandas >1.1.0, <=1.5.2
jsonargparse[signatures] >=3.17.0, <=4.9.0
jsonargparse[signatures] >4.0.0, <=4.9.0
click >=7.1.2, <=8.1.3
protobuf <=3.20.1
fsspec[http] >=2022.5.0,<=2022.7.1
Expand Down
11 changes: 5 additions & 6 deletions requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# NOTE: all pins for latest are for CI consistency unless it is `strict`, then it is also forced in setup

coverage[toml]
codecov >2.1
pytest >7.2, <7.4
pytest-doctestplus >0.9.0
pytest-rerunfailures >10.0
pytest-forked
pytest-mock
pytest >6.2, <7.0
pytest-doctestplus >0.12.0
pytest-rerunfailures >11.0.0
pytest-forked ==1.6.0
pytest-mock ==3.10.0

scikit-learn
torch_optimizer
4 changes: 2 additions & 2 deletions tests/audio/classification/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import patch

import pytest

Expand All @@ -23,7 +23,7 @@
@pytest.mark.skipif(not _TOPIC_AUDIO_AVAILABLE, reason="audio libraries aren't installed.")
def test_cli():
cli_args = ["flash", "audio_classification", "--trainer.fast_dev_run", "True"]
with mock.patch("sys.argv", cli_args):
with patch("sys.argv", cli_args):
try:
main()
except SystemExit:
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/speech_recognition/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.import os
from typing import Any
from unittest import mock
from unittest.mock import patch

import numpy as np
import pytest
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_modules_to_freeze():


@pytest.mark.skipif(not _TOPIC_SERVE_AVAILABLE, reason="serve libraries aren't installed.")
@mock.patch("flash._IS_TESTING", True)
@patch("flash._IS_TESTING", True)
def test_serve():
model = SpeechRecognition(backbone=TEST_BACKBONE)
model.eval()
Expand Down
54 changes: 27 additions & 27 deletions tests/core/data/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from unittest.mock import ANY, MagicMock, call, patch

import pytest
import torch
Expand All @@ -26,12 +26,12 @@


@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
@mock.patch("pickle.dumps") # need to mock pickle or we get pickle error
@mock.patch("torch.save") # need to mock torch.save, or we get pickle error
@patch("pickle.dumps") # need to mock pickle or we get pickle error
@patch("torch.save") # need to mock torch.save, or we get pickle error
def test_flash_callback(_, __, tmpdir):
"""Test the callback hook system for fit."""

callback_mock = mock.MagicMock()
callback_mock = MagicMock()

inputs = [(torch.rand(1), torch.rand(1))]
transform = InputTransform()
Expand All @@ -48,10 +48,10 @@ def test_flash_callback(_, __, tmpdir):
_ = next(iter(dm.train_dataloader()))

assert callback_mock.method_calls == [
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
]

class CustomModel(Task):
Expand Down Expand Up @@ -89,23 +89,23 @@ def test_step(self, batch, batch_idx):
trainer.fit(CustomModel(), datamodule=dm)

assert callback_mock.method_calls == [
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
mock.call.on_load_sample(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_collate(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.TRAINING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.TRAINING),
mock.call.on_load_sample(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_sample_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_collate(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform(mock.ANY, RunningStage.VALIDATING),
mock.call.on_per_batch_transform_on_device(mock.ANY, RunningStage.VALIDATING),
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
call.on_load_sample(ANY, RunningStage.VALIDATING),
call.on_per_sample_transform(ANY, RunningStage.VALIDATING),
call.on_collate(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
call.on_load_sample(ANY, RunningStage.TRAINING),
call.on_per_sample_transform(ANY, RunningStage.TRAINING),
call.on_collate(ANY, RunningStage.TRAINING),
call.on_per_batch_transform(ANY, RunningStage.TRAINING),
call.on_per_batch_transform_on_device(ANY, RunningStage.TRAINING),
call.on_load_sample(ANY, RunningStage.VALIDATING),
call.on_per_sample_transform(ANY, RunningStage.VALIDATING),
call.on_collate(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform(ANY, RunningStage.VALIDATING),
call.on_per_batch_transform_on_device(ANY, RunningStage.VALIDATING),
]
6 changes: 3 additions & 3 deletions tests/core/data/test_data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Dict
from unittest import mock
from unittest.mock import MagicMock, NonCallableMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -426,8 +426,8 @@ def validation_step(self, batch, batch_idx):


@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
@pytest.mark.parametrize("sampler, callable", [(mock.MagicMock(), True), (mock.NonCallableMock(), False)])
@mock.patch("flash.core.data.data_module.DataLoader")
@pytest.mark.parametrize("sampler, callable", [(MagicMock(), True), (NonCallableMock(), False)])
@patch("flash.core.data.data_module.DataLoader")
def test_dataloaders_with_sampler(mock_dataloader, sampler, callable):
train_input = TestInput(RunningStage.TRAINING, [1])
datamodule = DataModule(
Expand Down
7 changes: 3 additions & 4 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from itertools import chain
from numbers import Number
from typing import Any, Tuple
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import ANY, MagicMock

import pytest
import pytorch_lightning as pl
Expand Down Expand Up @@ -387,12 +386,12 @@ def test_optimizer_learning_rate():
model = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10), nn.LogSoftmax())

ClassificationTask(model, optimizer="test").configure_optimizers()
mock_optimizer.assert_called_once_with(mock.ANY)
mock_optimizer.assert_called_once_with(ANY)

mock_optimizer.reset_mock()

ClassificationTask(model, optimizer="test", learning_rate=10).configure_optimizers()
mock_optimizer.assert_called_once_with(mock.ANY, lr=10)
mock_optimizer.assert_called_once_with(ANY, lr=10)

mock_optimizer.reset_mock()

Expand Down
Loading

0 comments on commit d3421c0

Please sign in to comment.