Skip to content

Commit

Permalink
#2 fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
MrGolden1 committed Apr 23, 2022
1 parent d53200d commit 7a192be
Showing 1 changed file with 46 additions and 28 deletions.
74 changes: 46 additions & 28 deletions sort/src/Py_SORT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ static PyObject *Py_SORT_run(Py_SORT *self, PyObject *args)
return NULL;
}

// check if format is valid
if (format < 0 || format > 2)
{
PyErr_SetString(PyExc_ValueError, "Format must be 0, 1 or 2");
return NULL;
}

// get dtype
PyArrayObject *array = (PyArrayObject *)py_array;
int dtype = array->descr->type_num;
Expand Down Expand Up @@ -178,38 +185,43 @@ static PyObject *Py_SORT_run(Py_SORT *self, PyObject *args)
rects.reserve(n);

int xmin, ymin, width, height;
for (int i = 0; i < n; i++)
switch (format)
{
npy_intp *indices = (npy_intp *)PyArray_GETPTR2(array, i, 0);
switch (format)
case 0:
for (int i = 0; i < n; i++)
{
case 0: // [xmin, ymin, w, h]
xmin = *indices++;
ymin = *indices++;
width = *indices++;
height = *indices++;
break;
case 1: // [xcenter, ycenter, w, h]
xmin = *indices++;
ymin = *indices++;
width = *indices++;
height = *indices++;
xmin -= width / 2;
ymin -= height / 2;
break;
case 2: // [xmin, ymin, xmax, ymax]
xmin = *indices++;
ymin = *indices++;
width = *indices++ - xmin;
height = *indices++ - ymin;
break;
default:
PyErr_SetString(PyExc_TypeError, "Invalid format");
return NULL;
xmin = *(int *)PyArray_GETPTR2(py_array, i, 0);
ymin = *(int *)PyArray_GETPTR2(py_array, i, 1);
width = *(int *)PyArray_GETPTR2(py_array, i, 2);
height = *(int *)PyArray_GETPTR2(py_array, i, 3);
rects.push_back(cv::Rect(xmin, ymin, width, height));
}
break;
case 1:
for (int i = 0; i < n; i++)
{
width = *(int *)PyArray_GETPTR2(py_array, i, 2);
height = *(int *)PyArray_GETPTR2(py_array, i, 3);
xmin = *(int *)PyArray_GETPTR2(py_array, i, 0) - width / 2;
ymin = *(int *)PyArray_GETPTR2(py_array, i, 1) - height / 2;
rects.push_back(cv::Rect(xmin, ymin, width, height));
}
break;
case 2:
for (int i = 0; i < n; i++)
{
xmin = *(int *)PyArray_GETPTR2(py_array, i, 0);
ymin = *(int *)PyArray_GETPTR2(py_array, i, 1);
width = *(int *)PyArray_GETPTR2(py_array, i, 2) - xmin;
height = *(int *)PyArray_GETPTR2(py_array, i, 3) - ymin;
rects.push_back(cv::Rect(xmin, ymin, width, height));
}
rects.push_back(cv::Rect(xmin, ymin, width, height));
break;
default:
PyErr_SetString(PyExc_TypeError, "Format must be 0, 1 or 2");
return NULL;
}
// // print rects

// for (int i = 0; i < n; i++)
// {
// printf("%d: [%d, %d, %d, %d]\n", i, rects[i].x, rects[i].y, rects[i].width, rects[i].height);
Expand Down Expand Up @@ -239,6 +251,12 @@ static PyObject *Py_SORT_get_tracks(Py_SORT *self, PyObject *args)
int format = 0; // default format
if (!PyArg_ParseTuple(args, "|i", &format))
return NULL;
// check if format is valid
if (format < 0 || format > 2)
{
PyErr_SetString(PyExc_ValueError, "Format must be 0, 1 or 2");
return NULL;
}

// convert tracks to numpy array
int hited = 0;
Expand Down

0 comments on commit 7a192be

Please sign in to comment.