Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 105 additions & 80 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1081,97 +1081,122 @@ void SemaHLSL::ActOnFinishRootSignatureDecl(
SemaRef.PushOnScopeChains(SignatureDecl, SemaRef.getCurScope());
}

namespace {

class RangeInfoMap {
using RangeInfo = llvm::hlsl::rootsig::RangeInfo;

llvm::SmallVector<RangeInfo> Infos;
llvm::SmallVector<const hlsl::RootSignatureElement *> ElemRefs;

void collectRangeInfos(ArrayRef<hlsl::RootSignatureElement> Elements) {
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
if (const auto *Descriptor =
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Descriptor->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class =
llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
Info.Space = Descriptor->Space;
Info.Visibility = Descriptor->Visibility;

Infos.push_back(Info);
ElemRefs.push_back(&RootSigElem);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Constants->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class = llvm::dxil::ResourceClass::CBuffer;
Info.Space = Constants->Space;
Info.Visibility = Constants->Visibility;

Infos.push_back(Info);
ElemRefs.push_back(&RootSigElem);
} else if (const auto *Sampler =
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Sampler->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class = llvm::dxil::ResourceClass::Sampler;
Info.Space = Sampler->Space;
Info.Visibility = Sampler->Visibility;

Infos.push_back(Info);
ElemRefs.push_back(&RootSigElem);
} else if (const auto *Clause =
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
&Elem)) {
RangeInfo Info;
Info.LowerBound = Clause->Reg.Number;
assert(0 < Clause->NumDescriptors && "Verified as part of TODO(#129940)");
Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
? RangeInfo::Unbounded
: Info.LowerBound + Clause->NumDescriptors -
1; // use inclusive ranges []

Info.Class = Clause->Type;
Info.Space = Clause->Space;

// Note: Clause does not hold visibility, is updated below
Infos.push_back(Info);
ElemRefs.push_back(&RootSigElem);
} else if (const auto *Table =
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
// Table holds the Visibility of all owned Clauses in Table, so iterate
// owned Clauses and update their corresponding RangeInfo
assert(Table->NumClauses <= Infos.size() && "RootElement");
// The last Table->NumClauses elements of Infos are the owned Clauses
// generated RangeInfo
auto TableInfos =
MutableArrayRef<RangeInfo>(Infos).take_back(Table->NumClauses);
for (RangeInfo &Info : TableInfos)
Info.Visibility = Table->Visibility;
}
}
}

public:
RangeInfoMap(ArrayRef<hlsl::RootSignatureElement> Elements) {
collectRangeInfos(Elements);
}

ArrayRef<RangeInfo> getInfos() const {
return Infos;
}

const hlsl::RootSignatureElement *getElement(const RangeInfo *Info) const {
const RangeInfo * const InfoStart = getInfos().data();
assert(InfoStart <= Info && "Out of range RangeInfo");
size_t Index = Info - InfoStart;
assert(Index < getInfos().size() && "Out of range RangeInfo");
return ElemRefs[Index];
}
};

} // namespace

