88import PropTypes , { type Validator } from 'prop-types' ;
99import React , {
1010 cloneElement ,
11+ useCallback ,
1112 useContext ,
1213 useEffect ,
1314 useRef ,
@@ -34,6 +35,7 @@ import { useMergedRefs } from '../../internal/useMergedRefs';
3435import { usePrefix } from '../../internal/usePrefix' ;
3536import { usePreviousValue } from '../../internal/usePreviousValue' ;
3637import { keys , match } from '../../internal/keyboard' ;
38+ import { selectorTabbable } from '../../internal/keyboard/navigation' ;
3739import { IconButton } from '../IconButton' ;
3840import { noopFn } from '../../internal/noopFn' ;
3941import { Text } from '../Text' ;
@@ -122,7 +124,7 @@ export interface ModalProps extends HTMLAttributes<HTMLDivElement> {
122124 /**
123125 * Provide a ref to return focus to once the modal is closed.
124126 */
125- launcherButtonRef ?: RefObject < HTMLButtonElement | null > ;
127+ launcherButtonRef ?: RefObject < HTMLElement | null > ;
126128
127129 /**
128130 * Specify the description for the loading text
@@ -585,6 +587,17 @@ const ModalDialog = React.forwardRef(function ModalDialog(
585587 }
586588 } , [ open , prefix , enableDialogElement ] ) ;
587589
590+ const focusLauncherButton = useCallback ( ( ) => {
591+ if ( ! launcherButtonRef || ! launcherButtonRef . current ) return ;
592+
593+ const { current : launcherButton } = launcherButtonRef ;
594+ const focusTarget = launcherButton . matches ( selectorTabbable )
595+ ? launcherButton
596+ : launcherButton . querySelector < HTMLElement > ( selectorTabbable ) ;
597+
598+ focusTarget ?. focus ( ) ;
599+ } , [ launcherButtonRef ] ) ;
600+
588601 useEffect ( ( ) => {
589602 if (
590603 ! enableDialogElement &&
@@ -594,23 +607,29 @@ const ModalDialog = React.forwardRef(function ModalDialog(
594607 launcherButtonRef
595608 ) {
596609 setTimeout ( ( ) => {
597- if ( 'current' in launcherButtonRef ) {
598- launcherButtonRef . current ?. focus ( ) ;
599- }
610+ focusLauncherButton ( ) ;
600611 } ) ;
601612 }
602- } , [ open , prevOpen , launcherButtonRef , enableDialogElement , enablePresence ] ) ;
613+ } , [
614+ open ,
615+ prevOpen ,
616+ launcherButtonRef ,
617+ enableDialogElement ,
618+ enablePresence ,
619+ focusLauncherButton ,
620+ ] ) ;
621+
603622 // Focus launcherButtonRef on unmount
604623 useEffect ( ( ) => {
605624 const launcherButton = launcherButtonRef ?. current ;
606625 return ( ) => {
607626 if ( enablePresence && launcherButton ) {
608627 setTimeout ( ( ) => {
609- launcherButton . focus ( ) ;
628+ focusLauncherButton ( ) ;
610629 } ) ;
611630 }
612631 } ;
613- } , [ enablePresence , launcherButtonRef ] ) ;
632+ } , [ enablePresence , launcherButtonRef , focusLauncherButton ] ) ;
614633
615634 useEffect ( ( ) => {
616635 if ( ! enableDialogElement ) {
@@ -964,16 +983,16 @@ Modal.propTypes = {
964983 PropTypes . func ,
965984 PropTypes . shape ( {
966985 current : PropTypes . oneOfType ( [
967- // `PropTypes.instanceOf(HTMLButtonElement )` alone won't work because
968- // `HTMLButtonElement ` is not defined in the test environment even
986+ // `PropTypes.instanceOf(HTMLElement )` alone won't work because
987+ // `HTMLElement ` is not defined in the test environment even
969988 // though `testEnvironment` is set to `jsdom`.
970- typeof HTMLButtonElement !== 'undefined'
971- ? PropTypes . instanceOf ( HTMLButtonElement )
989+ typeof HTMLElement !== 'undefined'
990+ ? PropTypes . instanceOf ( HTMLElement )
972991 : PropTypes . any ,
973992 PropTypes . oneOf ( [ null ] ) ,
974993 ] ) . isRequired ,
975994 } ) ,
976- ] ) as Validator < RefObject < HTMLButtonElement | null > > ,
995+ ] ) as Validator < RefObject < HTMLElement | null > > ,
977996
978997 /**
979998 * Specify the description for the loading text
0 commit comments