diff --git a/flox/core.py b/flox/core.py
index 1c9599a1..e082c7e2 100644
--- a/flox/core.py
+++ b/flox/core.py
@@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:
 
 
 def _is_first_last_reduction(func: T_Agg) -> bool:
-    return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
+    if isinstance(func, Aggregation):
+        func = func.name
+    return func in ["nanfirst", "nanlast", "first", "last"]
 
 
 def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
@@ -1642,7 +1644,12 @@ def dask_groupby_agg(
     #       This allows us to discover groups at compute time, support argreductions, lower intermediate
     #       memory usage (but method="cohorts" would also work to reduce memory in some cases)
     labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
-    do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
+    do_grouped_combine = (
+        _is_arg_reduction(agg)
+        or labels_are_unknown
+        or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
+    )
+    do_simple_combine = not do_grouped_combine
 
     if method == "blockwise":
         #  use the "non dask" code path, but applied blockwise
@@ -1986,8 +1993,13 @@ def _validate_reindex(
     expected_groups,
     any_by_dask: bool,
     is_dask_array: bool,
+    array_dtype: Any,
 ) -> bool | None:
     # logger.debug("Entering _validate_reindex: reindex is {}".format(reindex))  # noqa
+    def first_or_last():
+        return func in ["first", "last"] or (
+            _is_first_last_reduction(func) and array_dtype.kind != "f"
+        )
 
     all_numpy = not is_dask_array and not any_by_dask
     if reindex is True and not all_numpy:
@@ -1997,7 +2009,7 @@ def _validate_reindex(
             raise ValueError(
                 "reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
             )
-        if func in ["first", "last"]:
+        if first_or_last():
             raise ValueError("reindex must be None or False when func is 'first' or 'last.")
 
     if reindex is None:
@@ -2008,9 +2020,10 @@ def _validate_reindex(
         if all_numpy:
             return True
 
-        if func in ["first", "last"]:
+        if first_or_last():
             # have to do the grouped_combine since there's no good fill_value
-            reindex = False
+            # Also needed for nanfirst, nanlast with no-NaN dtypes
+            return False
 
         if method == "blockwise":
             # for grouping by dask arrays, we set reindex=True
@@ -2413,7 +2426,13 @@ def groupby_reduce(
         raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
 
     reindex = _validate_reindex(
-        reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
+        reindex,
+        func,
+        method,
+        expected_groups,
+        any_by_dask,
+        is_duck_dask_array(array),
+        array.dtype,
     )
 
     if not is_duck_array(array):
@@ -2601,7 +2620,7 @@ def groupby_reduce(
 
         # TODO: clean this up
         reindex = _validate_reindex(
-            reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
+            reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
         )
 
         if TYPE_CHECKING:
diff --git a/tests/test_core.py b/tests/test_core.py
index e12e695d..22864a05 100644
--- a/tests/test_core.py
+++ b/tests/test_core.py
@@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
         )
 
 
+@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]])
+@pytest.mark.parametrize(
+    "func",
+    [
+        # "first", "last",
+        "nanfirst",
+        "nanlast",
+    ],
+)
+@pytest.mark.parametrize(
+    "chunks",
+    [
+        None,
+        pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
+        pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
+        pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
+    ],
+)
+def test_first_last_useless(func, chunks, group_idx):
+    array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8)
+    if chunks is not None:
+        array = dask.array.from_array(array, chunks=chunks)
+    actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy")
+    expected = np.array([[0, 0], [0, 0]], dtype=np.int8)
+    assert_equal(actual, expected)
+
+
 @pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
 @pytest.mark.parametrize("axis", [(0, 1)])
 def test_first_last_disallowed(axis, func):