bool SemaHLSL::handleRootSignatureElements(
ArrayRef<hlsl::RootSignatureElement> Elements) {
using RangeInfo = llvm::hlsl::rootsig::RangeInfo;
using OverlappingRanges = llvm::hlsl::rootsig::OverlappingRanges;

// Each RangeInfo will contain an index back to its associated
// RootSignatureElement in our Elements ArrayRef
size_t InfoIndex = 0;

// 1. Collect RangeInfos
llvm::SmallVector<RangeInfo> Infos;
for (const hlsl::RootSignatureElement &RootSigElem : Elements) {
const llvm::hlsl::rootsig::RootElement &Elem = RootSigElem.getElement();
if (const auto *Descriptor =
std::get_if<llvm::hlsl::rootsig::RootDescriptor>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Descriptor->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class =
llvm::dxil::ResourceClass(llvm::to_underlying(Descriptor->Type));
Info.Space = Descriptor->Space;
Info.Visibility = Descriptor->Visibility;

Info.Index = InfoIndex;
Infos.push_back(Info);
} else if (const auto *Constants =
std::get_if<llvm::hlsl::rootsig::RootConstants>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Constants->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class = llvm::dxil::ResourceClass::CBuffer;
Info.Space = Constants->Space;
Info.Visibility = Constants->Visibility;

Info.Index = InfoIndex;
Infos.push_back(Info);
} else if (const auto *Sampler =
std::get_if<llvm::hlsl::rootsig::StaticSampler>(&Elem)) {
RangeInfo Info;
Info.LowerBound = Sampler->Reg.Number;
Info.UpperBound = Info.LowerBound; // use inclusive ranges []

Info.Class = llvm::dxil::ResourceClass::Sampler;
Info.Space = Sampler->Space;
Info.Visibility = Sampler->Visibility;

Info.Index = InfoIndex;
Infos.push_back(Info);
} else if (const auto *Clause =
std::get_if<llvm::hlsl::rootsig::DescriptorTableClause>(
&Elem)) {
RangeInfo Info;
Info.LowerBound = Clause->Reg.Number;
assert(0 < Clause->NumDescriptors && "Verified as part of TODO(#129940)");
Info.UpperBound = Clause->NumDescriptors == RangeInfo::Unbounded
? RangeInfo::Unbounded
: Info.LowerBound + Clause->NumDescriptors -
1; // use inclusive ranges []

Info.Class = Clause->Type;
Info.Space = Clause->Space;

// Note: Clause does not hold the visibility this will need to
Info.Index = InfoIndex;
Infos.push_back(Info);
} else if (const auto *Table =
std::get_if<llvm::hlsl::rootsig::DescriptorTable>(&Elem)) {
// Table holds the Visibility of all owned Clauses in Table, so iterate
// owned Clauses and update their corresponding RangeInfo
assert(Table->NumClauses <= Infos.size() && "RootElement");
// The last Table->NumClauses elements of Infos are the owned Clauses
// generated RangeInfo
auto TableInfos =
MutableArrayRef<RangeInfo>(Infos).take_back(Table->NumClauses);
for (RangeInfo &Info : TableInfos)
Info.Visibility = Table->Visibility;
}

InfoIndex++;
}
const RangeInfoMap InfoMap(Elements);

// Helper to report diagnostics
auto ReportOverlap = [this, &Elements](OverlappingRanges Overlap) {
auto ReportOverlap = [this, &InfoMap](OverlappingRanges Overlap) {
const RangeInfo *Info = Overlap.A;
const RangeInfo *OInfo = Overlap.B;
auto CommonVis = Info->Visibility == llvm::dxbc::ShaderVisibility::All
? OInfo->Visibility
: Info->Visibility;
SourceLocation InfoLoc = Elements[Info->Index].getLocation();
SourceLocation InfoLoc = InfoMap.getElement(Info)->getLocation();
this->Diag(InfoLoc, diag::err_hlsl_resource_range_overlap)
<< llvm::to_underlying(Info->Class) << Info->LowerBound
<< /*unbounded=*/(Info->UpperBound == RangeInfo::Unbounded)
Expand All @@ -1180,12 +1205,12 @@ bool SemaHLSL::handleRootSignatureElements(
<< /*unbounded=*/(OInfo->UpperBound == RangeInfo::Unbounded)
<< OInfo->UpperBound << Info->Space << CommonVis;

SourceLocation OInfoLoc = Elements[OInfo->Index].getLocation();
SourceLocation OInfoLoc = InfoMap.getElement(OInfo)->getLocation();
this->Diag(OInfoLoc, diag::note_hlsl_resource_range_here);
};

llvm::SmallVector<OverlappingRanges> Overlaps =
llvm::hlsl::rootsig::findOverlappingRanges(Infos);
llvm::hlsl::rootsig::findOverlappingRanges(InfoMap.getInfos());
for (OverlappingRanges Overlap : Overlaps)
ReportOverlap(Overlap);

Expand Down
21 changes: 9 additions & 12 deletions llvm/include/llvm/Frontend/HLSL/RootSignatureValidations.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,6 @@ struct RangeInfo {
llvm::dxil::ResourceClass Class;
uint32_t Space;
llvm::dxbc::ShaderVisibility Visibility;

// The index retains its original position before being sorted by group.
size_t Index;
};

class ResourceRange {
Expand All @@ -70,7 +67,7 @@ class ResourceRange {
// Returns a reference to the first RangeInfo that overlaps with
// [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap
LLVM_ABI std::optional<const RangeInfo *>
getOverlapping(const RangeInfo &Info) const;
getOverlapping(const RangeInfo *Info) const;

// Return the mapped RangeInfo at X or nullptr if no mapping exists
LLVM_ABI const RangeInfo *lookup(uint32_t X) const;
Expand All @@ -86,25 +83,25 @@ class ResourceRange {
// intervals denoting the Lower/Upper-bounds:
//
// A = [0;2]
// insert(A) -> false
// insert(&A) -> false
// intervals: [0;2] -> &A
// B = [5;7]
// insert(B) -> false
// insert(&B) -> false
// intervals: [0;2] -> &A, [5;7] -> &B
// C = [4;7]
// insert(C) -> true
// insert(&C) -> true
// intervals: [0;2] -> &A, [4;7] -> &C
// D = [1;5]
// insert(D) -> true
// insert(&D) -> true
// intervals: [0;2] -> &A, [3;3] -> &D, [4;7] -> &C
// E = [0;unbounded]
// insert(E) -> true
// intervals: [0;unbounded] -> E
// insert(&E) -> true
// intervals: [0;unbounded] -> &E
//
// Returns a reference to the first RangeInfo that overlaps with
// [Info.LowerBound;Info.UpperBound], or, std::nullopt if there is no overlap
// (equivalent to getOverlapping)
LLVM_ABI std::optional<const RangeInfo *> insert(const RangeInfo &Info);
LLVM_ABI std::optional<const RangeInfo *> insert(const RangeInfo *Info);
};

struct OverlappingRanges {
Expand Down Expand Up @@ -137,7 +134,7 @@ struct OverlappingRanges {
/// ResourceRange
/// B: Check for overlap with any overlapping Visibility ResourceRange
llvm::SmallVector<OverlappingRanges>
findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos);
findOverlappingRanges(ArrayRef<RangeInfo> Infos);

} // namespace rootsig
} // namespace hlsl
Expand Down
46 changes: 26 additions & 20 deletions llvm/lib/Frontend/HLSL/RootSignatureValidations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ bool verifyBorderColor(uint32_t BorderColor) {
bool verifyLOD(float LOD) { return !std::isnan(LOD); }

std::optional<const RangeInfo *>
ResourceRange::getOverlapping(const RangeInfo &Info) const {
MapT::const_iterator Interval = Intervals.find(Info.LowerBound);
if (!Interval.valid() || Info.UpperBound < Interval.start())
ResourceRange::getOverlapping(const RangeInfo *Info) const {
MapT::const_iterator Interval = Intervals.find(Info->LowerBound);
if (!Interval.valid() || Info->UpperBound < Interval.start())
return std::nullopt;
return Interval.value();
}
Expand All @@ -194,9 +194,9 @@ const RangeInfo *ResourceRange::lookup(uint32_t X) const {

void ResourceRange::clear() { return Intervals.clear(); }

std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
uint32_t LowerBound = Info.LowerBound;
uint32_t UpperBound = Info.UpperBound;
std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo *Info) {
uint32_t LowerBound = Info->LowerBound;
uint32_t UpperBound = Info->UpperBound;

std::optional<const RangeInfo *> Res = std::nullopt;
MapT::iterator Interval = Intervals.begin();
Expand Down Expand Up @@ -239,25 +239,31 @@ std::optional<const RangeInfo *> ResourceRange::insert(const RangeInfo &Info) {
}

assert(LowerBound <= UpperBound && "Attempting to insert an empty interval");
Intervals.insert(LowerBound, UpperBound, &Info);
Intervals.insert(LowerBound, UpperBound, Info);
return Res;
}

llvm::SmallVector<OverlappingRanges>
findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) {
findOverlappingRanges(ArrayRef<RangeInfo> Infos) {
// 1. The user has provided the corresponding range information
llvm::SmallVector<OverlappingRanges> Overlaps;
using GroupT = std::pair<dxil::ResourceClass, /*Space*/ uint32_t>;

llvm::SmallVector<const RangeInfo *> InfoRefs;
for (const RangeInfo &Info : Infos)
InfoRefs.push_back(&Info);

// 2. Sort the RangeInfo's by their GroupT to form groupings
std::sort(Infos.begin(), Infos.end(), [](RangeInfo A, RangeInfo B) {
return std::tie(A.Class, A.Space) < std::tie(B.Class, B.Space);
});
std::sort(InfoRefs.begin(), InfoRefs.end(),
[](const RangeInfo *A, const RangeInfo *B) {
return std::tie(A->Class, A->Space) <
std::tie(B->Class, B->Space);
});

// 3. First we will init our state to track:
if (Infos.size() == 0)
if (InfoRefs.size() == 0)
return Overlaps; // No ranges to overlap
GroupT CurGroup = {Infos[0].Class, Infos[0].Space};
GroupT CurGroup = {InfoRefs[0]->Class, InfoRefs[0]->Space};

// Create a ResourceRange for each Visibility
ResourceRange::MapT::Allocator Allocator;
Expand All @@ -278,19 +284,19 @@ findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) {
Range.clear();
};

// 3: Iterate through collected RangeInfos
for (const RangeInfo &Info : Infos) {
GroupT InfoGroup = {Info.Class, Info.Space};
// 3: Iterate through collected RangeInfoRefs
for (const RangeInfo *Info : InfoRefs) {
GroupT InfoGroup = {Info->Class, Info->Space};
// Reset our ResourceRanges when we enter a new group
if (CurGroup != InfoGroup) {
ClearRanges();
CurGroup = InfoGroup;
}

// 3A: Insert range info into corresponding Visibility ResourceRange
ResourceRange &VisRange = Ranges[llvm::to_underlying(Info.Visibility)];
ResourceRange &VisRange = Ranges[llvm::to_underlying(Info->Visibility)];
if (std::optional<const RangeInfo *> Overlapping = VisRange.insert(Info))
Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
Overlaps.push_back(OverlappingRanges(Info, Overlapping.value()));

// 3B: Check for overlap in all overlapping Visibility ResourceRanges
//
Expand All @@ -303,14 +309,14 @@ findOverlappingRanges(llvm::SmallVector<RangeInfo> &Infos) {
// ResourceRanges in the former case and it will be an ArrayRef to just the
// all visiblity ResourceRange in the latter case.
ArrayRef<ResourceRange> OverlapRanges =
Info.Visibility == llvm::dxbc::ShaderVisibility::All
Info->Visibility == llvm::dxbc::ShaderVisibility::All
? ArrayRef<ResourceRange>{Ranges}.drop_front()
: ArrayRef<ResourceRange>{Ranges}.take_front();

for (const ResourceRange &Range : OverlapRanges)
if (std::optional<const RangeInfo *> Overlapping =
Range.getOverlapping(Info))
Overlaps.push_back(OverlappingRanges(&Info, Overlapping.value()));
Overlaps.push_back(OverlappingRanges(Info, Overlapping.value()));
}

return Overlaps;
Expand Down
Loading