diff --git a/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx b/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx
index cdf6e98ec0c64..6969b9da4cf94 100644
--- a/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx
+++ b/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx
@@ -1007,7 +1007,7 @@ describe('', () => {
describe('range selection', () => {
it('keyboard arrow', () => {
const { getByTestId, queryAllByRole, getByText } = render(
-
+
@@ -1085,7 +1085,7 @@ describe('', () => {
it('keyboard arrow merge', () => {
const { getByTestId, getByText, queryAllByRole } = render(
-
+
@@ -1207,7 +1207,7 @@ describe('', () => {
expect(getByTestId('eight')).to.have.attribute('aria-selected', 'true');
expect(getByTestId('nine')).to.have.attribute('aria-selected', 'true');
- fireEvent.keyDown(getByTestId('nine'), {
+ fireEvent.keyDown(getByTestId('five'), {
key: 'Home',
shiftKey: true,
ctrlKey: true,
diff --git a/packages/x-tree-view/src/TreeItem/useTreeItemState.ts b/packages/x-tree-view/src/TreeItem/useTreeItemState.ts
index 0c2442d3ec15c..d43fcd4f2c4aa 100644
--- a/packages/x-tree-view/src/TreeItem/useTreeItemState.ts
+++ b/packages/x-tree-view/src/TreeItem/useTreeItemState.ts
@@ -39,7 +39,7 @@ export function useTreeItemState(itemId: string) {
if (multiple) {
if (event.shiftKey) {
- instance.selectRange(event, { end: itemId });
+ instance.expandSelectionRange(event, itemId);
} else {
instance.selectItem(event, itemId, true);
}
diff --git a/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx b/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx
index 87ed49fff4dc0..c8c67baf0f9e1 100644
--- a/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx
+++ b/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx
@@ -63,7 +63,7 @@ export const useTreeItem2Utils = ({
if (multiple) {
if (event.shiftKey) {
- instance.selectRange(event, { end: itemId });
+ instance.expandSelectionRange(event, itemId);
} else {
instance.selectItem(event, itemId, true);
}
diff --git a/packages/x-tree-view/src/internals/models/treeView.ts b/packages/x-tree-view/src/internals/models/treeView.ts
index fa6cca1edb602..552dd29bf7e2c 100644
--- a/packages/x-tree-view/src/internals/models/treeView.ts
+++ b/packages/x-tree-view/src/internals/models/treeView.ts
@@ -13,13 +13,6 @@ export interface TreeViewItemMeta {
label?: string;
}
-export interface TreeViewItemRange {
- start?: string | null;
- end?: string | null;
- next?: string | null;
- current?: string;
-}
-
export interface TreeViewModel {
name: string;
value: TValue;
diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts
index bcf20f45487f5..8ed5436915f4f 100644
--- a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts
+++ b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts
@@ -127,7 +127,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
case key === ' ' && canToggleItemSelection(itemId): {
event.preventDefault();
if (params.multiSelect && event.shiftKey) {
- instance.selectRange(event, { end: itemId });
+ instance.expandSelectionRange(event, itemId);
} else if (params.multiSelect) {
instance.selectItem(event, itemId, true);
} else {
@@ -165,14 +165,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
// Multi select behavior when pressing Shift + ArrowDown
// Toggles the selection state of the next item
if (params.multiSelect && event.shiftKey && canToggleItemSelection(nextItem)) {
- instance.selectRange(
- event,
- {
- end: nextItem,
- current: itemId,
- },
- true,
- );
+ instance.selectItemFromArrowNavigation(event, itemId, nextItem);
}
}
@@ -189,14 +182,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
// Multi select behavior when pressing Shift + ArrowUp
// Toggles the selection state of the previous item
if (params.multiSelect && event.shiftKey && canToggleItemSelection(previousItem)) {
- instance.selectRange(
- event,
- {
- end: previousItem,
- current: itemId,
- },
- true,
- );
+ instance.selectItemFromArrowNavigation(event, itemId, previousItem);
}
}
@@ -239,12 +225,12 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
// Focuses the first item in the tree
case key === 'Home': {
- instance.focusItem(event, getFirstNavigableItem(instance));
-
// Multi select behavior when pressing Ctrl + Shift + Home
// Selects the focused item and all items up to the first item.
if (canToggleItemSelection(itemId) && params.multiSelect && ctrlPressed && event.shiftKey) {
- instance.rangeSelectToFirst(event, itemId);
+ instance.selectRangeFromStartToItem(event, itemId);
+ } else {
+ instance.focusItem(event, getFirstNavigableItem(instance));
}
event.preventDefault();
@@ -253,12 +239,12 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
// Focuses the last item in the tree
case key === 'End': {
- instance.focusItem(event, getLastNavigableItem(instance));
-
// Multi select behavior when pressing Ctrl + Shirt + End
// Selects the focused item and all the items down to the last item.
if (canToggleItemSelection(itemId) && params.multiSelect && ctrlPressed && event.shiftKey) {
- instance.rangeSelectToLast(event, itemId);
+ instance.selectRangeFromItemToEnd(event, itemId);
+ } else {
+ instance.focusItem(event, getLastNavigableItem(instance));
}
event.preventDefault();
@@ -275,10 +261,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin<
// Multi select behavior when pressing Ctrl + a
// Selects all the items
case key === 'a' && ctrlPressed && params.multiSelect && !params.disableSelection: {
- instance.selectRange(event, {
- start: getFirstNavigableItem(instance),
- end: getLastNavigableItem(instance),
- });
+ instance.selectAllNavigableItems(event);
event.preventDefault();
break;
}
diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts
index bda0ed53bc1ad..02fa885cb08a9 100644
--- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts
+++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts
@@ -1,11 +1,15 @@
import * as React from 'react';
-import { TreeViewPlugin, TreeViewItemRange } from '../../models';
+import { TreeViewPlugin } from '../../models';
+import { TreeViewItemId } from '../../../models';
import {
+ findOrderInTremauxTree,
+ getAllNavigableItems,
getFirstNavigableItem,
getLastNavigableItem,
- getNavigableItemsInRange,
+ getNonDisabledItemsInRange,
} from '../../utils/tree';
import { UseTreeViewSelectionSignature } from './useTreeViewSelection.types';
+import { convertSelectedItemsToArray, getLookupFromArray } from './useTreeViewSelection.utils';
export const useTreeViewSelection: TreeViewPlugin = ({
instance,
@@ -13,8 +17,20 @@ export const useTreeViewSelection: TreeViewPlugin
models,
}) => {
const lastSelectedItem = React.useRef(null);
- const lastSelectionWasRange = React.useRef(false);
- const currentRangeSelection = React.useRef([]);
+ const lastSelectedRange = React.useRef<{ [itemId: string]: boolean }>({});
+
+ const selectedItemsMap = React.useMemo(() => {
+ const temp = new Map();
+ if (Array.isArray(models.selectedItems.value)) {
+ models.selectedItems.value.forEach((id) => {
+ temp.set(id, true);
+ });
+ } else if (models.selectedItems.value != null) {
+ temp.set(models.selectedItems.value, true);
+ }
+
+ return temp;
+ }, [models.selectedItems.value]);
const setSelectedItems = (
event: React.SyntheticEvent,
@@ -53,122 +69,108 @@ export const useTreeViewSelection: TreeViewPlugin
models.selectedItems.setControlledValue(newSelectedItems);
};
- const isItemSelected = (itemId: string) =>
- Array.isArray(models.selectedItems.value)
- ? models.selectedItems.value.indexOf(itemId) !== -1
- : models.selectedItems.value === itemId;
+ const isItemSelected = (itemId: string) => selectedItemsMap.has(itemId);
const selectItem = (event: React.SyntheticEvent, itemId: string, multiple = false) => {
if (params.disableSelection) {
return;
}
+ let newSelected: typeof models.selectedItems.value;
if (multiple) {
- if (Array.isArray(models.selectedItems.value)) {
- let newSelected: string[];
- if (models.selectedItems.value.indexOf(itemId) !== -1) {
- newSelected = models.selectedItems.value.filter((id) => id !== itemId);
- } else {
- newSelected = [itemId].concat(models.selectedItems.value);
- }
-
- setSelectedItems(event, newSelected);
+ const cleanSelectedItems = convertSelectedItemsToArray(models.selectedItems.value);
+ if (instance.isItemSelected(itemId)) {
+ newSelected = cleanSelectedItems.filter((id) => id !== itemId);
+ } else {
+ newSelected = [itemId].concat(cleanSelectedItems);
}
} else {
- const newSelected = params.multiSelect ? [itemId] : itemId;
- setSelectedItems(event, newSelected);
+ newSelected = params.multiSelect ? [itemId] : itemId;
}
+
+ setSelectedItems(event, newSelected);
lastSelectedItem.current = itemId;
- lastSelectionWasRange.current = false;
- currentRangeSelection.current = [];
+ lastSelectedRange.current = {};
};
- const handleRangeArrowSelect = (event: React.SyntheticEvent, items: TreeViewItemRange) => {
- let base = (models.selectedItems.value as string[]).slice();
- const { start, next, current } = items;
-
- if (!next || !current) {
+ const selectRange = (event: React.SyntheticEvent, [start, end]: [string, string]) => {
+ if (params.disableSelection || !params.multiSelect) {
return;
}
- if (currentRangeSelection.current.indexOf(current) === -1) {
- currentRangeSelection.current = [];
- }
+ let newSelectedItems = convertSelectedItemsToArray(models.selectedItems.value).slice();
- if (lastSelectionWasRange.current) {
- if (currentRangeSelection.current.indexOf(next) !== -1) {
- base = base.filter((id) => id === start || id !== current);
- currentRangeSelection.current = currentRangeSelection.current.filter(
- (id) => id === start || id !== current,
- );
- } else {
- base.push(next);
- currentRangeSelection.current.push(next);
- }
- } else {
- base.push(next);
- currentRangeSelection.current.push(current, next);
+ // If the last selection was a range selection,
+ // remove the items that were part of the last range from the model
+ if (Object.keys(lastSelectedRange.current).length > 0) {
+ newSelectedItems = newSelectedItems.filter((id) => !lastSelectedRange.current[id]);
}
- setSelectedItems(event, base);
+
+ // Add to the model the items that are part of the new range and not already part of the model.
+ const selectedItemsLookup = getLookupFromArray(newSelectedItems);
+ const range = getNonDisabledItemsInRange(instance, start, end);
+ const itemsToAddToModel = range.filter((id) => !selectedItemsLookup[id]);
+ newSelectedItems = newSelectedItems.concat(itemsToAddToModel);
+
+ setSelectedItems(event, newSelectedItems);
+ lastSelectedRange.current = getLookupFromArray(range);
};
- const handleRangeSelect = (
- event: React.SyntheticEvent,
- items: { start: string; end: string },
- ) => {
- let base = (models.selectedItems.value as string[]).slice();
- const { start, end } = items;
- // If last selection was a range selection ignore items that were selected.
- if (lastSelectionWasRange.current) {
- base = base.filter((id) => currentRangeSelection.current.indexOf(id) === -1);
+ const expandSelectionRange = (event: React.SyntheticEvent, itemId: string) => {
+ if (lastSelectedItem.current != null) {
+ const [start, end] = findOrderInTremauxTree(instance, itemId, lastSelectedItem.current);
+ selectRange(event, [start, end]);
}
+ };
- let range = getNavigableItemsInRange(instance, start, end);
- range = range.filter((item) => !instance.isItemDisabled(item));
- currentRangeSelection.current = range;
- let newSelected = base.concat(range);
- newSelected = newSelected.filter((id, i) => newSelected.indexOf(id) === i);
- setSelectedItems(event, newSelected);
+ const selectRangeFromStartToItem = (event: React.SyntheticEvent, itemId: string) => {
+ selectRange(event, [getFirstNavigableItem(instance), itemId]);
};
- const selectRange = (event: React.SyntheticEvent, items: TreeViewItemRange, stacked = false) => {
- if (params.disableSelection) {
+ const selectRangeFromItemToEnd = (event: React.SyntheticEvent, itemId: string) => {
+ selectRange(event, [itemId, getLastNavigableItem(instance)]);
+ };
+
+ const selectAllNavigableItems = (event: React.SyntheticEvent) => {
+ if (params.disableSelection || !params.multiSelect) {
return;
}
- const { start = lastSelectedItem.current, end, current } = items;
- if (stacked) {
- handleRangeArrowSelect(event, { start, next: end, current });
- } else if (start != null && end != null) {
- handleRangeSelect(event, { start, end });
- }
- lastSelectionWasRange.current = true;
+ const navigableItems = getAllNavigableItems(instance);
+ setSelectedItems(event, navigableItems);
+
+ lastSelectedRange.current = getLookupFromArray(navigableItems);
};
- const rangeSelectToFirst = (event: React.KeyboardEvent, itemId: string) => {
- if (!lastSelectedItem.current) {
- lastSelectedItem.current = itemId;
+ const selectItemFromArrowNavigation = (
+ event: React.SyntheticEvent,
+ currentItem: string,
+ nextItem: string,
+ ) => {
+ if (params.disableSelection || !params.multiSelect) {
+ return;
}
- const start = lastSelectionWasRange.current ? lastSelectedItem.current : itemId;
+ let newSelectedItems = convertSelectedItemsToArray(models.selectedItems.value).slice();
- instance.selectRange(event, {
- start,
- end: getFirstNavigableItem(instance),
- });
- };
+ if (Object.keys(lastSelectedRange.current).length === 0) {
+ newSelectedItems.push(nextItem);
+ lastSelectedRange.current = { [currentItem]: true, [nextItem]: true };
+ } else {
+ if (!lastSelectedRange.current[currentItem]) {
+ lastSelectedRange.current = {};
+ }
- const rangeSelectToLast = (event: React.KeyboardEvent, itemId: string) => {
- if (!lastSelectedItem.current) {
- lastSelectedItem.current = itemId;
+ if (lastSelectedRange.current[nextItem]) {
+ newSelectedItems = newSelectedItems.filter((id) => id !== currentItem);
+ delete lastSelectedRange.current[currentItem];
+ } else {
+ newSelectedItems.push(nextItem);
+ lastSelectedRange.current[nextItem] = true;
+ }
}
- const start = lastSelectionWasRange.current ? lastSelectedItem.current : itemId;
-
- instance.selectRange(event, {
- start,
- end: getLastNavigableItem(instance),
- });
+ setSelectedItems(event, newSelectedItems);
};
return {
@@ -178,9 +180,11 @@ export const useTreeViewSelection: TreeViewPlugin
instance: {
isItemSelected,
selectItem,
- selectRange,
- rangeSelectToLast,
- rangeSelectToFirst,
+ selectAllNavigableItems,
+ expandSelectionRange,
+ selectRangeFromStartToItem,
+ selectRangeFromItemToEnd,
+ selectItemFromArrowNavigation,
},
contextValue: {
selection: {
diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts
index 474a7b8f82e44..99c8e8d12d99c 100644
--- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts
+++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts
@@ -1,14 +1,49 @@
import * as React from 'react';
-import type { DefaultizedProps, TreeViewItemRange, TreeViewPluginSignature } from '../../models';
+import type { DefaultizedProps, TreeViewPluginSignature } from '../../models';
import { UseTreeViewItemsSignature } from '../useTreeViewItems';
import { UseTreeViewExpansionSignature } from '../useTreeViewExpansion';
export interface UseTreeViewSelectionInstance {
isItemSelected: (itemId: string) => boolean;
- selectItem: (event: React.SyntheticEvent, itemId: string, multiple?: boolean) => void;
- selectRange: (event: React.SyntheticEvent, items: TreeViewItemRange, stacked?: boolean) => void;
- rangeSelectToFirst: (event: React.KeyboardEvent, itemId: string) => void;
- rangeSelectToLast: (event: React.KeyboardEvent, itemId: string) => void;
+ selectItem: (
+ event: React.SyntheticEvent,
+ itemId: string,
+ keepExistingSelection?: boolean,
+ ) => void;
+ /**
+ * Select all the navigable items in the tree.
+ * @param {React.SyntheticEvent} event The event source of the callback.
+ */
+ selectAllNavigableItems: (event: React.SyntheticEvent) => void;
+ /**
+ * Expand the current selection range up to the given item.
+ * @param {React.SyntheticEvent} event The event source of the callback.
+ * @param {string} itemId The id of the item to expand the selection to.
+ */
+ expandSelectionRange: (event: React.SyntheticEvent, itemId: string) => void;
+ /**
+ * Expand the current selection range from the first navigable item to the given item.
+ * @param {React.SyntheticEvent} event The event source of the callback.
+ * @param {string} itemId The id of the item up to which the selection range should be expanded.
+ */
+ selectRangeFromStartToItem: (event: React.SyntheticEvent, itemId: string) => void;
+ /**
+ * Expand the current selection range from the given item to the last navigable item.
+ * @param {React.SyntheticEvent} event The event source of the callback.
+ * @param {string} itemId The id of the item from which the selection range should be expanded.
+ */
+ selectRangeFromItemToEnd: (event: React.SyntheticEvent, itemId: string) => void;
+ /**
+ * Update the selection when navigating with ArrowUp / ArrowDown keys.
+ * @param {React.SyntheticEvent} event The event source of the callback.
+ * @param {string} currentItemId The id of the active item before the keyboard navigation.
+ * @param {string} nextItemId The id of the active item after the keyboard navigation.
+ */
+ selectItemFromArrowNavigation: (
+ event: React.SyntheticEvent,
+ currentItemId: string,
+ nextItemId: string,
+ ) => void;
}
type TreeViewSelectionValue = Multiple extends true
diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts
new file mode 100644
index 0000000000000..bb022e13338c6
--- /dev/null
+++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts
@@ -0,0 +1,24 @@
+/**
+ * Transform the `selectedItems` model to be an array if it was a string or null.
+ * @param {string[] | string | null} model The raw model.
+ * @returns {string[]} The converted model.
+ */
+export const convertSelectedItemsToArray = (model: string[] | string | null): string[] => {
+ if (Array.isArray(model)) {
+ return model;
+ }
+
+ if (model != null) {
+ return [model];
+ }
+
+ return [];
+};
+
+export const getLookupFromArray = (array: string[]) => {
+ const lookup: { [itemId: string]: boolean } = {};
+ array.forEach((itemId) => {
+ lookup[itemId] = true;
+ });
+ return lookup;
+};
diff --git a/packages/x-tree-view/src/internals/utils/tree.ts b/packages/x-tree-view/src/internals/utils/tree.ts
index d3938323ca60e..1293374356c05 100644
--- a/packages/x-tree-view/src/internals/utils/tree.ts
+++ b/packages/x-tree-view/src/internals/utils/tree.ts
@@ -124,7 +124,7 @@ export const getFirstNavigableItem = (instance: TreeViewInstance<[UseTreeViewIte
* Another way to put it is which item is shallower in a trémaux tree
* https://en.wikipedia.org/wiki/Tr%C3%A9maux_tree
*/
-const findOrderInTremauxTree = (
+export const findOrderInTremauxTree = (
instance: TreeViewInstance<[UseTreeViewItemsSignature]>,
itemAId: string,
itemBId: string,
@@ -185,20 +185,57 @@ const findOrderInTremauxTree = (
: [itemBId, itemAId];
};
-export const getNavigableItemsInRange = (
+export const getNonDisabledItemsInRange = (
instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>,
itemAId: string,
itemBId: string,
) => {
+ const getNextItem = (itemId: string) => {
+ // If the item is expanded and has some children, return the first of them.
+ if (instance.isItemExpandable(itemId) && instance.isItemExpanded(itemId)) {
+ return instance.getItemOrderedChildrenIds(itemId)[0];
+ }
+
+ let itemMeta = instance.getItemMeta(itemId);
+ while (itemMeta != null) {
+ // Try to find the first navigable sibling after the current item.
+ const siblings = instance.getItemOrderedChildrenIds(itemMeta.parentId);
+ const currentItemIndex = instance.getItemIndex(itemMeta.id);
+
+ if (currentItemIndex < siblings.length - 1) {
+ return siblings[currentItemIndex + 1];
+ }
+
+ // If the item is the last of its siblings, go up a level to the parent and try again.
+ itemMeta = instance.getItemMeta(itemMeta.parentId!);
+ }
+
+ throw new Error('Invalid range');
+ };
+
const [first, last] = findOrderInTremauxTree(instance, itemAId, itemBId);
const items = [first];
-
let current = first;
while (current !== last) {
- current = getNextNavigableItem(instance, current)!;
- items.push(current);
+ current = getNextItem(current);
+ if (!instance.isItemDisabled(current)) {
+ items.push(current);
+ }
}
return items;
};
+
+export const getAllNavigableItems = (
+ instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>,
+) => {
+ let item: string | null = getFirstNavigableItem(instance);
+ const navigableItems: string[] = [];
+ while (item != null) {
+ navigableItems.push(item);
+ item = getNextNavigableItem(instance, item);
+ }
+
+ return navigableItems;
+};