From 2f20ca6b6774a9983afbb9e542dee3f5c5c97b93 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Cl=C3=A9ment?=
 <49512274+ludwigVonKoopa@users.noreply.github.com>
Date: Mon, 21 Jun 2021 22:04:59 +0200
Subject: [PATCH] create particles in class method

---
 src/py_eddy_tracker/observations/groups.py    | 40 +------------------
 .../observations/observation.py               | 32 +++++++++++++++
 2 files changed, 34 insertions(+), 38 deletions(-)

diff --git a/src/py_eddy_tracker/observations/groups.py b/src/py_eddy_tracker/observations/groups.py
index 3d028e12..64a81a36 100644
--- a/src/py_eddy_tracker/observations/groups.py
+++ b/src/py_eddy_tracker/observations/groups.py
@@ -3,9 +3,8 @@
 
 from numba import njit
 from numba import types as nb_types
-from numpy import arange, array, int32, interp, median, where, zeros
+from numpy import arange, int32, interp, median, where, zeros
 
-from ..poly import create_vertice, reduce_size, winding_number_poly
 from .observation import EddiesObservations
 
 logger = logging.getLogger("pet")
@@ -88,41 +87,6 @@ def advect(x, y, c, t0, n_days):
     return t, x, y
 
 
-@njit(cache=True)
-def _create_meshed_particles(lons, lats, step):
-    x_out, y_out, i_out = list(), list(), list()
-    for i, (lon, lat) in enumerate(zip(lons, lats)):
-        lon_min, lon_max = lon.min(), lon.max()
-        lat_min, lat_max = lat.min(), lat.max()
-        lon_min -= lon_min % step
-        lon_max -= lon_max % step - step * 2
-        lat_min -= lat_min % step
-        lat_max -= lat_max % step - step * 2
-
-        for x in arange(lon_min, lon_max, step):
-            for y in arange(lat_min, lat_max, step):
-                if winding_number_poly(x, y, create_vertice(*reduce_size(lon, lat))):
-                    x_out.append(x), y_out.append(y), i_out.append(i)
-    return array(x_out), array(y_out), array(i_out)
-
-
-def create_particles(eddies, step):
-    """create particles only inside speed contour. Avoid creating too large numpy arrays, only to me masked
-
-    :param eddies: network where eddies are
-    :type eddies: network
-    :param step: step for particles
-    :type step: float
-    :return: lon, lat and indices of particles in contour speed
-    :rtype: tuple(np.array)
-    """
-
-    lon = eddies.contour_lon_s
-    lat = eddies.contour_lat_s
-
-    return _create_meshed_particles(lon, lat, step)
-
-
 def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs):
     """Select particles within eddies, advect them, return target observation and associated percentages
 
@@ -141,7 +105,7 @@ def particle_candidate(c, eddies, step_mesh, t_start, i_target, pct, **kwargs):
     # to be able to get global index
     translate_start = where(m_start)[0]
 
-    x, y, i_start = create_particles(e, step_mesh)
+    x, y, i_start = e.create_particles(step_mesh)
 
     # Advection
     t_end, x, y = advect(x, y, c, t_start, **kwargs)
diff --git a/src/py_eddy_tracker/observations/observation.py b/src/py_eddy_tracker/observations/observation.py
index aa73b28d..d969f800 100644
--- a/src/py_eddy_tracker/observations/observation.py
+++ b/src/py_eddy_tracker/observations/observation.py
@@ -69,6 +69,7 @@
     poly_indexs,
     reduce_size,
     vertice_overlap,
+    winding_number_poly
 )
 
 logger = logging.getLogger("pet")
@@ -2274,6 +2275,19 @@ def nb_days(self):
         """
         return self.period[1] - self.period[0] + 1
 
+    def create_particles(self, step, intern=True):
+        """create particles only inside speed contour. Avoid creating too large numpy arrays, only to me masked
+
+        :param step: step for particles
+        :type step: float
+        :param bool intern: If true use speed contour instead of effective contour
+        :return: lon, lat and indices of particles
+        :rtype: tuple(np.array)
+        """
+
+        xname, yname = self.intern(intern)
+        return _create_meshed_particles(self[xname], self[yname], step)
+
 
 @njit(cache=True)
 def grid_count_(grid, i, j):
@@ -2430,6 +2444,24 @@ def grid_stat(x_c, y_c, grid, x, y, result, circular=False, method="mean"):
             result[elt] = v_max
 
 
+@njit(cache=True)
+def _create_meshed_particles(lons, lats, step):
+    x_out, y_out, i_out = list(), list(), list()
+    for i, (lon, lat) in enumerate(zip(lons, lats)):
+        lon_min, lon_max = lon.min(), lon.max()
+        lat_min, lat_max = lat.min(), lat.max()
+        lon_min -= lon_min % step
+        lon_max -= lon_max % step - step * 2
+        lat_min -= lat_min % step
+        lat_max -= lat_max % step - step * 2
+
+        for x in arange(lon_min, lon_max, step):
+            for y in arange(lat_min, lat_max, step):
+                if winding_number_poly(x, y, create_vertice(*reduce_size(lon, lat))):
+                    x_out.append(x), y_out.append(y), i_out.append(i)
+    return array(x_out), array(y_out), array(i_out)
+
+
 class VirtualEddiesObservations(EddiesObservations):
     """Class to work with virtual obs"""