diff --git a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp index 07a9266dcd1fa..caa17ea57b950 100644 --- a/mlir/lib/Bytecode/Reader/BytecodeReader.cpp +++ b/mlir/lib/Bytecode/Reader/BytecodeReader.cpp @@ -1503,6 +1503,7 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, // Try parsing with callbacks first if available. for (const auto &callback : parserConfig.getBytecodeReaderConfig().getTypeCallbacks()) { + size_t savedWorklistSize = deferredWorklist.size(); if (failed( callback->read(dialectReader, entry.dialect->name, entry.entry))) return failure(); @@ -1510,14 +1511,17 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, if (!!entry.entry) return success(); - // Reset the reader if we failed to parse, so we can fall through the - // other parsing functions. + // The callback fell through without consuming the encoding. Reset the + // reader and restore the deferred worklist: any entries added during the + // callback's partial read are stale and must not persist. + deferredWorklist.resize(savedWorklistSize); reader = EncodingReader(entry.data, reader.getLoc()); } } else { // Try parsing with callbacks first if available. for (const auto &callback : parserConfig.getBytecodeReaderConfig().getAttributeCallbacks()) { + size_t savedWorklistSize = deferredWorklist.size(); if (failed( callback->read(dialectReader, entry.dialect->name, entry.entry))) return failure(); @@ -1525,8 +1529,10 @@ LogicalResult AttrTypeReader::parseCustomEntry(Entry &entry, if (!!entry.entry) return success(); - // Reset the reader if we failed to parse, so we can fall through the - // other parsing functions. + // The callback fell through without consuming the encoding. Reset the + // reader and restore the deferred worklist: any entries added during the + // callback's partial read are stale and must not persist. + deferredWorklist.resize(savedWorklistSize); reader = EncodingReader(entry.data, reader.getLoc()); } } diff --git a/mlir/test/Bytecode/bytecode_callback_with_spirv_and_custom_attr.mlir b/mlir/test/Bytecode/bytecode_callback_with_spirv_and_custom_attr.mlir new file mode 100644 index 0000000000000..e2bf0453f3f45 --- /dev/null +++ b/mlir/test/Bytecode/bytecode_callback_with_spirv_and_custom_attr.mlir @@ -0,0 +1,30 @@ +// RUN: mlir-opt %s --test-bytecode-roundtrip="test-kind=4" | FileCheck %s + +// Regression test for https://github.com/llvm/llvm-project/issues/163337 +// +// When using test-kind=4, the attribute callback calls the builtin dialect +// reader for each attribute. For dialect-specific attributes (e.g., spirv.*), +// the builtin reader fails and the callback falls through to the regular reader. +// During the failing read, the callback may add entries to the deferred parsing +// worklist. These stale entries must be discarded when the reader position is +// reset, otherwise the `deferredWorklist.empty()` assertion fires in debug +// builds and may corrupt subsequent attribute resolution in release builds. + +spirv.module Logical GLSL450 { + spirv.func @callee() -> () "None" { + spirv.Kill + } + spirv.func @caller() -> () "None" { + spirv.FunctionCall @callee() : () -> () + spirv.Return + } +} + +// CHECK-LABEL: spirv.module Logical GLSL450 +// CHECK: spirv.func @callee +// CHECK: spirv.func @caller + +// Verify the custom test attribute roundtrips correctly through the callback. +"test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> () + +// CHECK: "test.versionedC"() <{attribute = #test.attr_params<42, 24>}> : () -> ()