Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ struct ExecPlanImpl : public ExecPlan {
futures.push_back(node->finished());
}

finished_ = AllComplete(std::move(futures));
finished_ = AllFinished(futures);
return st;
}

Expand Down
20 changes: 20 additions & 0 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,26 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions {
std::function<Future<util::optional<ExecBatch>>()>* generator;
};

class ARROW_EXPORT SinkNodeConsumer {
public:
virtual ~SinkNodeConsumer() = default;
/// \brief Consume a batch of data
virtual Status Consume(ExecBatch batch) = 0;
/// \brief Signal to the consumer that the last batch has been delivered
///
/// The returned future should only finish when all outstanding tasks have completed
virtual Future<> Finish() = 0;
};

/// \brief Add a sink node which consumes data within the exec plan run
class ARROW_EXPORT ConsumingSinkNodeOptions : public ExecNodeOptions {
public:
explicit ConsumingSinkNodeOptions(std::shared_ptr<SinkNodeConsumer> consumer)
: consumer(std::move(consumer)) {}

std::shared_ptr<SinkNodeConsumer> consumer;
};

/// \brief Make a node which sorts rows passed through it
///
/// All batches pushed to this node will be accumulated, then sorted, by the given
Expand Down
102 changes: 102 additions & 0 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,108 @@ TEST(ExecPlanExecution, SourceSinkError) {
Finishes(Raises(StatusCode::Invalid, HasSubstr("Artificial"))));
}

TEST(ExecPlanExecution, SourceConsumingSink) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");

for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel" : "single threaded");
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
std::atomic<uint32_t> batches_seen{0};
Future<> finish = Future<>::Make();
struct TestConsumer : public SinkNodeConsumer {
TestConsumer(std::atomic<uint32_t>* batches_seen, Future<> finish)
: batches_seen(batches_seen), finish(std::move(finish)) {}

Status Consume(ExecBatch batch) override {
(*batches_seen)++;
return Status::OK();
}

Future<> Finish() override { return finish; }

std::atomic<uint32_t>* batches_seen;
Future<> finish;
};
std::shared_ptr<TestConsumer> consumer =
std::make_shared<TestConsumer>(&batches_seen, finish);

auto basic_data = MakeBasicBatches();
ASSERT_OK_AND_ASSIGN(
auto source, MakeExecNode("source", plan.get(), {},
SourceNodeOptions(basic_data.schema,
basic_data.gen(parallel, slow))));
ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
ConsumingSinkNodeOptions(consumer)));
ASSERT_OK(plan->StartProducing());
// Source should finish fairly quickly
ASSERT_FINISHES_OK(source->finished());
SleepABit();
ASSERT_EQ(2, batches_seen);
// Consumer isn't finished and so plan shouldn't have finished
AssertNotFinished(plan->finished());
// Mark consumption complete, plan should finish
finish.MarkFinished();
ASSERT_FINISHES_OK(plan->finished());
}
}
}

TEST(ExecPlanExecution, ConsumingSinkError) {
struct ConsumeErrorConsumer : public SinkNodeConsumer {
Status Consume(ExecBatch batch) override { return Status::Invalid("XYZ"); }
Future<> Finish() override { return Future<>::MakeFinished(); }
};
struct FinishErrorConsumer : public SinkNodeConsumer {
Status Consume(ExecBatch batch) override { return Status::OK(); }
Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
};
std::vector<std::shared_ptr<SinkNodeConsumer>> consumers{
std::make_shared<ConsumeErrorConsumer>(), std::make_shared<FinishErrorConsumer>()};

for (auto& consumer : consumers) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
auto basic_data = MakeBasicBatches();
ASSERT_OK(Declaration::Sequence(
{{"source",
SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))},
{"consuming_sink", ConsumingSinkNodeOptions(consumer)}})
.AddToPlan(plan.get()));
ASSERT_OK_AND_ASSIGN(
auto source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
ConsumingSinkNodeOptions(consumer)));
ASSERT_OK(plan->StartProducing());
ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
}
}

TEST(ExecPlanExecution, ConsumingSinkErrorFinish) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
struct FinishErrorConsumer : public SinkNodeConsumer {
Status Consume(ExecBatch batch) override { return Status::OK(); }
Future<> Finish() override { return Future<>::MakeFinished(Status::Invalid("XYZ")); }
};
std::shared_ptr<FinishErrorConsumer> consumer = std::make_shared<FinishErrorConsumer>();

