diff --git a/MANIFEST.in b/MANIFEST.in index 60c4b158..ce4e1f0e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ -include deepforest/data/deepforest_config.yml +include deepforest_config.yml include deepforest/data/testfile_deepforest.csv include deepforest/data/testfile_multi.csv include deepforest/data/classes.csv diff --git a/deepforest/data/deepforest_config.yml b/deepforest/data/deepforest_config.yml deleted file mode 100644 index 0a541578..00000000 --- a/deepforest/data/deepforest_config.yml +++ /dev/null @@ -1,60 +0,0 @@ -# Config file for DeepForest pytorch module - -# Cpu workers for data loaders -# Dataloaders -workers: 1 -devices: auto -accelerator: auto -batch_size: 1 - -# Model Architecture -architecture: 'retinanet' -num_classes: 1 -nms_thresh: 0.05 - -# Architecture specific params -retinanet: - # Non-max supression of overlapping predictions - score_thresh: 0.1 - -train: - csv_file: - root_dir: - - # Optimizer initial learning rate - lr: 0.001 - scheduler: - type: - params: - # Common parameters - T_max: 10 - eta_min: 0.00001 - lr_lambda: "lambda epoch: 0.95 ** epoch" # For lambdaLR and multiplicativeLR - step_size: 30 # For stepLR - gamma: 0.1 # For stepLR, multistepLR, and exponentialLR - milestones: [50, 100] # For multistepLR - - # ReduceLROnPlateau parameters (used if type is not explicitly mentioned) - mode: "min" - factor: 0.1 - patience: 10 - threshold: 0.0001 - threshold_mode: "rel" - cooldown: 0 - min_lr: 0 - eps: 1e-08 - - # Print loss every n epochs - epochs: 1 - # Useful debugging flag in pytorch lightning, set to True to get a single batch of training to test settings. - fast_dev_run: False - # pin images to GPU memory for fast training. This depends on GPU size and number of images. - preload_images: False - -validation: - # callback args - csv_file: - root_dir: - # Intersection over union evaluation - iou_threshold: 0.4 - val_accuracy_interval: 20 diff --git a/tests/test_data.py b/tests/test_data.py index b978a96c..dec50e40 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -13,9 +13,3 @@ def test_get_data(): assert os.path.exists(deepforest.get_data("OSBS_029.tif")) assert os.path.exists(deepforest.get_data("SOAP_061.png")) assert os.path.exists(deepforest.get_data("classes.csv")) - -# Assert that the included config file matches the front of the repo. -def test_matching_config(ROOT): - config = read_config("{}/deepforest_config.yml".format(os.path.dirname(ROOT))) - config_from_data_dir = read_config("{}/data/deepforest_config.yml".format(ROOT)) - assert config == config_from_data_dir \ No newline at end of file