Skip to content

Commit ca8625f

Browse files
wz337pytorchmergebot
authored andcommitted
[BE][1/N]Add sharding spec logger for ShardedTensor (#99748)
Set up a nullHandler() on the OSS side. Next step is to set up the counterpart in internal. This is part of the effort for ShardedTensor deprecation. We want to log internal use cases for different sharding spec. Pull Request resolved: #99748 Approved by: https://github.com/H-Huang, https://github.com/fegin
1 parent bd71911 commit ca8625f

3 files changed

Lines changed: 74 additions & 0 deletions

File tree

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
import logging
4+
5+
from torch.distributed._shard.sharded_tensor.logger import _get_or_create_logger
6+
from torch.testing._internal.common_utils import (
7+
TestCase,
8+
run_tests,
9+
)
10+
11+
12+
class ShardingSpecLoggerTest(TestCase):
13+
def test_get_or_create_logger(self):
14+
logger = _get_or_create_logger()
15+
self.assertIsNotNone(logger)
16+
self.assertEqual(1, len(logger.handlers))
17+
self.assertIsInstance(logger.handlers[0], logging.NullHandler)
18+
19+
20+
if __name__ == "__main__":
21+
run_tests()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
import logging
10+
from typing import List, Tuple
11+
12+
from torch.distributed._shard.sharded_tensor.logging_handlers import (
13+
_log_handlers,
14+
)
15+
16+
__all__: List[str] = []
17+
18+
19+
def _get_or_create_logger() -> logging.Logger:
20+
logging_handler, log_handler_name = _get_logging_handler()
21+
logger = logging.getLogger(f"sharding-spec-{log_handler_name}")
22+
logger.setLevel(logging.DEBUG)
23+
formatter = logging.Formatter(
24+
"%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s"
25+
)
26+
logging_handler.setFormatter(formatter)
27+
logger.propagate = False
28+
logger.addHandler(logging_handler)
29+
return logger
30+
31+
32+
def _get_logging_handler(
33+
destination: str = "default",
34+
) -> Tuple[logging.Handler, str]:
35+
log_handler = _log_handlers[destination]
36+
log_handler_name = type(log_handler).__name__
37+
return (log_handler, log_handler_name)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) Facebook, Inc. and its affiliates.
4+
# All rights reserved.
5+
#
6+
# This source code is licensed under the BSD-style license found in the
7+
# LICENSE file in the root directory of this source tree.
8+
9+
import logging
10+
from typing import Dict, List
11+
12+
__all__: List[str] = []
13+
14+
_log_handlers: Dict[str, logging.Handler] = {
15+
"default": logging.NullHandler(),
16+
}

0 commit comments

Comments
 (0)