Skip to content

Commit 143a7da

Browse files
authored
GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
### Rationale for this change It's impossible to use the current bindings with PollFlightInfo. Required for apache/arrow-adbc#1457. ### What changes are included in this PR? Add new methods that expose PollFlightInfo. ### Are these changes tested? Yes ### Are there any user-facing changes? Adds new methods. * Closes: #39574 Authored-by: David Li <li.davidm96@gmail.com> Signed-off-by: David Li <li.davidm96@gmail.com>
1 parent 55afcf0 commit 143a7da

3 files changed

Lines changed: 206 additions & 0 deletions

File tree

go/arrow/flight/flightsql/client.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ func flightInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, op
8282
return cl.getFlightInfo(ctx, desc, opts...)
8383
}
8484

85+
func pollInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
86+
if retryDescriptor != nil {
87+
return cl.Client.PollFlightInfo(ctx, retryDescriptor, opts...)
88+
}
89+
desc, err := descForCommand(cmd)
90+
if err != nil {
91+
return nil, err
92+
}
93+
return cl.Client.PollFlightInfo(ctx, desc, opts...)
94+
}
95+
8596
func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
8697
desc, err := descForCommand(cmd)
8798
if err != nil {
@@ -123,6 +134,14 @@ func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOpt
123134
return flightInfoForCommand(ctx, c, &cmd, opts...)
124135
}
125136

137+
// ExecutePoll idempotently starts execution of a query/checks for completion.
138+
// To check for completion, pass the FlightDescriptor from the previous call
139+
// to ExecutePoll as the retryDescriptor.
140+
func (c *Client) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
141+
cmd := pb.CommandStatementQuery{Query: query}
142+
return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
143+
}
144+
126145
// GetExecuteSchema gets the schema of the result set of a query without
127146
// executing the query itself.
128147
func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
@@ -136,6 +155,12 @@ func (c *Client) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts
136155
return flightInfoForCommand(ctx, c, &cmd, opts...)
137156
}
138157

158+
func (c *Client) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
159+
cmd := pb.CommandStatementSubstraitPlan{
160+
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
161+
return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
162+
}
163+
139164
func (c *Client) GetExecuteSubstraitSchema(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
140165
cmd := pb.CommandStatementSubstraitPlan{
141166
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
@@ -606,6 +631,15 @@ func (tx *Txn) Execute(ctx context.Context, query string, opts ...grpc.CallOptio
606631
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
607632
}
608633

634+
func (tx *Txn) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
635+
if !tx.txn.IsValid() {
636+
return nil, ErrInvalidTxn
637+
}
638+
// The server should encode the transaction into the retry descriptor
639+
cmd := &pb.CommandStatementQuery{Query: query, TransactionId: tx.txn}
640+
return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
641+
}
642+
609643
func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
610644
if !tx.txn.IsValid() {
611645
return nil, ErrInvalidTxn
@@ -616,6 +650,18 @@ func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts ..
616650
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
617651
}
618652

653+
func (tx *Txn) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
654+
if !tx.txn.IsValid() {
655+
return nil, ErrInvalidTxn
656+
}
657+
// The server should encode the transaction into the retry descriptor
658+
cmd := &pb.CommandStatementSubstraitPlan{
659+
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version},
660+
TransactionId: tx.txn,
661+
}
662+
return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
663+
}
664+
619665
func (tx *Txn) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
620666
if !tx.txn.IsValid() {
621667
return nil, ErrInvalidTxn
@@ -981,6 +1027,52 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
9811027
return p.client.getFlightInfo(ctx, desc, opts...)
9821028
}
9831029

