Skip to content
This repository was archived by the owner on Mar 2, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 66 additions & 35 deletions sycl/include/sycl/ext/oneapi/experimental/graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,45 @@ struct graph_impl {
MSchedule.clear();
}

template <typename T>
node_ptr add(graph_ptr impl, T cgf, const std::vector<node_ptr> &dep = {}) {
node_ptr nodeImpl = std::make_shared<node_impl>(impl, cgf);
if (!dep.empty()) {
for (auto n : dep) {
n->register_successor(nodeImpl); // register successor
this->remove_root(nodeImpl); // remove receiver from root node
// list
}
} else {
this->add_root(nodeImpl);
}
return nodeImpl;
}

graph_impl() : MFirst(true) {}
};

} // namespace detail

struct node {
detail::node_ptr MNode;
detail::graph_ptr MGraph;

class node {
public:
template <typename T>
node(detail::graph_ptr g, T cgf)
: MGraph(g), MNode(new detail::node_impl(g, cgf)){};
void register_successor(node n) { MNode->register_successor(n.MNode); }
void exec(sycl::queue q, sycl::event = sycl::event()) { MNode->exec(q); }
: MGraph(g), impl(new detail::node_impl(g, cgf)) {}
void register_successor(node n) { impl->register_successor(n.impl); }
void exec(sycl::queue q) { impl->exec(q); }

private:
node(detail::node_ptr Impl) : impl(Impl) {}

template <class Obj>
friend decltype(Obj::impl)
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
template <class T>
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

void set_root() { MGraph->add_root(MNode); }
detail::node_ptr impl;
detail::graph_ptr MGraph;
};

template <graph_state State = graph_state::modifiable> class command_graph {
Expand All @@ -165,60 +188,68 @@ template <graph_state State = graph_state::modifiable> class command_graph {
command_graph<graph_state::executable>
finalize(const sycl::context &syclContext) const;

command_graph() : MGraph(new detail::graph_impl()) {}
command_graph() : impl(new detail::graph_impl()) {}

private:
detail::graph_ptr MGraph;
command_graph(detail::graph_ptr Impl) : impl(Impl) {}

template <class Obj>
friend decltype(Obj::impl)
sycl::detail::getSyclObjImpl(const Obj &SyclObject);
template <class T>
friend T sycl::detail::createSyclObjFromImpl(decltype(T::impl) ImplObj);

detail::graph_ptr impl;
};

template <> class command_graph<graph_state::executable> {
public:
int MTag;
const sycl::context &MCtx;

void exec_and_wait(sycl::queue q);

command_graph() = delete;

command_graph(detail::graph_ptr g, const sycl::context &ctx)
: MGraph(g), MCtx(ctx), MTag(rand()) {}
: MTag(rand()), MCtx(ctx), impl(g) {}

private:
detail::graph_ptr MGraph;
int MTag;
const sycl::context &MCtx;
detail::graph_ptr impl;
};

template <>
template <typename T>
node command_graph<graph_state::modifiable>::add(T cgf,
const std::vector<node> &dep) {
node ret_val(MGraph, cgf);
if (!dep.empty()) {
for (auto n : dep)
this->make_edge(n, ret_val);
} else {
ret_val.set_root();
inline node
command_graph<graph_state::modifiable>::add(T cgf,
const std::vector<node> &dep) {
std::vector<detail::node_ptr> depImpls;
for (auto &d : dep) {
depImpls.push_back(sycl::detail::getSyclObjImpl(d));
}
return ret_val;

auto nodeImpl = impl->add(impl, cgf, depImpls);
return sycl::detail::createSyclObjFromImpl<node>(nodeImpl);
}

template <>
void command_graph<graph_state::modifiable>::make_edge(node sender,
node receiver) {
inline void command_graph<graph_state::modifiable>::make_edge(node sender,
node receiver) {
sender.register_successor(receiver); // register successor
MGraph->remove_root(receiver.MNode); // remove receiver from root node
// list
impl->remove_root(
sycl::detail::getSyclObjImpl(receiver)); // remove receiver from root node
// list
}

template <>
command_graph<graph_state::executable>
command_graph<graph_state::modifiable>::finalize(
const sycl::context &ctx) const {
return command_graph<graph_state::executable>{this->MGraph, ctx};
command_graph<graph_state::executable> inline command_graph<
graph_state::modifiable>::finalize(const sycl::context &ctx) const {
return command_graph<graph_state::executable>{this->impl, ctx};
}

void command_graph<graph_state::executable>::exec_and_wait(sycl::queue q) {
MGraph->exec_and_wait(q);
};
inline void
command_graph<graph_state::executable>::exec_and_wait(sycl::queue q) {
impl->exec_and_wait(q);
}

} // namespace experimental
} // namespace oneapi
Expand Down
12 changes: 7 additions & 5 deletions sycl/test/graph/graph-explicit-dotp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,8 @@ int main() {
sycl::ext::oneapi::property::queue::lazy_execution{}
};

sycl::gpu_selector device_selector;

sycl::queue q{device_selector, properties};

sycl::queue q{sycl::gpu_selector_v, properties};

sycl::ext::oneapi::experimental::command_graph g;

float *dotp = sycl::malloc_shared<float>(1, q);
Expand Down Expand Up @@ -80,7 +78,11 @@ int main() {
auto exec_graph = g.finalize(q.get_context());

exec_graph.exec_and_wait(q);


if (*dotp != host_gold_result()) {
std::cout << "Error unexpected result!\n";
}

sycl::free(dotp, q);
sycl::free(x, q);
sycl::free(y, q);
Expand Down