diff --git a/src/resolver/__tests__/findAllComponentDefinitions-test.js b/src/resolver/__tests__/findAllComponentDefinitions-test.js
index 70d637e944f..cac532294cd 100644
--- a/src/resolver/__tests__/findAllComponentDefinitions-test.js
+++ b/src/resolver/__tests__/findAllComponentDefinitions-test.js
@@ -189,4 +189,44 @@ describe('findAllComponentDefinitions', () => {
expect(result.length).toBe(0);
});
});
+
+ describe('forwardRef components', () => {
+ it('finds forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ const ColoredView = React.forwardRef((props, ref) => (
+
+ ));
+
+ extendStyles(ColoredView);
+ `;
+
+ const result = parse(source);
+ expect(Array.isArray(result)).toBe(true);
+ expect(result.length).toBe(1);
+ expect(result[0].value.type).toEqual('CallExpression');
+ });
+
+ it('finds none inline forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ function ColoredView(props, ref) {
+ return
+ }
+
+ const ForwardedColoredView = React.forwardRef(ColoredView);
+ `;
+
+ const result = parse(source);
+ expect(Array.isArray(result)).toBe(true);
+ expect(result.length).toBe(1);
+ expect(result[0].value.type).toEqual('CallExpression');
+ });
+ });
});
diff --git a/src/resolver/__tests__/findAllExportedComponentDefinitions-test.js b/src/resolver/__tests__/findAllExportedComponentDefinitions-test.js
index 1b0a56dc9e2..f03bd882bc9 100644
--- a/src/resolver/__tests__/findAllExportedComponentDefinitions-test.js
+++ b/src/resolver/__tests__/findAllExportedComponentDefinitions-test.js
@@ -143,6 +143,50 @@ describe('findAllExportedComponentDefinitions', () => {
});
});
+ describe('forwardRef components', () => {
+ it('finds forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ const ColoredView = React.forwardRef((props, ref) => (
+
+ ));
+
+ module.exports = extendStyles(ColoredView);
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+
+ it('finds none inline forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ function ColoredView(props, ref) {
+ return
+ }
+
+ const ForwardedColoredView = React.forwardRef(ColoredView);
+
+ module.exports = ForwardedColoredView
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+ });
+
describe('module.exports = ; / exports.foo = ;', () => {
describe('React.createClass', () => {
it('finds assignments to exports', () => {
@@ -486,6 +530,50 @@ describe('findAllExportedComponentDefinitions', () => {
expect(actual.length).toBe(1);
});
});
+
+ describe('forwardRef components', () => {
+ it('finds forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ const ColoredView = React.forwardRef((props, ref) => (
+
+ ));
+
+ export default extendStyles(ColoredView);
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+
+ it('finds none inline forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ function ColoredView(props, ref) {
+ return
+ }
+
+ const ForwardedColoredView = React.forwardRef(ColoredView);
+
+ export default ForwardedColoredView
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+ });
});
describe('export var foo = , ...;', () => {
@@ -734,6 +822,26 @@ describe('findAllExportedComponentDefinitions', () => {
expect(actual[0].node.type).toBe('FunctionExpression');
});
});
+
+ describe('forwardRef components', () => {
+ it('finds forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ export const ColoredView = extendStyles(React.forwardRef((props, ref) => (
+
+ )));
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+ });
});
describe('export {};', () => {
@@ -994,6 +1102,28 @@ describe('findAllExportedComponentDefinitions', () => {
expect(actual[0].node.type).toBe('ArrowFunctionExpression');
});
});
+
+ describe('forwardRef components', () => {
+ it('finds forwardRef components', () => {
+ const source = `
+ import React from 'react';
+ import PropTypes from 'prop-types';
+ import extendStyles from 'enhancers/extendStyles';
+
+ const ColoredView = extendStyles(React.forwardRef((props, ref) => (
+
+ )));
+
+ export { ColoredView }
+ `;
+
+ const parsed = parse(source);
+ const actual = findComponents(parsed);
+
+ expect(actual.length).toBe(1);
+ expect(actual[0].value.type).toEqual('CallExpression');
+ });
+ });
});
describe('export ;', () => {
diff --git a/src/resolver/findAllComponentDefinitions.js b/src/resolver/findAllComponentDefinitions.js
index 57a5dc88296..235b5d07680 100644
--- a/src/resolver/findAllComponentDefinitions.js
+++ b/src/resolver/findAllComponentDefinitions.js
@@ -12,6 +12,7 @@
import isReactComponentClass from '../utils/isReactComponentClass';
import isReactCreateClassCall from '../utils/isReactCreateClassCall';
+import isReactForwardRefCall from '../utils/isReactForwardRefCall';
import isStatelessComponent from '../utils/isStatelessComponent';
import normalizeClassDefinition from '../utils/normalizeClassDefinition';
import resolveToValue from '../utils/resolveToValue';
@@ -25,19 +26,19 @@ export default function findAllReactCreateClassCalls(
recast: Object,
): Array {
const types = recast.types.namedTypes;
- const definitions = [];
+ const definitions = new Set();
function classVisitor(path) {
if (isReactComponentClass(path)) {
normalizeClassDefinition(path);
- definitions.push(path);
+ definitions.add(path);
}
return false;
}
function statelessVisitor(path) {
if (isStatelessComponent(path)) {
- definitions.push(path);
+ definitions.add(path);
}
return false;
}
@@ -49,16 +50,21 @@ export default function findAllReactCreateClassCalls(
visitClassExpression: classVisitor,
visitClassDeclaration: classVisitor,
visitCallExpression: function(path) {
- if (!isReactCreateClassCall(path)) {
- return false;
- }
- const resolvedPath = resolveToValue(path.get('arguments', 0));
- if (types.ObjectExpression.check(resolvedPath.node)) {
- definitions.push(resolvedPath);
+ if (isReactForwardRefCall(path)) {
+ // If the the inner function was previously identified as a component
+ // replace it with the parent node
+ const inner = resolveToValue(path.get('arguments', 0));
+ definitions.delete(inner);
+ definitions.add(path);
+ } else if (isReactCreateClassCall(path)) {
+ const resolvedPath = resolveToValue(path.get('arguments', 0));
+ if (types.ObjectExpression.check(resolvedPath.node)) {
+ definitions.add(resolvedPath);
+ }
}
return false;
},
});
- return definitions;
+ return Array.from(definitions);
}
diff --git a/src/resolver/findAllExportedComponentDefinitions.js b/src/resolver/findAllExportedComponentDefinitions.js
index 9e3f561122d..e634845709d 100644
--- a/src/resolver/findAllExportedComponentDefinitions.js
+++ b/src/resolver/findAllExportedComponentDefinitions.js
@@ -12,6 +12,7 @@
import isExportsOrModuleAssignment from '../utils/isExportsOrModuleAssignment';
import isReactComponentClass from '../utils/isReactComponentClass';
import isReactCreateClassCall from '../utils/isReactCreateClassCall';
+import isReactForwardRefCall from '../utils/isReactForwardRefCall';
import isStatelessComponent from '../utils/isStatelessComponent';
import normalizeClassDefinition from '../utils/normalizeClassDefinition';
import resolveExportDeclaration from '../utils/resolveExportDeclaration';
@@ -26,7 +27,8 @@ function isComponentDefinition(path) {
return (
isReactCreateClassCall(path) ||
isReactComponentClass(path) ||
- isStatelessComponent(path)
+ isStatelessComponent(path) ||
+ isReactForwardRefCall(path)
);
}
@@ -40,7 +42,10 @@ function resolveDefinition(definition, types): ?NodePath {
} else if (isReactComponentClass(definition)) {
normalizeClassDefinition(definition);
return definition;
- } else if (isStatelessComponent(definition)) {
+ } else if (
+ isStatelessComponent(definition) ||
+ isReactForwardRefCall(definition)
+ ) {
return definition;
}
return null;
diff --git a/src/resolver/findExportedComponentDefinition.js b/src/resolver/findExportedComponentDefinition.js
index c48594f4f6e..f0ff0acad36 100644
--- a/src/resolver/findExportedComponentDefinition.js
+++ b/src/resolver/findExportedComponentDefinition.js
@@ -10,9 +10,9 @@
*/
import isExportsOrModuleAssignment from '../utils/isExportsOrModuleAssignment';
-import isReactForwardRefCall from '../utils/isReactForwardRefCall';
import isReactComponentClass from '../utils/isReactComponentClass';
import isReactCreateClassCall from '../utils/isReactCreateClassCall';
+import isReactForwardRefCall from '../utils/isReactForwardRefCall';
import isStatelessComponent from '../utils/isStatelessComponent';
import normalizeClassDefinition from '../utils/normalizeClassDefinition';
import resolveExportDeclaration from '../utils/resolveExportDeclaration';
@@ -45,9 +45,10 @@ function resolveDefinition(definition, types) {
} else if (isReactComponentClass(definition)) {
normalizeClassDefinition(definition);
return definition;
- } else if (isStatelessComponent(definition)) {
- return definition;
- } else if (isReactForwardRefCall(definition)) {
+ } else if (
+ isStatelessComponent(definition) ||
+ isReactForwardRefCall(definition)
+ ) {
return definition;
}
return null;
diff --git a/src/utils/resolveHOC.js b/src/utils/resolveHOC.js
index efb28520945..564aac28801 100644
--- a/src/utils/resolveHOC.js
+++ b/src/utils/resolveHOC.js
@@ -12,6 +12,7 @@
import recast from 'recast';
import isReactCreateClassCall from './isReactCreateClassCall';
+import isReactForwardRefCall from './isReactForwardRefCall';
const {
types: { NodePath, namedTypes: types },
@@ -25,7 +26,11 @@ const {
*/
export default function resolveHOC(path: NodePath): NodePath {
const node = path.node;
- if (types.CallExpression.check(node) && !isReactCreateClassCall(path)) {
+ if (
+ types.CallExpression.check(node) &&
+ !isReactCreateClassCall(path) &&
+ !isReactForwardRefCall(path)
+ ) {
if (node.arguments.length) {
return resolveHOC(path.get('arguments', node.arguments.length - 1));
}