auto basic_data = MakeBasicBatches();
ASSERT_OK(
Declaration::Sequence(
{{"source", SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))},
{"consuming_sink", ConsumingSinkNodeOptions(consumer)}})
.AddToPlan(plan.get()));
ASSERT_OK_AND_ASSIGN(
auto source,
MakeExecNode("source", plan.get(), {},
SourceNodeOptions(basic_data.schema, basic_data.gen(false, false))));
ASSERT_OK(MakeExecNode("consuming_sink", plan.get(), {source},
ConsumingSinkNodeOptions(consumer)));
ASSERT_OK(plan->StartProducing());
ASSERT_FINISHES_AND_RAISES(Invalid, plan->finished());
}

TEST(ExecPlanExecution, StressSourceSink) {
for (bool slow : {false, true}) {
SCOPED_TRACE(slow ? "slowed" : "unslowed");
Expand Down
100 changes: 100 additions & 0 deletions cpp/src/arrow/compute/exec/sink_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "arrow/result.h"
#include "arrow/table.h"
#include "arrow/util/async_generator.h"
#include "arrow/util/async_util.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/future.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -132,6 +133,104 @@ class SinkNode : public ExecNode {
PushGenerator<util::optional<ExecBatch>>::Producer producer_;
};

// A sink node that owns consuming the data and will not finish until the consumption
// is finished. Use SinkNode if you are transferring the ownership of the data to another
// system. Use ConsumingSinkNode if the data is being consumed within the exec plan (i.e.
// the exec plan should not complete until the consumption has completed).
class ConsumingSinkNode : public ExecNode {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add an explicit test of this node in plan_test.cc, as with the other nodes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a few unit tests to plan_test.cc

public:
ConsumingSinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::shared_ptr<SinkNodeConsumer> consumer)
: ExecNode(plan, std::move(inputs), {"to_consume"}, {},
/*num_outputs=*/0),
consumer_(std::move(consumer)) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 1, "SinkNode"));

const auto& sink_options = checked_cast<const ConsumingSinkNodeOptions&>(options);
return plan->EmplaceNode<ConsumingSinkNode>(plan, std::move(inputs),
std::move(sink_options.consumer));
}

const char* kind_name() const override { return "ConsumingSinkNode"; }

Status StartProducing() override {
finished_ = Future<>::Make();
return Status::OK();
}

// sink nodes have no outputs from which to feel backpressure
[[noreturn]] static void NoOutputs() {
Unreachable("no outputs; this should never be called");
}
[[noreturn]] void ResumeProducing(ExecNode* output) override { NoOutputs(); }
[[noreturn]] void PauseProducing(ExecNode* output) override { NoOutputs(); }
[[noreturn]] void StopProducing(ExecNode* output) override { NoOutputs(); }

void StopProducing() override {
Finish(Status::Invalid("ExecPlan was stopped early"));
inputs_[0]->StopProducing(this);
}

Future<> finished() override { return finished_; }

void InputReceived(ExecNode* input, ExecBatch batch) override {
DCHECK_EQ(input, inputs_[0]);

// This can happen if an error was received and the source hasn't yet stopped. Since
// we have already called consumer_->Finish we don't want to call consumer_->Consume
if (input_counter_.Completed()) {
return;
}

Status consumption_status = consumer_->Consume(std::move(batch));
if (!consumption_status.ok()) {
if (input_counter_.Cancel()) {
Finish(std::move(consumption_status));
}
inputs_[0]->StopProducing(this);
return;
}

if (input_counter_.Increment()) {
Finish(Status::OK());
}
}

void ErrorReceived(ExecNode* input, Status error) override {
DCHECK_EQ(input, inputs_[0]);

if (input_counter_.Cancel()) {
Finish(std::move(error));
}

inputs_[0]->StopProducing(this);
}

void InputFinished(ExecNode* input, int total_batches) override {
if (input_counter_.SetTotal(total_batches)) {
Finish(Status::OK());
}
}

protected:
virtual void Finish(const Status& finish_st) {
consumer_->Finish().AddCallback([this, finish_st](const Status& st) {
// Prefer the plan error over the consumer error
Status final_status = finish_st & st;
finished_.MarkFinished(std::move(final_status));
});
}

AtomicCounter input_counter_;

Future<> finished_ = Future<>::MakeFinished();
std::shared_ptr<SinkNodeConsumer> consumer_;
};

