diff --git a/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoDerivedComputationsInEffects_exp.ts b/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoDerivedComputationsInEffects_exp.ts index 28d19fb1db..7eb820d87d 100644 --- a/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoDerivedComputationsInEffects_exp.ts +++ b/compiler/packages/babel-plugin-react-compiler/src/Validation/ValidateNoDerivedComputationsInEffects_exp.ts @@ -21,7 +21,6 @@ import { isUseStateType, BasicBlock, isUseRefType, - GeneratedSource, SourceLocation, } from '../HIR'; import {eachInstructionLValue, eachInstructionOperand} from '../HIR/visitors'; @@ -41,8 +40,8 @@ type ValidationContext = { readonly errors: CompilerError; readonly derivationCache: DerivationCache; readonly effects: Set; - readonly setStateCache: Map>; - readonly effectSetStateCache: Map>; + readonly setStateLoads: Map; + readonly setStateUsages: Map>; }; class DerivationCache { @@ -182,19 +181,16 @@ export function validateNoDerivedComputationsInEffects_exp( const errors = new CompilerError(); const effects: Set = new Set(); - const setStateCache: Map> = new Map(); - const effectSetStateCache: Map< - string | undefined | null, - Array - > = new Map(); + const setStateLoads: Map = new Map(); + const setStateUsages: Map> = new Map(); const context: ValidationContext = { functions, errors, derivationCache, effects, - setStateCache, - effectSetStateCache, + setStateLoads, + setStateUsages, }; if (fn.fnType === 'Hook') { @@ -284,11 +280,60 @@ function joinValue( return 'fromPropsAndState'; } +function getRootSetState( + key: IdentifierId, + loads: Map, + visited: Set = new Set(), +): IdentifierId | null { + if (visited.has(key)) { + return null; + } + visited.add(key); + + const parentId = loads.get(key); + + if (parentId === undefined) { + return null; + } + + if (parentId === null) { + return key; + } + + return getRootSetState(parentId, loads, visited); +} + +function maybeRecordSetState( + instr: Instruction, + loads: Map, + usages: Map>, +): void { + for (const operand of eachInstructionLValue(instr)) { + if (isSetStateType(operand.identifier)) { + if (instr.value.kind === 'LoadLocal') { + loads.set(operand.identifier.id, instr.value.place.identifier.id); + } else { + // this is a root setState + loads.set(operand.identifier.id, null); + } + + const rootSetState = getRootSetState(operand.identifier.id, loads); + if (rootSetState !== null && usages.get(rootSetState) === undefined) { + usages.set(rootSetState, new Set([operand.loc])); + } + } + } +} + function recordInstructionDerivations( instr: Instruction, context: ValidationContext, isFirstPass: boolean, ): void { + if (isFirstPass) { + maybeRecordSetState(instr, context.setStateLoads, context.setStateUsages); + } + let typeOfValue: TypeOfValue = 'ignored'; const sources: Set = new Set(); const {lvalue, value} = instr; @@ -323,15 +368,13 @@ function recordInstructionDerivations( } for (const operand of eachInstructionOperand(instr)) { - if ( - isSetStateType(operand.identifier) && - operand.loc !== GeneratedSource && - isFirstPass - ) { - if (context.setStateCache.has(operand.loc.identifierName)) { - context.setStateCache.get(operand.loc.identifierName)!.push(operand); - } else { - context.setStateCache.set(operand.loc.identifierName, [operand]); + if (isSetStateType(operand.identifier) && isFirstPass) { + const rootSetStateId = getRootSetState( + operand.identifier.id, + context.setStateLoads, + ); + if (rootSetStateId !== null) { + context.setStateUsages.get(rootSetStateId)?.add(operand.loc); } } @@ -482,11 +525,16 @@ function validateEffect( const effectDerivedSetStateCalls: Array<{ value: CallExpression; - loc: SourceLocation; + id: IdentifierId; sourceIds: Set; typeOfValue: TypeOfValue; }> = []; + const effectSetStateUsages: Map< + IdentifierId, + Set + > = new Map(); + const globals: Set = new Set(); for (const block of effectFunction.body.blocks.values()) { for (const pred of block.preds) { @@ -502,19 +550,16 @@ function validateEffect( return; } + maybeRecordSetState(instr, context.setStateLoads, effectSetStateUsages); + for (const operand of eachInstructionOperand(instr)) { - if ( - isSetStateType(operand.identifier) && - operand.loc !== GeneratedSource - ) { - if (context.effectSetStateCache.has(operand.loc.identifierName)) { - context.effectSetStateCache - .get(operand.loc.identifierName)! - .push(operand); - } else { - context.effectSetStateCache.set(operand.loc.identifierName, [ - operand, - ]); + if (isSetStateType(operand.identifier)) { + const rootSetStateId = getRootSetState( + operand.identifier.id, + context.setStateLoads, + ); + if (rootSetStateId !== null) { + effectSetStateUsages.get(rootSetStateId)?.add(operand.loc); } } } @@ -532,7 +577,7 @@ function validateEffect( if (argMetadata !== undefined) { effectDerivedSetStateCalls.push({ value: instr.value, - loc: instr.value.callee.loc, + id: instr.value.callee.identifier.id, sourceIds: argMetadata.sourcesIds, typeOfValue: argMetadata.typeOfValue, }); @@ -566,15 +611,17 @@ function validateEffect( } for (const derivedSetStateCall of effectDerivedSetStateCalls) { + const rootSetStateCall = getRootSetState( + derivedSetStateCall.id, + context.setStateLoads, + ); + if ( - derivedSetStateCall.loc !== GeneratedSource && - context.effectSetStateCache.has(derivedSetStateCall.loc.identifierName) && - context.setStateCache.has(derivedSetStateCall.loc.identifierName) && - context.effectSetStateCache.get(derivedSetStateCall.loc.identifierName)! - .length === - context.setStateCache.get(derivedSetStateCall.loc.identifierName)! - .length - - 1 + rootSetStateCall !== null && + effectSetStateUsages.has(rootSetStateCall) && + context.setStateUsages.has(rootSetStateCall) && + effectSetStateUsages.get(rootSetStateCall)!.size === + context.setStateUsages.get(rootSetStateCall)!.size - 1 ) { const allSourceIds = Array.from(derivedSetStateCall.sourceIds); const propsSet = new Set();