Skip to content

Commit f997244

Browse files
samestepfacebook-github-bot
authored andcommitted
Abort node in fast_nvcc if ancestor fails
Summary: This PR makes `fast_nvcc` stop at failing commands, rather than continuing on to run commands that would otherwise run after those commands. It is still possible for `fast_nvcc` to run more commands than `nvcc` would run if there's no dependency between them, but this should still help to reduce noise from failing `fast_nvcc` runs. Test Plan: Unfortunately the test suite for this script is FB-internal. It would probably be a good idea to move it into the PyTorch GitHub repo, but I'm not entirely sure how to do so, since I don't believe we currently have a good place to put tests for things in `tools`. Reviewed By: malfet Differential Revision: D26007788 fbshipit-source-id: 3ab7bf623e9b89e940bbb6bb92a0cf0ec9f28193
1 parent f7b339d commit f997244

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

tools/fast_nvcc/fast_nvcc.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,10 @@ async def run_command(command, *, env, deps, gather_data, i, save):
337337
Run the command with the given env after waiting for deps.
338338
"""
339339
for task in deps:
340-
await task
340+
dep_result = await task
341+
# abort if a previous step failed
342+
if 'exit_code' not in dep_result or dep_result['exit_code'] != 0:
343+
return {}
341344
if gather_data:
342345
t1 = time.monotonic()
343346
proc = await asyncio.create_subprocess_shell(
@@ -368,7 +371,7 @@ async def run_command(command, *, env, deps, gather_data, i, save):
368371
return results
369372

370373

371-
async def run_graph(*, env, commands, graph, gather_data, save):
374+
async def run_graph(*, env, commands, graph, gather_data=False, save=None):
372375
"""
373376
Return outputs/errors (and optionally time/file info) from commands.
374377
"""
@@ -391,8 +394,8 @@ def print_command_outputs(command_results):
391394
Print captured stdout and stderr from commands.
392395
"""
393396
for result in command_results:
394-
sys.stdout.write(result['stdout'].decode('ascii'))
395-
sys.stderr.write(result['stderr'].decode('ascii'))
397+
sys.stdout.write(result.get('stdout', b'').decode('ascii'))
398+
sys.stderr.write(result.get('stderr', b'').decode('ascii'))
396399

397400

398401
def write_log_csv(command_parts, command_results, *, filename):
@@ -401,23 +404,23 @@ def write_log_csv(command_parts, command_results, *, filename):
401404
"""
402405
tmp_files = []
403406
for result in command_results:
404-
tmp_files.extend(result['files'].keys())
407+
tmp_files.extend(result.get('files', {}).keys())
405408
with open(filename, 'w', newline='') as csvfile:
406409
fieldnames = ['command', 'seconds'] + list(dict.fromkeys(tmp_files))
407410
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
408411
writer.writeheader()
409412
for i, result in enumerate(command_results):
410413
command = f'{i} {os.path.basename(command_parts[i][0])}'
411-
row = {'command': command, 'seconds': result['time']}
412-
writer.writerow({**row, **result['files']})
414+
row = {'command': command, 'seconds': result.get('time', 0)}
415+
writer.writerow({**row, **result.get('files', {})})
413416

414417

415418
def exit_code(results):
416419
"""
417420
Aggregate individual exit codes into a single code.
418421
"""
419422
for result in results:
420-
code = result['exit_code']
423+
code = result.get('exit_code', 0)
421424
if code != 0:
422425
return code
423426
return 0

0 commit comments

Comments
 (0)