Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[compiler] Add lowerContextAccess pass #30548

Merged
merged 12 commits into from
Aug 7, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ import {
import {validateLocalsNotReassignedAfterRender} from '../Validation/ValidateLocalsNotReassignedAfterRender';
import {outlineFunctions} from '../Optimization/OutlineFunctions';
import {propagatePhiTypes} from '../TypeInference/PropagatePhiTypes';
import {lowerContextAccess} from '../Optimization/LowerContextAccess';

export type CompilerPipelineValue =
| {kind: 'ast'; name: string; value: CodegenFunction}
Expand Down Expand Up @@ -204,6 +205,10 @@ function* runWithEnvironment(
validateNoCapitalizedCalls(hir);
}

if (env.config.enableLowerContextAccess) {
lowerContextAccess(hir);
}

analyseFunctions(hir);
yield log({kind: 'hir', name: 'AnalyseFunctions', value: hir});

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
/**
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/

import {
ArrayExpression,
BasicBlock,
CallExpression,
Destructure,
Environment,
GeneratedSource,
HIRFunction,
IdentifierId,
Instruction,
LoadLocal,
Place,
PropertyLoad,
isUseContextHookType,
makeBlockId,
makeInstructionId,
markInstructionIds,
promoteTemporary,
reversePostorderBlocks,
} from '../HIR';
import {createTemporaryPlace} from '../HIR/HIRBuilder';
import {enterSSA} from '../SSA';
import {inferTypes} from '../TypeInference';

export function lowerContextAccess(fn: HIRFunction): void {
const contextAccess: Map<IdentifierId, CallExpression> = new Map();
const contextKeys: Map<IdentifierId, Array<string>> = new Map();

// collect context access and keys
for (const [, block] of fn.body.blocks) {
for (const instr of block.instructions) {
const {value, lvalue} = instr;

if (
value.kind === 'CallExpression' &&
isUseContextHookType(value.callee.identifier)
) {
contextAccess.set(lvalue.identifier.id, value);
continue;
}

if (value.kind !== 'Destructure') {
continue;
}

const destructureId = value.value.identifier.id;
if (!contextAccess.has(destructureId)) {
continue;
}

const keys = getContextKeys(value);
if (keys === null) {
return;
}

if (contextKeys.has(destructureId)) {
/*
* TODO(gsn): Add support for accessing context over multiple
* statements.
*/
gsathya marked this conversation as resolved.
Show resolved Hide resolved
return;
} else {
contextKeys.set(destructureId, keys);
}
}
}

if (contextAccess.size > 0) {
for (const [, block] of fn.body.blocks) {
let nextInstructions: Array<Instruction> | null = null;

for (let i = 0; i < block.instructions.length; i++) {
const instr = block.instructions[i];
const {lvalue, value} = instr;
if (
value.kind === 'CallExpression' &&
isUseContextHookType(value.callee.identifier) &&
contextKeys.has(lvalue.identifier.id)
) {
const keys = contextKeys.get(lvalue.identifier.id)!;
const selectorFnInstr = emitSelectorFn(fn.env, keys);
if (nextInstructions === null) {
nextInstructions = block.instructions.slice(0, i);
}
nextInstructions.push(selectorFnInstr);

const selectorFn = selectorFnInstr.lvalue;
value.args.push(selectorFn);
}

if (nextInstructions) {
nextInstructions.push(instr);
}
}
if (nextInstructions) {
block.instructions = nextInstructions;
}
}
markInstructionIds(fn.body);
}
}

function getContextKeys(value: Destructure): Array<string> | null {
const keys = [];
const pattern = value.lvalue.pattern;

switch (pattern.kind) {
case 'ArrayPattern': {
return null;
}

case 'ObjectPattern': {
for (const place of pattern.properties) {
if (
place.kind !== 'ObjectProperty' ||
place.type !== 'property' ||
place.key.kind !== 'identifier' ||
place.place.identifier.name === null ||
place.place.identifier.name.kind !== 'named'
) {
return null;
}
keys.push(place.key.name);
}
return keys;
}
}
}

function emitPropertyLoad(
env: Environment,
obj: Place,
property: string,
): {instructions: Array<Instruction>; element: Place} {
const loadObj: LoadLocal = {
kind: 'LoadLocal',
place: obj,
loc: GeneratedSource,
};
const object: Place = createTemporaryPlace(env, GeneratedSource);
const loadLocalInstr: Instruction = {
lvalue: object,
value: loadObj,
id: makeInstructionId(0),
loc: GeneratedSource,
};

const loadProp: PropertyLoad = {
kind: 'PropertyLoad',
object,
property,
loc: GeneratedSource,
};
const element: Place = createTemporaryPlace(env, GeneratedSource);
const loadPropInstr: Instruction = {
lvalue: element,
value: loadProp,
id: makeInstructionId(0),
loc: GeneratedSource,
};
return {
instructions: [loadLocalInstr, loadPropInstr],
element: element,
};
}

