diff --git a/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx b/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx index cdf6e98ec0c64..6969b9da4cf94 100644 --- a/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx +++ b/packages/x-tree-view/src/TreeItem/TreeItem.test.tsx @@ -1007,7 +1007,7 @@ describe('', () => { describe('range selection', () => { it('keyboard arrow', () => { const { getByTestId, queryAllByRole, getByText } = render( - + @@ -1085,7 +1085,7 @@ describe('', () => { it('keyboard arrow merge', () => { const { getByTestId, getByText, queryAllByRole } = render( - + @@ -1207,7 +1207,7 @@ describe('', () => { expect(getByTestId('eight')).to.have.attribute('aria-selected', 'true'); expect(getByTestId('nine')).to.have.attribute('aria-selected', 'true'); - fireEvent.keyDown(getByTestId('nine'), { + fireEvent.keyDown(getByTestId('five'), { key: 'Home', shiftKey: true, ctrlKey: true, diff --git a/packages/x-tree-view/src/TreeItem/useTreeItemState.ts b/packages/x-tree-view/src/TreeItem/useTreeItemState.ts index 0c2442d3ec15c..d43fcd4f2c4aa 100644 --- a/packages/x-tree-view/src/TreeItem/useTreeItemState.ts +++ b/packages/x-tree-view/src/TreeItem/useTreeItemState.ts @@ -39,7 +39,7 @@ export function useTreeItemState(itemId: string) { if (multiple) { if (event.shiftKey) { - instance.selectRange(event, { end: itemId }); + instance.expandSelectionRange(event, itemId); } else { instance.selectItem(event, itemId, true); } diff --git a/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx b/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx index 87ed49fff4dc0..c8c67baf0f9e1 100644 --- a/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx +++ b/packages/x-tree-view/src/hooks/useTreeItem2Utils/useTreeItem2Utils.tsx @@ -63,7 +63,7 @@ export const useTreeItem2Utils = ({ if (multiple) { if (event.shiftKey) { - instance.selectRange(event, { end: itemId }); + instance.expandSelectionRange(event, itemId); } else { instance.selectItem(event, itemId, true); } diff --git a/packages/x-tree-view/src/internals/models/treeView.ts b/packages/x-tree-view/src/internals/models/treeView.ts index fa6cca1edb602..552dd29bf7e2c 100644 --- a/packages/x-tree-view/src/internals/models/treeView.ts +++ b/packages/x-tree-view/src/internals/models/treeView.ts @@ -13,13 +13,6 @@ export interface TreeViewItemMeta { label?: string; } -export interface TreeViewItemRange { - start?: string | null; - end?: string | null; - next?: string | null; - current?: string; -} - export interface TreeViewModel { name: string; value: TValue; diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts index bcf20f45487f5..8ed5436915f4f 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewKeyboardNavigation/useTreeViewKeyboardNavigation.ts @@ -127,7 +127,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< case key === ' ' && canToggleItemSelection(itemId): { event.preventDefault(); if (params.multiSelect && event.shiftKey) { - instance.selectRange(event, { end: itemId }); + instance.expandSelectionRange(event, itemId); } else if (params.multiSelect) { instance.selectItem(event, itemId, true); } else { @@ -165,14 +165,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Multi select behavior when pressing Shift + ArrowDown // Toggles the selection state of the next item if (params.multiSelect && event.shiftKey && canToggleItemSelection(nextItem)) { - instance.selectRange( - event, - { - end: nextItem, - current: itemId, - }, - true, - ); + instance.selectItemFromArrowNavigation(event, itemId, nextItem); } } @@ -189,14 +182,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Multi select behavior when pressing Shift + ArrowUp // Toggles the selection state of the previous item if (params.multiSelect && event.shiftKey && canToggleItemSelection(previousItem)) { - instance.selectRange( - event, - { - end: previousItem, - current: itemId, - }, - true, - ); + instance.selectItemFromArrowNavigation(event, itemId, previousItem); } } @@ -239,12 +225,12 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focuses the first item in the tree case key === 'Home': { - instance.focusItem(event, getFirstNavigableItem(instance)); - // Multi select behavior when pressing Ctrl + Shift + Home // Selects the focused item and all items up to the first item. if (canToggleItemSelection(itemId) && params.multiSelect && ctrlPressed && event.shiftKey) { - instance.rangeSelectToFirst(event, itemId); + instance.selectRangeFromStartToItem(event, itemId); + } else { + instance.focusItem(event, getFirstNavigableItem(instance)); } event.preventDefault(); @@ -253,12 +239,12 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Focuses the last item in the tree case key === 'End': { - instance.focusItem(event, getLastNavigableItem(instance)); - // Multi select behavior when pressing Ctrl + Shirt + End // Selects the focused item and all the items down to the last item. if (canToggleItemSelection(itemId) && params.multiSelect && ctrlPressed && event.shiftKey) { - instance.rangeSelectToLast(event, itemId); + instance.selectRangeFromItemToEnd(event, itemId); + } else { + instance.focusItem(event, getLastNavigableItem(instance)); } event.preventDefault(); @@ -275,10 +261,7 @@ export const useTreeViewKeyboardNavigation: TreeViewPlugin< // Multi select behavior when pressing Ctrl + a // Selects all the items case key === 'a' && ctrlPressed && params.multiSelect && !params.disableSelection: { - instance.selectRange(event, { - start: getFirstNavigableItem(instance), - end: getLastNavigableItem(instance), - }); + instance.selectAllNavigableItems(event); event.preventDefault(); break; } diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts index bda0ed53bc1ad..02fa885cb08a9 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.ts @@ -1,11 +1,15 @@ import * as React from 'react'; -import { TreeViewPlugin, TreeViewItemRange } from '../../models'; +import { TreeViewPlugin } from '../../models'; +import { TreeViewItemId } from '../../../models'; import { + findOrderInTremauxTree, + getAllNavigableItems, getFirstNavigableItem, getLastNavigableItem, - getNavigableItemsInRange, + getNonDisabledItemsInRange, } from '../../utils/tree'; import { UseTreeViewSelectionSignature } from './useTreeViewSelection.types'; +import { convertSelectedItemsToArray, getLookupFromArray } from './useTreeViewSelection.utils'; export const useTreeViewSelection: TreeViewPlugin = ({ instance, @@ -13,8 +17,20 @@ export const useTreeViewSelection: TreeViewPlugin models, }) => { const lastSelectedItem = React.useRef(null); - const lastSelectionWasRange = React.useRef(false); - const currentRangeSelection = React.useRef([]); + const lastSelectedRange = React.useRef<{ [itemId: string]: boolean }>({}); + + const selectedItemsMap = React.useMemo(() => { + const temp = new Map(); + if (Array.isArray(models.selectedItems.value)) { + models.selectedItems.value.forEach((id) => { + temp.set(id, true); + }); + } else if (models.selectedItems.value != null) { + temp.set(models.selectedItems.value, true); + } + + return temp; + }, [models.selectedItems.value]); const setSelectedItems = ( event: React.SyntheticEvent, @@ -53,122 +69,108 @@ export const useTreeViewSelection: TreeViewPlugin models.selectedItems.setControlledValue(newSelectedItems); }; - const isItemSelected = (itemId: string) => - Array.isArray(models.selectedItems.value) - ? models.selectedItems.value.indexOf(itemId) !== -1 - : models.selectedItems.value === itemId; + const isItemSelected = (itemId: string) => selectedItemsMap.has(itemId); const selectItem = (event: React.SyntheticEvent, itemId: string, multiple = false) => { if (params.disableSelection) { return; } + let newSelected: typeof models.selectedItems.value; if (multiple) { - if (Array.isArray(models.selectedItems.value)) { - let newSelected: string[]; - if (models.selectedItems.value.indexOf(itemId) !== -1) { - newSelected = models.selectedItems.value.filter((id) => id !== itemId); - } else { - newSelected = [itemId].concat(models.selectedItems.value); - } - - setSelectedItems(event, newSelected); + const cleanSelectedItems = convertSelectedItemsToArray(models.selectedItems.value); + if (instance.isItemSelected(itemId)) { + newSelected = cleanSelectedItems.filter((id) => id !== itemId); + } else { + newSelected = [itemId].concat(cleanSelectedItems); } } else { - const newSelected = params.multiSelect ? [itemId] : itemId; - setSelectedItems(event, newSelected); + newSelected = params.multiSelect ? [itemId] : itemId; } + + setSelectedItems(event, newSelected); lastSelectedItem.current = itemId; - lastSelectionWasRange.current = false; - currentRangeSelection.current = []; + lastSelectedRange.current = {}; }; - const handleRangeArrowSelect = (event: React.SyntheticEvent, items: TreeViewItemRange) => { - let base = (models.selectedItems.value as string[]).slice(); - const { start, next, current } = items; - - if (!next || !current) { + const selectRange = (event: React.SyntheticEvent, [start, end]: [string, string]) => { + if (params.disableSelection || !params.multiSelect) { return; } - if (currentRangeSelection.current.indexOf(current) === -1) { - currentRangeSelection.current = []; - } + let newSelectedItems = convertSelectedItemsToArray(models.selectedItems.value).slice(); - if (lastSelectionWasRange.current) { - if (currentRangeSelection.current.indexOf(next) !== -1) { - base = base.filter((id) => id === start || id !== current); - currentRangeSelection.current = currentRangeSelection.current.filter( - (id) => id === start || id !== current, - ); - } else { - base.push(next); - currentRangeSelection.current.push(next); - } - } else { - base.push(next); - currentRangeSelection.current.push(current, next); + // If the last selection was a range selection, + // remove the items that were part of the last range from the model + if (Object.keys(lastSelectedRange.current).length > 0) { + newSelectedItems = newSelectedItems.filter((id) => !lastSelectedRange.current[id]); } - setSelectedItems(event, base); + + // Add to the model the items that are part of the new range and not already part of the model. + const selectedItemsLookup = getLookupFromArray(newSelectedItems); + const range = getNonDisabledItemsInRange(instance, start, end); + const itemsToAddToModel = range.filter((id) => !selectedItemsLookup[id]); + newSelectedItems = newSelectedItems.concat(itemsToAddToModel); + + setSelectedItems(event, newSelectedItems); + lastSelectedRange.current = getLookupFromArray(range); }; - const handleRangeSelect = ( - event: React.SyntheticEvent, - items: { start: string; end: string }, - ) => { - let base = (models.selectedItems.value as string[]).slice(); - const { start, end } = items; - // If last selection was a range selection ignore items that were selected. - if (lastSelectionWasRange.current) { - base = base.filter((id) => currentRangeSelection.current.indexOf(id) === -1); + const expandSelectionRange = (event: React.SyntheticEvent, itemId: string) => { + if (lastSelectedItem.current != null) { + const [start, end] = findOrderInTremauxTree(instance, itemId, lastSelectedItem.current); + selectRange(event, [start, end]); } + }; - let range = getNavigableItemsInRange(instance, start, end); - range = range.filter((item) => !instance.isItemDisabled(item)); - currentRangeSelection.current = range; - let newSelected = base.concat(range); - newSelected = newSelected.filter((id, i) => newSelected.indexOf(id) === i); - setSelectedItems(event, newSelected); + const selectRangeFromStartToItem = (event: React.SyntheticEvent, itemId: string) => { + selectRange(event, [getFirstNavigableItem(instance), itemId]); }; - const selectRange = (event: React.SyntheticEvent, items: TreeViewItemRange, stacked = false) => { - if (params.disableSelection) { + const selectRangeFromItemToEnd = (event: React.SyntheticEvent, itemId: string) => { + selectRange(event, [itemId, getLastNavigableItem(instance)]); + }; + + const selectAllNavigableItems = (event: React.SyntheticEvent) => { + if (params.disableSelection || !params.multiSelect) { return; } - const { start = lastSelectedItem.current, end, current } = items; - if (stacked) { - handleRangeArrowSelect(event, { start, next: end, current }); - } else if (start != null && end != null) { - handleRangeSelect(event, { start, end }); - } - lastSelectionWasRange.current = true; + const navigableItems = getAllNavigableItems(instance); + setSelectedItems(event, navigableItems); + + lastSelectedRange.current = getLookupFromArray(navigableItems); }; - const rangeSelectToFirst = (event: React.KeyboardEvent, itemId: string) => { - if (!lastSelectedItem.current) { - lastSelectedItem.current = itemId; + const selectItemFromArrowNavigation = ( + event: React.SyntheticEvent, + currentItem: string, + nextItem: string, + ) => { + if (params.disableSelection || !params.multiSelect) { + return; } - const start = lastSelectionWasRange.current ? lastSelectedItem.current : itemId; + let newSelectedItems = convertSelectedItemsToArray(models.selectedItems.value).slice(); - instance.selectRange(event, { - start, - end: getFirstNavigableItem(instance), - }); - }; + if (Object.keys(lastSelectedRange.current).length === 0) { + newSelectedItems.push(nextItem); + lastSelectedRange.current = { [currentItem]: true, [nextItem]: true }; + } else { + if (!lastSelectedRange.current[currentItem]) { + lastSelectedRange.current = {}; + } - const rangeSelectToLast = (event: React.KeyboardEvent, itemId: string) => { - if (!lastSelectedItem.current) { - lastSelectedItem.current = itemId; + if (lastSelectedRange.current[nextItem]) { + newSelectedItems = newSelectedItems.filter((id) => id !== currentItem); + delete lastSelectedRange.current[currentItem]; + } else { + newSelectedItems.push(nextItem); + lastSelectedRange.current[nextItem] = true; + } } - const start = lastSelectionWasRange.current ? lastSelectedItem.current : itemId; - - instance.selectRange(event, { - start, - end: getLastNavigableItem(instance), - }); + setSelectedItems(event, newSelectedItems); }; return { @@ -178,9 +180,11 @@ export const useTreeViewSelection: TreeViewPlugin instance: { isItemSelected, selectItem, - selectRange, - rangeSelectToLast, - rangeSelectToFirst, + selectAllNavigableItems, + expandSelectionRange, + selectRangeFromStartToItem, + selectRangeFromItemToEnd, + selectItemFromArrowNavigation, }, contextValue: { selection: { diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts index 474a7b8f82e44..99c8e8d12d99c 100644 --- a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.types.ts @@ -1,14 +1,49 @@ import * as React from 'react'; -import type { DefaultizedProps, TreeViewItemRange, TreeViewPluginSignature } from '../../models'; +import type { DefaultizedProps, TreeViewPluginSignature } from '../../models'; import { UseTreeViewItemsSignature } from '../useTreeViewItems'; import { UseTreeViewExpansionSignature } from '../useTreeViewExpansion'; export interface UseTreeViewSelectionInstance { isItemSelected: (itemId: string) => boolean; - selectItem: (event: React.SyntheticEvent, itemId: string, multiple?: boolean) => void; - selectRange: (event: React.SyntheticEvent, items: TreeViewItemRange, stacked?: boolean) => void; - rangeSelectToFirst: (event: React.KeyboardEvent, itemId: string) => void; - rangeSelectToLast: (event: React.KeyboardEvent, itemId: string) => void; + selectItem: ( + event: React.SyntheticEvent, + itemId: string, + keepExistingSelection?: boolean, + ) => void; + /** + * Select all the navigable items in the tree. + * @param {React.SyntheticEvent} event The event source of the callback. + */ + selectAllNavigableItems: (event: React.SyntheticEvent) => void; + /** + * Expand the current selection range up to the given item. + * @param {React.SyntheticEvent} event The event source of the callback. + * @param {string} itemId The id of the item to expand the selection to. + */ + expandSelectionRange: (event: React.SyntheticEvent, itemId: string) => void; + /** + * Expand the current selection range from the first navigable item to the given item. + * @param {React.SyntheticEvent} event The event source of the callback. + * @param {string} itemId The id of the item up to which the selection range should be expanded. + */ + selectRangeFromStartToItem: (event: React.SyntheticEvent, itemId: string) => void; + /** + * Expand the current selection range from the given item to the last navigable item. + * @param {React.SyntheticEvent} event The event source of the callback. + * @param {string} itemId The id of the item from which the selection range should be expanded. + */ + selectRangeFromItemToEnd: (event: React.SyntheticEvent, itemId: string) => void; + /** + * Update the selection when navigating with ArrowUp / ArrowDown keys. + * @param {React.SyntheticEvent} event The event source of the callback. + * @param {string} currentItemId The id of the active item before the keyboard navigation. + * @param {string} nextItemId The id of the active item after the keyboard navigation. + */ + selectItemFromArrowNavigation: ( + event: React.SyntheticEvent, + currentItemId: string, + nextItemId: string, + ) => void; } type TreeViewSelectionValue = Multiple extends true diff --git a/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts new file mode 100644 index 0000000000000..bb022e13338c6 --- /dev/null +++ b/packages/x-tree-view/src/internals/plugins/useTreeViewSelection/useTreeViewSelection.utils.ts @@ -0,0 +1,24 @@ +/** + * Transform the `selectedItems` model to be an array if it was a string or null. + * @param {string[] | string | null} model The raw model. + * @returns {string[]} The converted model. + */ +export const convertSelectedItemsToArray = (model: string[] | string | null): string[] => { + if (Array.isArray(model)) { + return model; + } + + if (model != null) { + return [model]; + } + + return []; +}; + +export const getLookupFromArray = (array: string[]) => { + const lookup: { [itemId: string]: boolean } = {}; + array.forEach((itemId) => { + lookup[itemId] = true; + }); + return lookup; +}; diff --git a/packages/x-tree-view/src/internals/utils/tree.ts b/packages/x-tree-view/src/internals/utils/tree.ts index d3938323ca60e..1293374356c05 100644 --- a/packages/x-tree-view/src/internals/utils/tree.ts +++ b/packages/x-tree-view/src/internals/utils/tree.ts @@ -124,7 +124,7 @@ export const getFirstNavigableItem = (instance: TreeViewInstance<[UseTreeViewIte * Another way to put it is which item is shallower in a trémaux tree * https://en.wikipedia.org/wiki/Tr%C3%A9maux_tree */ -const findOrderInTremauxTree = ( +export const findOrderInTremauxTree = ( instance: TreeViewInstance<[UseTreeViewItemsSignature]>, itemAId: string, itemBId: string, @@ -185,20 +185,57 @@ const findOrderInTremauxTree = ( : [itemBId, itemAId]; }; -export const getNavigableItemsInRange = ( +export const getNonDisabledItemsInRange = ( instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>, itemAId: string, itemBId: string, ) => { + const getNextItem = (itemId: string) => { + // If the item is expanded and has some children, return the first of them. + if (instance.isItemExpandable(itemId) && instance.isItemExpanded(itemId)) { + return instance.getItemOrderedChildrenIds(itemId)[0]; + } + + let itemMeta = instance.getItemMeta(itemId); + while (itemMeta != null) { + // Try to find the first navigable sibling after the current item. + const siblings = instance.getItemOrderedChildrenIds(itemMeta.parentId); + const currentItemIndex = instance.getItemIndex(itemMeta.id); + + if (currentItemIndex < siblings.length - 1) { + return siblings[currentItemIndex + 1]; + } + + // If the item is the last of its siblings, go up a level to the parent and try again. + itemMeta = instance.getItemMeta(itemMeta.parentId!); + } + + throw new Error('Invalid range'); + }; + const [first, last] = findOrderInTremauxTree(instance, itemAId, itemBId); const items = [first]; - let current = first; while (current !== last) { - current = getNextNavigableItem(instance, current)!; - items.push(current); + current = getNextItem(current); + if (!instance.isItemDisabled(current)) { + items.push(current); + } } return items; }; + +export const getAllNavigableItems = ( + instance: TreeViewInstance<[UseTreeViewItemsSignature, UseTreeViewExpansionSignature]>, +) => { + let item: string | null = getFirstNavigableItem(instance); + const navigableItems: string[] = []; + while (item != null) { + navigableItems.push(item); + item = getNextNavigableItem(instance, item); + } + + return navigableItems; +};