diff --git a/packages/react/src/ReactForwardRef.js b/packages/react/src/ReactForwardRef.js index 4f891d609f..f11ce4deba 100644 --- a/packages/react/src/ReactForwardRef.js +++ b/packages/react/src/ReactForwardRef.js @@ -9,6 +9,10 @@ import {REACT_FORWARD_REF_TYPE, REACT_MEMO_TYPE} from 'shared/ReactSymbols'; +export function isForwardRef(type: mixed): boolean %checks { + return typeof type === 'object' && type !== null && type.$$typeof === REACT_FORWARD_REF_TYPE; +} + export function forwardRef( render: ( props: Props, diff --git a/packages/react/src/ReactMemo.js b/packages/react/src/ReactMemo.js index 8f2c0f5382..fcc3450d22 100644 --- a/packages/react/src/ReactMemo.js +++ b/packages/react/src/ReactMemo.js @@ -7,11 +7,12 @@ * @noflow */ -import {REACT_MEMO_TYPE} from 'shared/ReactSymbols'; +import {REACT_MEMO_TYPE, REACT_FORWARD_REF_TYPE} from 'shared/ReactSymbols'; +import shallowEqual from 'shared/shallowEqual'; export function memo( type: React$ElementType, - compare?: (oldProps: Props, newProps: Props) => boolean, + compare?: (oldProps: Props, newProps: Props, oldRef: mixed, newRef: mixed) => boolean, ) { if (__DEV__) { if (type == null) { @@ -22,10 +23,22 @@ export function memo( ); } } + + // Create custom compare function that includes ref for forwardRef components + const isForwardRefComponent = typeof type === 'object' && type !== null && type.$$typeof === REACT_FORWARD_REF_TYPE; + let finalCompare = compare; + + if (isForwardRefComponent && compare === undefined) { + // Default compare for forwardRef: shallow equal props + strict equal ref + finalCompare = function compareWithRef(oldProps, newProps, oldRef, newRef) { + return shallowEqual(oldProps, newProps) && oldRef === newRef; + }; + } + const elementType = { $$typeof: REACT_MEMO_TYPE, type, - compare: compare === undefined ? null : compare, + compare: finalCompare === undefined ? null : finalCompare, }; if (__DEV__) { let ownName;