Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
gchesnokov committed Jul 6, 2021
1 parent 1471231 commit 49a96f6
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 7 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.egg-info
__pycache__/
.ipynb_checkpoints/
.vscode
.venv
.vscode
File renamed without changes.
5 changes: 2 additions & 3 deletions sdc/ysdc_dataset_api/evaluation/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import numpy as np
import torch
from scipy.special import softmax

from sdc.constants import VALID_BASE_METRICS, VALID_AGGREGATORS


def average_displacement_error(ground_truth, predicted):
"""Calculates average displacement error
r"""Calculates average displacement error
ADE(y) = (1/T) \sum_{t=1}^T || s_t - s^*_t ||_2
where T = num_timesteps, y = (s_1, ..., s_T)
Expand Down Expand Up @@ -223,7 +222,7 @@ def average_displacement_error_torch(
ground_truth: torch.Tensor,
predicted: torch.Tensor,
) -> torch.Tensor:
"""Calculates average displacement error
r"""Calculates average displacement error
ADE(y) = (1/T) \sum_{t=1}^T || s_t - s^*_t ||_2
where T = num_timesteps, y = (s_1, ..., s_T)
Expand Down
4 changes: 2 additions & 2 deletions sdc/ysdc_dataset_api/features/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _get_fm_values(self, tracks, transform):
values.append(self._get_acceleration_values(tracks, transform))
if 'yaw' in self._config:
values.append(self._get_yaw_values(tracks))
return np.concatenate(values, axis=1, dtype=np.float64)
return np.concatenate(values, axis=1).astype(np.float64)


class PedestrianTracksRenderer(TrackRendererBase):
Expand All @@ -209,7 +209,7 @@ def _get_fm_values(self, tracks, transform):
values.append(self._get_occupancy_values(tracks))
if 'velocity' in self._config:
values.append(self._get_velocity_values(tracks, transform))
return np.concatenate(values, axis=1, dtype=np.float64)
return np.concatenate(values, axis=1).astype(dtype=np.float64)


class RoadGraphRenderer(FeatureMapRendererBase):
Expand Down
1 change: 0 additions & 1 deletion sdc/ysdc_dataset_api/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
get_trajectories_weights_arrays, min_ade, min_fde,
top1_ade, top1_fde, trajectory_array_to_proto,
weighted_ade, weighted_fde)
from ..evaluation.metrics import _softmax_normalize
from ..proto import ObjectPrediction, Submission, WeightedTrajectory


Expand Down

0 comments on commit 49a96f6

Please sign in to comment.