File tree Expand file tree Collapse file tree
test/distributed/_shard/sharded_tensor
torch/distributed/_shard/sharded_tensor Expand file tree Collapse file tree Original file line number Diff line number Diff line change 1+ # Owner(s): ["oncall: distributed"]
2+
3+ import logging
4+
5+ from torch .distributed ._shard .sharded_tensor .logger import _get_or_create_logger
6+ from torch .testing ._internal .common_utils import (
7+ TestCase ,
8+ run_tests ,
9+ )
10+
11+
12+ class ShardingSpecLoggerTest (TestCase ):
13+ def test_get_or_create_logger (self ):
14+ logger = _get_or_create_logger ()
15+ self .assertIsNotNone (logger )
16+ self .assertEqual (1 , len (logger .handlers ))
17+ self .assertIsInstance (logger .handlers [0 ], logging .NullHandler )
18+
19+
20+ if __name__ == "__main__" :
21+ run_tests ()
Original file line number Diff line number Diff line change 1+ #!/usr/bin/env python3
2+
3+ # Copyright (c) Facebook, Inc. and its affiliates.
4+ # All rights reserved.
5+ #
6+ # This source code is licensed under the BSD-style license found in the
7+ # LICENSE file in the root directory of this source tree.
8+
9+ import logging
10+ from typing import List , Tuple
11+
12+ from torch .distributed ._shard .sharded_tensor .logging_handlers import (
13+ _log_handlers ,
14+ )
15+
16+ __all__ : List [str ] = []
17+
18+
19+ def _get_or_create_logger () -> logging .Logger :
20+ logging_handler , log_handler_name = _get_logging_handler ()
21+ logger = logging .getLogger (f"sharding-spec-{ log_handler_name } " )
22+ logger .setLevel (logging .DEBUG )
23+ formatter = logging .Formatter (
24+ "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
25+ )
26+ logging_handler .setFormatter (formatter )
27+ logger .propagate = False
28+ logger .addHandler (logging_handler )
29+ return logger
30+
31+
32+ def _get_logging_handler (
33+ destination : str = "default" ,
34+ ) -> Tuple [logging .Handler , str ]:
35+ log_handler = _log_handlers [destination ]
36+ log_handler_name = type (log_handler ).__name__
37+ return (log_handler , log_handler_name )
Original file line number Diff line number Diff line change 1+ #!/usr/bin/env python3
2+
3+ # Copyright (c) Facebook, Inc. and its affiliates.
4+ # All rights reserved.
5+ #
6+ # This source code is licensed under the BSD-style license found in the
7+ # LICENSE file in the root directory of this source tree.
8+
9+ import logging
10+ from typing import Dict , List
11+
12+ __all__ : List [str ] = []
13+
14+ _log_handlers : Dict [str , logging .Handler ] = {
15+ "default" : logging .NullHandler (),
16+ }
You can’t perform that action at this time.
0 commit comments