Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Collapsing nodes #8686

Merged
merged 17 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import {
averagePositionPlacement,
mouseDictatedPlacement,
nonDictatedPlacement,
previousNodeDictatedPlacement,
Expand Down Expand Up @@ -447,6 +448,54 @@ describe('Mouse dictated placement', () => {
})
})

describe('Average position placement', () => {
function environment(selectedNodeRects: Rect[], nonSelectedNodeRects: Rect[]): Environment {
return {
screenBounds,
nodeRects: [...selectedNodeRects, ...nonSelectedNodeRects],
selectedNodeRects,
get mousePosition() {
return getMousePosition()
},
}
}

function options(): { horizontalGap: number; verticalGap: number } {
return {
get horizontalGap() {
return getHorizontalGap()
},
get verticalGap() {
return getVerticalGap()
},
}
}

test('One selected, no other nodes', () => {
const X = 1100
const Y = 700
const selectedNodeRects = [rectAt(X, Y)]
const result = averagePositionPlacement(nodeSize, environment(selectedNodeRects, []), options())
expect(result).toEqual({ position: new Vec2(X, Y), pan: undefined })
})

test('Multiple selected, no other nodes', () => {
const selectedNodeRects = [rectAt(1000, 600), rectAt(1300, 800)]
const result = averagePositionPlacement(nodeSize, environment(selectedNodeRects, []), options())
expect(result).toEqual({ position: new Vec2(1150, 700), pan: undefined })
})

test('Average position occupied', () => {
const selectedNodeRects = [rectAt(1000, 600), rectAt(1300, 800)]
const result = averagePositionPlacement(
nodeSize,
environment(selectedNodeRects, [rectAt(1150, 700)]),
options(),
)
expect(result).toEqual({ position: new Vec2(1150, 744), pan: undefined })
})
})

// === Helpers for debugging ===