1030+
// ExecutePoll executes the prepared statement on the server and returns a PollInfo
1031+
// indicating the progress of execution.
1032+
//
1033+
// Will error if already closed.
1034+
func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
1035+
if p.closed {
1036+
return nil, errors.New("arrow/flightsql: prepared statement already closed")
1037+
}
1038+
1039+
cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}
1040+
1041+
desc := retryDescriptor
1042+
var err error
1043+
1044+
if desc == nil {
1045+
desc, err = descForCommand(cmd)
1046+
if err != nil {
1047+
return nil, err
1048+
}
1049+
}
1050+
1051+
if retryDescriptor == nil {
1052+
if p.hasBindParameters() {
1053+
pstream, err := p.client.Client.DoPut(ctx, opts...)
1054+
if err != nil {
1055+
return nil, err
1056+
}
1057+
1058+
wr, err := p.writeBindParameters(pstream, desc)
1059+
if err != nil {
1060+
return nil, err
1061+
}
1062+
if err = wr.Close(); err != nil {
1063+
return nil, err
1064+
}
1065+
pstream.CloseSend()
1066+
1067+
// wait for the server to ack the result
1068+
if _, err = pstream.Recv(); err != nil && err != io.EOF {
1069+
return nil, err
1070+
}
1071+
}
1072+
}
1073+
return p.client.Client.PollFlightInfo(ctx, desc, opts...)
1074+
}
1075+
9841076
// ExecuteUpdate executes the prepared statement update query on the server
9851077
// and returns the number of rows affected. If SetParameters was called,
9861078
// the parameter bindings will be sent with the request to execute.

go/arrow/flight/flightsql/server.go

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,6 +524,22 @@ func (BaseServer) RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpoi
524524
return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not implemented")
525525
}
526526

527+
func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) {
528+
return nil, status.Error(codes.Unimplemented, "PollFlightInfo not implemented")
529+
}
530+
531+
func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
532+
return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement not implemented")
533+
}
534+
535+
func (BaseServer) PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) {
536+
return nil, status.Error(codes.Unimplemented, "PollFlightInfoSubstraitPlan not implemented")
537+
}
538+
539+
func (BaseServer) PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
540+
return nil, status.Error(codes.Unimplemented, "PollFlightInfoPreparedStatement not implemented")
541+
}
542+
527543
func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest) error {
528544
return status.Error(codes.Unimplemented, "EndTransaction not implemented")
529545
}
@@ -652,6 +668,14 @@ type Server interface {
652668
CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error)
653669
// RenewFlightEndpoint attempts to extend the expiration of a FlightEndpoint
654670
RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error)
671+
// PollFlightInfo is a generic handler for PollFlightInfo requests.
672+
PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error)
673+
// PollFlightInfoStatement handles polling for query execution.
674+
PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
675+
// PollFlightInfoSubstraitPlan handles polling for query execution.
676+
PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error)
677+
// PollFlightInfoPreparedStatement handles polling for query execution.
678+
PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
655679

656680
mustEmbedBaseServer()
657681
}
@@ -729,6 +753,36 @@ func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.Fli
729753
return nil, status.Error(codes.InvalidArgument, "requested command is invalid")
730754
}
731755