// A sink node that accumulates inputs, then sorts them before emitting them.
struct OrderBySinkNode final : public SinkNode {
OrderBySinkNode(ExecPlan* plan, std::vector<ExecNode*> inputs,
std::unique_ptr<OrderByImpl> impl,
Expand Down Expand Up @@ -226,6 +325,7 @@ namespace internal {
void RegisterSinkNode(ExecFactoryRegistry* registry) {
DCHECK_OK(registry->AddFactory("select_k_sink", OrderBySinkNode::MakeSelectK));
DCHECK_OK(registry->AddFactory("order_by_sink", OrderBySinkNode::MakeSort));
DCHECK_OK(registry->AddFactory("consuming_sink", ConsumingSinkNode::Make));
DCHECK_OK(registry->AddFactory("sink", SinkNode::Make));
}

Expand Down
105 changes: 52 additions & 53 deletions cpp/src/arrow/compute/exec/source_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,60 +79,59 @@ struct SourceNode : ExecNode {
options.executor = executor;
options.should_schedule = ShouldSchedule::IfDifferentExecutor;
}
finished_ = Loop([this, executor, options] {
std::unique_lock<std::mutex> lock(mutex_);
int total_batches = batch_count_++;
if (stop_requested_) {
return Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
finished_ =
Loop([this, executor, options] {
std::unique_lock<std::mutex> lock(mutex_);
int total_batches = batch_count_++;
if (stop_requested_) {
return Future<ControlFlow<int>>::MakeFinished(Break(total_batches));
}
lock.unlock();

return generator_().Then(
[=](const util::optional<ExecBatch>& maybe_batch) -> ControlFlow<int> {
std::unique_lock<std::mutex> lock(mutex_);
if (IsIterationEnd(maybe_batch) || stop_requested_) {
stop_requested_ = true;
return Break(total_batches);
}
lock.unlock();
ExecBatch batch = std::move(*maybe_batch);

if (executor) {
auto status =
task_group_.AddTask([this, executor, batch]() -> Result<Future<>> {
return executor->Submit([=]() {
outputs_[0]->InputReceived(this, std::move(batch));
return Status::OK();
});
});
if (!status.ok()) {
outputs_[0]->ErrorReceived(this, std::move(status));
return Break(total_batches);
}
lock.unlock();

return generator_().Then(
[=](const util::optional<ExecBatch>& batch) -> ControlFlow<int> {
std::unique_lock<std::mutex> lock(mutex_);
if (IsIterationEnd(batch) || stop_requested_) {
stop_requested_ = true;
return Break(total_batches);
}
lock.unlock();

if (executor) {
auto maybe_future = executor->Submit([=]() {
outputs_[0]->InputReceived(this, *batch);
return Status::OK();
});
if (!maybe_future.ok()) {
outputs_[0]->ErrorReceived(this, maybe_future.status());
return Break(total_batches);
}
auto status =
task_group_.AddTask(maybe_future.MoveValueUnsafe());
if (!status.ok()) {
outputs_[0]->ErrorReceived(this, std::move(status));
return Break(total_batches);
}
} else {
outputs_[0]->InputReceived(this, *batch);
}
return Continue();
},
[=](const Status& error) -> ControlFlow<int> {
// NB: ErrorReceived is independent of InputFinished, but
// ErrorReceived will usually prompt StopProducing which will
// prompt InputFinished. ErrorReceived may still be called from a
// node which was requested to stop (indeed, the request to stop
// may prompt an error).
std::unique_lock<std::mutex> lock(mutex_);
stop_requested_ = true;
lock.unlock();
outputs_[0]->ErrorReceived(this, error);
return Break(total_batches);
},
options);
}).Then([&](int total_batches) {
outputs_[0]->InputFinished(this, total_batches);
return task_group_.WaitForTasksToFinish();
});
} else {
outputs_[0]->InputReceived(this, std::move(batch));
}
return Continue();
},
[=](const Status& error) -> ControlFlow<int> {
// NB: ErrorReceived is independent of InputFinished, but
// ErrorReceived will usually prompt StopProducing which will
// prompt InputFinished. ErrorReceived may still be called from a
// node which was requested to stop (indeed, the request to stop
// may prompt an error).
std::unique_lock<std::mutex> lock(mutex_);
stop_requested_ = true;
lock.unlock();
outputs_[0]->ErrorReceived(this, error);
return Break(total_batches);
},
options);
}).Then([&](int total_batches) {
outputs_[0]->InputFinished(this, total_batches);
return task_group_.End();
});

return Status::OK();
}
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/exec/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ class AtomicCounter {
// return true if the counter has not already been completed
bool Cancel() { return DoneOnce(); }

// return true if the counter has finished or been cancelled
bool Completed() { return complete_.load(); }

private:
// ensure there is only one true return from Increment(), SetTotal(), or Cancel()
bool DoneOnce() {
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/dataset/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ set(ARROW_DATASET_SRCS
file_base.cc
file_ipc.cc
partition.cc
plan.cc
projector.cc
scanner.cc)

Expand Down
Loading