function generateVueCodeForNonDictatedPlacement(newNode: Rect, rects: Rect[]) {
Expand Down
51 changes: 51 additions & 0 deletions app/gui2/src/components/ComponentBrowser/placement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,54 @@ export function mouseDictatedPlacement(
const nodeRadius = nodeSize.y / 2
return { position: mousePosition.add(new Vec2(nodeRadius, nodeRadius)) }
}

/** The new node should appear at the average position of selected nodes.
*
* If the desired place is already occupied by non-selected node, it should be moved down to the closest free space.
*
* Specifically, this code, in order:
* - calculates the average position of selected nodes
* - searches for all vertical spans below the initial position,
* that horizontally intersect the initial position (no horizontal gap is required between
* the new node and old nodes)
* - shifts the node down (if required) until there is sufficient vertical space -
* the height of the node, in addition to the specified gap both above and below the node.
*/
export function averagePositionPlacement(
nodeSize: Vec2,
{ screenBounds, selectedNodeRects, nodeRects }: Environment,
{ verticalGap = theme.node.vertical_gap }: PlacementOptions = {},
): Placement {
let totalPosition = new Vec2(0, 0)
let selectedNodeRectsCount = 0
for (const rect of selectedNodeRects) {
totalPosition = totalPosition.add(rect.pos)
selectedNodeRectsCount++
}
const initialPosition = totalPosition.scale(1.0 / selectedNodeRectsCount)
const nonSelectedNodeRects = []
outer: for (const rect of nodeRects) {
for (const sel of selectedNodeRects) {
if (sel.equals(rect)) {
continue outer
}
}
nonSelectedNodeRects.push(rect)
}
let top = initialPosition.y
const initialRect = new Rect(initialPosition, nodeSize)
const nodeRectsSorted = Array.from(nonSelectedNodeRects).sort((a, b) => a.top - b.top)
for (const rect of nodeRectsSorted) {
if (initialRect.intersectsX(rect) && rect.bottom + verticalGap > top) {
if (rect.top - (top + nodeSize.y) < verticalGap) {
top = rect.bottom + verticalGap
}
}
}
const finalPosition = new Vec2(initialPosition.x, top)
if (new Rect(finalPosition, nodeSize).within(screenBounds)) {
return { position: finalPosition }
} else {
return { position: finalPosition, pan: finalPosition.sub(initialPosition) }
}
}
18 changes: 15 additions & 3 deletions app/gui2/src/components/GraphEditor.vue
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import { useGraphStore } from '@/stores/graph'
import type { RequiredImport } from '@/stores/graph/imports'
import { useProjectStore } from '@/stores/project'
import { groupColorVar, useSuggestionDbStore } from '@/stores/suggestionDatabase'
import { assert, bail } from '@/util/assert'
import { BodyBlock } from '@/util/ast/abstract'
import { colorFromString } from '@/util/colors'
import { Rect } from '@/util/data/rect'
import { Vec2 } from '@/util/data/vec2'
Expand Down Expand Up @@ -255,11 +257,21 @@ const graphBindingsHandler = graphBindings.handler({
},
collapse() {
if (keyboardBusy()) return false
const selected = nodeSelection.selected
const selected = new Set(nodeSelection.selected)
if (selected.size == 0) return
try {
const info = prepareCollapsedInfo(nodeSelection.selected, graphStore.db)
performCollapse(info)
const info = prepareCollapsedInfo(selected, graphStore.db)
const currentMethod = projectStore.executionContext.getStackTop()
const currentMethodName = graphStore.db.stackItemToMethodName(currentMethod)
if (currentMethodName == null) {
bail(`Cannot get the method name for the current execution stack item. ${currentMethod}`)
}
graphStore.editAst((module) => {
if (graphStore.moduleRoot == null) bail(`Module root is missing.`)
const topLevel = module.get(graphStore.moduleRoot)
assert(topLevel instanceof BodyBlock)
return performCollapse(info, module, topLevel, graphStore.db, currentMethodName)
})
} catch (err) {
console.log(`Error while collapsing, this is not normal. ${err}`)
}
Expand Down
4 changes: 3 additions & 1 deletion app/gui2/src/components/GraphEditor/GraphNode.vue
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import { Vec2 } from '@/util/data/vec2'
import { displayedIconOf } from '@/util/getIconName'
import { setIfUndefined } from 'lib0/map'
import type { ExprId, VisualizationIdentifier } from 'shared/yjsModel'
import { computed, ref, watch, watchEffect } from 'vue'
import { computed, onUnmounted, ref, watch, watchEffect } from 'vue'

const MAXIMUM_CLICK_LENGTH_MS = 300
const MAXIMUM_CLICK_DISTANCE_SQ = 50
Expand Down Expand Up @@ -73,6 +73,8 @@ const outputPortsSet = computed(() => {
const widthOverridePx = ref<number>()
const nodeId = computed(() => props.node.rootSpan.exprId)

onUnmounted(() => graph.unregisterNodeRect(nodeId.value))

const rootNode = ref<HTMLElement>()
const contentNode = ref<HTMLElement>()
const nodeSize = useResizeObserver(rootNode)
Expand Down
136 changes: 112 additions & 24 deletions app/gui2/src/components/GraphEditor/collapsing.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { GraphDb } from '@/stores/graph/graphDatabase'
import { Ast } from '@/util/ast'
import { moduleMethodNames } from '@/util/ast/abstract'
import { unwrap } from '@/util/data/result'
import { tryIdentifier, type Identifier } from '@/util/qualifiedName'
import assert from 'assert'
Expand Down Expand Up @@ -40,8 +41,8 @@ interface RefactoredInfo {
id: ExprId
/** The pattern of the refactored node. Included for convinience, collapsing does not affect it. */
pattern: string
/** The new expression of the refactored node. A call to the extracted function with the list of necessary arguments. */
expression: string
/** The list of necessary arguments for a call of the collapsed function. */
arguments: Identifier[]
}

// === prepareCollapsedInfo ===
Expand All @@ -55,19 +56,20 @@ export function prepareCollapsedInfo(selected: Set<ExprId>, graphDb: GraphDb): C
const leaves = new Set([...selected])
const inputs: Identifier[] = []
let output: Output | null = null
for (const [targetExprId, sourceExprIds] of graphDb.connections.allReverse()) {
for (const [targetExprId, sourceExprIds] of graphDb.allConnections.allReverse()) {
const target = graphDb.getExpressionNodeId(targetExprId)
if (target == null) throw new Error(`Connection target node for id ${targetExprId} not found.`)
if (target == null) continue
for (const sourceExprId of sourceExprIds) {
const source = graphDb.getPatternExpressionNodeId(sourceExprId)
if (source == null)
throw new Error(`Connection source node for id ${sourceExprId} not found.`)
const startsInside = selected.has(source)
const startsInside = source != null && selected.has(source)
const endsInside = selected.has(target)
const stringIdentifier = graphDb.getOutputPortIdentifier(sourceExprId)
if (stringIdentifier == null) throw new Error(`Source node (${source}) has no pattern.`)
if (stringIdentifier == null)
throw new Error(`Source node (${source}) has no output identifier.`)
const identifier = unwrap(tryIdentifier(stringIdentifier))
leaves.delete(source)
if (source != null) {
leaves.delete(source)
}
if (!startsInside && endsInside) {
inputs.push(identifier)
} else if (startsInside && !endsInside) {
Expand Down Expand Up @@ -105,21 +107,107 @@ export function prepareCollapsedInfo(selected: Set<ExprId>, graphDb: GraphDb): C
refactored: {
id: output.node,
pattern,
expression: 'Main.collapsed' + (inputs.length > 0 ? ' ' : '') + inputs.join(' '),
arguments: inputs,
},
}
}

// === performRefactoring ===
/** Generate a safe method name for a collapsed function using `baseName` as a prefix. */
function findSafeMethodName(module: Ast.Module, baseName: string): string {
const allIdentifiers = moduleMethodNames(module)
if (!allIdentifiers.has(baseName)) {
return baseName
}
let index = 1
while (allIdentifiers.has(`${baseName}${index}`)) {
index++
}
return `${baseName}${index}`
}

// === performCollapse ===

// We support working inside `Main` module of the project at the moment.
const MODULE_NAME = 'Main'
const COLLAPSED_FUNCTION_NAME = 'collapsed'

/** Perform the actual AST refactoring for collapsing nodes. */
export function performCollapse(_info: CollapsedInfo) {
// The general flow of this function:
// 1. Create a new function with a unique name and a list of arguments from the `ExtractedInfo`.
// 2. Move all nodes with `ids` from the `ExtractedInfo` into this new function. Use the order of their original definition.
// 3. Use a single identifier `output.identifier` as the return value of the function.
// 4. Change the expression of the `RefactoredInfo.id` node to the `RefactoredINfo.expression`
throw new Error('Not yet implemented, requires AST editing.')
export function performCollapse(
info: CollapsedInfo,
module: Ast.Module,
topLevel: Ast.BodyBlock,
db: GraphDb,
currentMethodName: string,
): Ast.MutableModule {
const functionAst = Ast.findModuleMethod(module, currentMethodName)
if (!(functionAst instanceof Ast.Function) || !(functionAst.body instanceof Ast.BodyBlock)) {
throw new Error(`Expected a collapsable function, found ${functionAst}.`)
}
const functionBlock = functionAst.body
const posToInsert = findInsertionPos(module, topLevel, currentMethodName)
const collapsedName = findSafeMethodName(module, COLLAPSED_FUNCTION_NAME)
const astIdsToExtract = new Set(
[...info.extracted.ids].map((nodeId) => db.nodeIdToNode.get(nodeId)?.outerExprId),
)
const astIdToReplace = db.nodeIdToNode.get(info.refactored.id)?.outerExprId
const collapsed = []
const refactored = []
const edit = module.edit()
const lines = functionBlock.lines()
for (const line of lines) {
const astId = line.expression?.node.exprId
const ast = astId != null ? module.get(astId) : null
if (ast == null) continue
if (astIdsToExtract.has(astId)) {
collapsed.push(ast)
if (astId === astIdToReplace) {
const newAst = collapsedCallAst(info, collapsedName, edit)
refactored.push({ expression: { node: newAst.exprId } })
}
} else {
refactored.push({ expression: { node: ast.exprId } })
}
}
const outputIdentifier = info.extracted.output?.identifier
if (outputIdentifier != null) {
collapsed.push(Ast.Ident.new(edit, outputIdentifier))
}
// Update the definiton of refactored function.
new Ast.BodyBlock(edit, functionBlock.exprId, refactored)

const args: Ast.Ast[] = info.extracted.inputs.map((arg) => Ast.Ident.new(edit, arg))
const collapsedFunction = Ast.Function.new(edit, collapsedName, args, collapsed, true)
topLevel.insert(edit, posToInsert, collapsedFunction)
return edit
}

/** Prepare a method call expression for collapsed method. */
function collapsedCallAst(
info: CollapsedInfo,
collapsedName: string,
edit: Ast.MutableModule,
): Ast.Ast {
const pattern = info.refactored.pattern
const args = info.refactored.arguments
const functionName = `${MODULE_NAME}.${collapsedName}`
const expression = functionName + (args.length > 0 ? ' ' : '') + args.join(' ')
const assignment = Ast.Assignment.new(edit, pattern, Ast.parse(expression, edit))
return assignment
}

/** Find the position before the current method to insert a collapsed one. */
function findInsertionPos(
module: Ast.Module,
topLevel: Ast.BodyBlock,
currentMethodName: string,
): number {
const currentFuncPosition = topLevel.lines().findIndex((line) => {
const node = line.expression?.node
const expr = node ? module.get(node.exprId)?.innerExpression() : null
return expr instanceof Ast.Function && expr.name?.code() === currentMethodName
})

return currentFuncPosition === -1 ? 0 : currentFuncPosition
}

// === Tests ===
Expand Down Expand Up @@ -148,7 +236,7 @@ if (import.meta.vitest) {
}
refactored: {
replace: string
with: { pattern: string; expression: string }
with: { pattern: string; arguments: string[] }
}
}
}
Expand All @@ -166,7 +254,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'c = A + B',
with: { pattern: 'c', expression: 'Main.collapsed a' },
with: { pattern: 'c', arguments: ['a'] },
},
},
},
Expand All @@ -182,7 +270,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'd = a + b',
with: { pattern: 'd', expression: 'Main.collapsed a b' },
with: { pattern: 'd', arguments: ['a', 'b'] },
},
},
},
Expand All @@ -198,7 +286,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'c = 50 + d',
with: { pattern: 'c', expression: 'Main.collapsed' },
with: { pattern: 'c', arguments: [] },
},
},
},
Expand All @@ -219,7 +307,7 @@ if (import.meta.vitest) {
},
refactored: {
replace: 'vector = range.to_vector',
with: { pattern: 'vector', expression: 'Main.collapsed number1 number2' },
with: { pattern: 'vector', arguments: ['number1', 'number2'] },
},
},
},
Expand Down Expand Up @@ -261,6 +349,6 @@ if (import.meta.vitest) {
expect(extracted.ids).toEqual(new Set(expectedIds))
expect(refactored.id).toEqual(expectedRefactoredId)
expect(refactored.pattern).toEqual(expectedRefactored.with.pattern)
expect(refactored.expression).toEqual(expectedRefactored.with.expression)
expect(refactored.arguments).toEqual(expectedRefactored.with.arguments)
})
}
Loading
Loading