756+
func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request *flight.FlightDescriptor) (*flight.PollInfo, error) {
757+
var (
758+
anycmd anypb.Any
759+
cmd proto.Message
760+
err error
761+
)
762+
// If we can't parse things, be friendly and defer to the server
763+
// implementation. This is especially important for this method since
764+
// the server returns a custom FlightDescriptor for future requests.
765+
if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil {
766+
return f.srv.PollFlightInfo(ctx, request)
767+
}
768+
769+
if cmd, err = anycmd.UnmarshalNew(); err != nil {
770+
return f.srv.PollFlightInfo(ctx, request)
771+
}
772+
773+
switch cmd := cmd.(type) {
774+
case *pb.CommandStatementQuery:
775+
return f.srv.PollFlightInfoStatement(ctx, cmd, request)
776+
case *pb.CommandStatementSubstraitPlan:
777+
return f.srv.PollFlightInfoSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request)
778+
case *pb.CommandPreparedStatementQuery:
779+
return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request)
780+
}
781+
// XXX: for now we won't support the other methods
782+
783+
return f.srv.PollFlightInfo(ctx, request)
784+
}
785+
732786
func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) {
733787
var (
734788
anycmd anypb.Any

go/arrow/flight/flightsql/server_test.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,36 @@ func (*testServer) GetFlightInfoStatement(ctx context.Context, q flightsql.State
5656
}, nil
5757
}
5858

59+
func (*testServer) PollFlightInfo(ctx context.Context, fd *flight.FlightDescriptor) (*flight.PollInfo, error) {
60+
return &flight.PollInfo{
61+
Info: &flight.FlightInfo{
62+
FlightDescriptor: fd,
63+
Endpoint: []*flight.FlightEndpoint{{
64+
Ticket: &flight.Ticket{Ticket: []byte{}},
65+
}, {
66+
Ticket: &flight.Ticket{Ticket: []byte{}},
67+
}},
68+
},
69+
FlightDescriptor: nil,
70+
}, nil
71+
}
72+
73+
func (*testServer) PollFlightInfoStatement(ctx context.Context, q flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.PollInfo, error) {
74+
ticket, err := flightsql.CreateStatementQueryTicket([]byte(q.GetQuery()))
75+
if err != nil {
76+
return nil, err
77+
}
78+
return &flight.PollInfo{
79+
Info: &flight.FlightInfo{
80+
FlightDescriptor: fd,
81+
Endpoint: []*flight.FlightEndpoint{{
82+
Ticket: &flight.Ticket{Ticket: ticket},
83+
}},
84+
},
85+
FlightDescriptor: &flight.FlightDescriptor{Cmd: []byte{}},
86+
}, nil
87+
}
88+
5989
func (*testServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan flight.StreamChunk, err error) {
6090
handle := string(ticket.GetStatementHandle())
6191
switch handle {
@@ -189,6 +219,20 @@ func (s *FlightSqlServerSuite) TestExecuteChunkError() {
189219
}
190220
}
191221

222+
func (s *FlightSqlServerSuite) TestExecutePoll() {
223+
poll, err := s.cl.ExecutePoll(context.TODO(), "1", nil)
224+
s.NoError(err)
225+
s.NotNil(poll)
226+
s.NotNil(poll.GetFlightDescriptor())
227+
s.Len(poll.GetInfo().Endpoint, 1)
228+
229+
poll, err = s.cl.ExecutePoll(context.TODO(), "1", poll.GetFlightDescriptor())
230+
s.NoError(err)
231+
s.NotNil(poll)
232+
s.Nil(poll.GetFlightDescriptor())
233+
s.Len(poll.GetInfo().Endpoint, 2)
234+
}
235+
192236
type UnimplementedFlightSqlServerSuite struct {
193237
suite.Suite
194238

@@ -314,6 +358,22 @@ func (s *UnimplementedFlightSqlServerSuite) TestGetTypeInfo() {
314358
s.Nil(info)
315359
}
316360

361+
func (s *UnimplementedFlightSqlServerSuite) TestPoll() {
362+
poll, err := s.cl.ExecutePoll(context.TODO(), "", nil)
363+
st, ok := status.FromError(err)
364+
s.True(ok)
365+
s.Equal(codes.Unimplemented, st.Code())
366+
s.Equal("PollFlightInfoStatement not implemented", st.Message())
367+
s.Nil(poll)
368+
369+
poll, err = s.cl.ExecuteSubstraitPoll(context.TODO(), flightsql.SubstraitPlan{}, nil)
370+
st, ok = status.FromError(err)
371+
s.True(ok)
372+
s.Equal(codes.Unimplemented, st.Code())
373+
s.Equal("PollFlightInfoSubstraitPlan not implemented", st.Message())
374+
s.Nil(poll)
375+
}
376+
317377
func getTicket(cmd proto.Message) *flight.Ticket {
318378
var anycmd anypb.Any
319379
anycmd.MarshalFrom(cmd)

0 commit comments

Comments
 (0)