function emitSelectorFn(env: Environment, keys: Array<string>): Instruction {
const obj: Place = createTemporaryPlace(env, GeneratedSource);
promoteTemporary(obj.identifier);
const instr: Array<Instruction> = [];
const elements = [];
for (const key of keys) {
const {instructions, element: prop} = emitPropertyLoad(env, obj, key);
instr.push(...instructions);
elements.push(prop);
}

const arrayInstr = emitArrayInstr(elements, env);
instr.push(arrayInstr);

const block: BasicBlock = {
kind: 'block',
id: makeBlockId(0),
instructions: instr,
terminal: {
id: makeInstructionId(0),
kind: 'return',
loc: GeneratedSource,
value: arrayInstr.lvalue,
},
preds: new Set(),
phis: new Set(),
};

const fn: HIRFunction = {
loc: GeneratedSource,
id: null,
fnType: 'Other',
env,
params: [obj],
returnType: null,
context: [],
effects: null,
body: {
entry: block.id,
blocks: new Map([[block.id, block]]),
},
generator: false,
async: false,
directives: [],
};

reversePostorderBlocks(fn.body);
markInstructionIds(fn.body);
enterSSA(fn);
inferTypes(fn);

const fnInstr: Instruction = {
id: makeInstructionId(0),
value: {
kind: 'FunctionExpression',
name: null,
loweredFunc: {
func: fn,
dependencies: [],
},
type: 'ArrowFunctionExpression',
loc: GeneratedSource,
},
lvalue: createTemporaryPlace(env, GeneratedSource),
loc: GeneratedSource,
};
return fnInstr;
}

function emitArrayInstr(elements: Array<Place>, env: Environment): Instruction {
const array: ArrayExpression = {
kind: 'ArrayExpression',
elements,
loc: GeneratedSource,
};
const arrayLvalue: Place = createTemporaryPlace(env, GeneratedSource);
const arrayInstr: Instruction = {
id: makeInstructionId(0),
value: array,
lvalue: arrayLvalue,
loc: GeneratedSource,
};
return arrayInstr;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@

## Input

```javascript
// @enableLowerContextAccess
function App() {
const {foo} = useContext(MyContext);
const {bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}

```

## Code

```javascript
import { c as _c } from "react/compiler-runtime"; // @enableLowerContextAccess
function App() {
const $ = _c(3);
const { foo } = useContext(MyContext, _temp);
const { bar } = useContext(MyContext, _temp2);
let t0;
if ($[0] !== foo || $[1] !== bar) {
t0 = <Bar foo={foo} bar={bar} />;
$[0] = foo;
$[1] = bar;
$[2] = t0;
} else {
t0 = $[2];
}
return t0;
}
function _temp2(t0) {
return [t0.bar];
}
function _temp(t0) {
return [t0.foo];
}

```

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
// @enableLowerContextAccess
function App() {
const {foo} = useContext(MyContext);
const {bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

## Input

```javascript
// @enableLowerContextAccess
function App() {
const {foo, bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}

```

## Code

```javascript
import { c as _c } from "react/compiler-runtime"; // @enableLowerContextAccess
function App() {
const $ = _c(3);
const { foo, bar } = useContext(MyContext, _temp);
let t0;
if ($[0] !== foo || $[1] !== bar) {
t0 = <Bar foo={foo} bar={bar} />;
$[0] = foo;
$[1] = bar;
$[2] = t0;
} else {
t0 = $[2];
}
return t0;
}
function _temp(t0) {
return [t0.foo, t0.bar];
}
Comment on lines +19 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you thinking about sequencing the next steps here? Asking because the generated code doesn't seem like it could work yet: the temp function returns an array, but the context value is destructured into an object and there's nothing that tells us how to map the array into the object.

Seems like we'd need to convert the destructuring into an ArrayPattern to match? Also, if we take Jack's suggestion of compiling into a custom useContextSelector-style hook, then we can implement a basic version of that in shared-runtime and this could be an end-to-end test with sprout.

Copy link
Contributor Author

@gsathya gsathya Aug 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return value of the selector isn't returned from the useContext call. The runtime simply compares the return values from the selector to determine if the context is dirty. If the values are different, then it returns the context object from useContext call, not the values from the selector function.

Ideally we'd just compile the selector function with Forget and compare the result, rather than iterate over an array, but that's optimisation for the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahhhhh got it, makes sense


```

Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
// @enableLowerContextAccess
function App() {
const {foo, bar} = useContext(MyContext);
return <Bar foo={foo} bar={bar} />;
}
Loading