[SPARK-49513][SS] Add Support for timer in transformWithStateInPandas API #47878
[SPARK-49513][SS] Add Support for timer in transformWithStateInPandas API #47878jingz-db wants to merge 17 commits intoapache:masterfrom
Conversation
4e61ac2 to
7512129
Compare
2ddec2b to
954759f
Compare
| batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
| watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() |
There was a problem hiding this comment.
Can we move these 2 API calls inside the in-else clause below and only call it with supported time mode?
There was a problem hiding this comment.
We will need to have some values to initialize the TimerValues in the handleInputRows. On Scala side, we will always pass the real timestamp into TimerValues even timer is not defined: https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala#L250
It is unlikely that users will call TimerValues if they do not have timer registered, but my original intention is to align the behavior with the Scala side. I guess here we might need to decide between saving a call and aligning with Scala side. I don't have a strong opinion on which is better. Which approach do you prefer?
There was a problem hiding this comment.
IIUC these 2 values are only being used when time mode is not none, I was meaning that for none time mode, we don't need these 2 extra API calls since it's not needed anyway
| if timeMode == "processingtime" and expiry_timestamp < batch_timestamp: | ||
| result_iter_list.append(statefulProcessor.handleInputRows( | ||
| (key_obj,), iter([]), | ||
| TimerValues(batch_timestamp, watermark_timestamp), |
There was a problem hiding this comment.
is watermark_timestamp needed for processingTime time mode and vise versa?
There was a problem hiding this comment.
Same above, this is to couple with behavior on Scala side.
| ExpiredTimerInfo(True, expiry_timestamp))) | ||
|
|
||
| # TODO(SPARK-49603) set the handle state in the lazily initialized iterator | ||
| """ |
There was a problem hiding this comment.
If we have a TODO here, we can remove the commented code.
| if len(response_message[2]) == 0: | ||
| return -1 | ||
| # TODO: can we simply parse from utf8 string here? | ||
| timestamp = int(response_message[2]) |
There was a problem hiding this comment.
Just curious: would this return the correct value?
There was a problem hiding this comment.
Passing a row schema and use CPickleSerializer seems a bit heavy-weighted. Modified this to pass a byte buffer of exact 8 bytes and read exactly 8 bytes on Python client.
| return [] | ||
| elif status == 0: | ||
| iterator = self._read_arrow_state() | ||
| batch = next(iterator) |
There was a problem hiding this comment.
Do we expect all the timers can be stored within a single arrow batch? If not, should we handle it properly here?
There was a problem hiding this comment.
We don't. It is now returning an iterator of List. Do you think this API makes sense to you?
| while (iter.hasNext) { | ||
| val timestamp = iter.next() | ||
| val internalRow = InternalRow(timestamp) | ||
| arrowStreamWriter.writeRow(internalRow) |
There was a problem hiding this comment.
Same as the python side comment: here we don't limit how many arrow batches we construct for timers, if user sets a fairly low value for arrowTransformWithStateInPandasMaxRecordsPerBatch, we would send multiple arrow batches and client side needs to handle this properly as well.
Question: should we have a lower limit on how many records we send throw a single batch (e.g. the default value 10000)? IIUC, each timer record is very small and should not consume a lot of memory. The user also doesn't care about how many records each batch contains since they would always get a single list from this API.
There was a problem hiding this comment.
I guess I'll rebase on your ListState PR change and this arrowTransformWithStateInPandasMaxRecordsPerBatch will be passed as the new config you'll add here: https://github.com/apache/spark/pull/47933/files#diff-0b0aaf91850194b6980b75d47bc166148566cbdc1b17b3da16faff1f0740e0f4R107.
But your concern above still holds. Should we pass a different default value for transmitting the list[Int] here? If so, should we add a new config or shall we just assign a fixed value for it?
| val allocator = ArrowUtils.rootAllocator.newChildAllocator( | ||
| s"stdout writer for transformWithStateInPandas state socket", 0, Long.MaxValue) | ||
| val root = VectorSchemaRoot.create(arrowSchema, allocator) | ||
| new BaseStreamingArrowWriter(root, new ArrowStreamWriter(root, null, outputStream), |
There was a problem hiding this comment.
Does it make sense to abstract this logic out since it's being used in multiple places?
There was a problem hiding this comment.
Done. Create a util object and put all python related writer functions into the Util object.
| batch = next(iterator) | ||
| result_list = [] | ||
| key_fields = [field.name for field in self.key_schema.fields] | ||
| # TODO any better way to restore a grouping object from a batch? |
There was a problem hiding this comment.
@bogao007 Is there any common practice for deserializing data from a batch object to the Python object for grouping key?
There was a problem hiding this comment.
Maybe take a look at how load_stream is implemented in ApplyInPandasWithStateSerializer and TransformWithStateInPandasSerializer in pyspark/sql/pandas/serializers.py. (and maybe some other customize serializers in the same file)
There was a problem hiding this comment.
Maybe try something like below?
df.itertuples(index=False, name=None)]
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.itertuples.html
There was a problem hiding this comment.
Btw, what if multiple batches are being sent from JVM, are we handling it correctly?
There was a problem hiding this comment.
Discussed with Bo offline, JVM will return Row type to Python and we can directly convert it into Tuple.
| batch_timestamp = statefulProcessorApiClient.get_batch_timestamp() | ||
| watermark_timestamp = statefulProcessorApiClient.get_watermark_timestamp() |
There was a problem hiding this comment.
IIUC these 2 values are only being used when time mode is not none, I was meaning that for none time mode, we don't need these 2 extra API calls since it's not needed anyway
| batch = next(iterator) | ||
| result_list = [] | ||
| key_fields = [field.name for field in self.key_schema.fields] | ||
| # TODO any better way to restore a grouping object from a batch? |
There was a problem hiding this comment.
Maybe try something like below?
df.itertuples(index=False, name=None)]
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.itertuples.html
| batch = next(iterator) | ||
| result_list = [] | ||
| key_fields = [field.name for field in self.key_schema.fields] | ||
| # TODO any better way to restore a grouping object from a batch? |
There was a problem hiding this comment.
Btw, what if multiple batches are being sent from JVM, are we handling it correctly?
...main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala
Show resolved
Hide resolved
| outputStream.write(responseMessageBytes) | ||
| } | ||
|
|
||
| def serializeLongToByteString(longValue: Long): ByteString = { |
There was a problem hiding this comment.
I think it may bring some extra complexity to do serde between long and ByteString. Since this is only used in TimerValueRequests, maybe we could add a dedicated response message for it which returns a long value? That way we can just use read_long on python side.
There was a problem hiding this comment.
Add a new type of StateResponse to transmit Long type directly in proto message.
bogao007
left a comment
There was a problem hiding this comment.
LGTM overall, left one minor comment regarding arrow resources clean up. Thanks for making the changes!
| arrowStreamWriter.writeRow(internalRow) | ||
| } | ||
| arrowStreamWriter.finalizeCurrentArrowBatch() | ||
| writer.end() |
There was a problem hiding this comment.
Minor: We might need to do something similar to what PythonArrowInput does to ensure we don't see unexpected errors
HeartSaVioR
left a comment
There was a problem hiding this comment.
Looks OK in overall. There are some small correction but mostly minors and nits.
There was a problem hiding this comment.
I might lose following of how TWS (for PySpark) works, but given we get the iterator of expiry timers based on the timestamp, isn't this if statement already covered from API call? In other words, shouldn't API need to cover this?
Please let me know if there is specific reason - don't need to change the code directly if there is a reason. I just wanted to understand and possibly refresh my head.
There was a problem hiding this comment.
You are correct about this. Thanks for noticing the redundant check. Removed.
There was a problem hiding this comment.
nit: is this indentation correct? looks a bit odd, compared to others - params start from same indentation with the first _.
There was a problem hiding this comment.
Shall we leave Github to resolve the review comment rather than manually marking as resolved? I don't see any new commit to resolve these style comments. I guess you've addressed but missed to push commit, but easier to track if we resolve the comment as "outdated".
There was a problem hiding this comment.
(I tend to use reaction to distinguish comments I agree to address, especially style comments.)
There was a problem hiding this comment.
nit: method doc in Python is placed "after" definition of the method.
There was a problem hiding this comment.
For other functionality we add the spying instance as a param to test this. Do we test this in e2e instead? I'm OK with it. Just wanted to check.
There was a problem hiding this comment.
Same above. This is also tested in e2e suites by assertions on the output of handle expired timer rows.
There was a problem hiding this comment.
nit: any specific reason to implement here separately rather than calling _prepare_input_data?
There was a problem hiding this comment.
refactored using calling_prepare_input_data.
There was a problem hiding this comment.
This is not true - watermark for eviction is 5 but watermark for late record is 0, hence ("a", 4) is not dropped. This is exactly the reason you still see event for "a". Otherwise you shouldn't have ("a", 20).
There was a problem hiding this comment.
You might wonder how this works differently with Scala tests - AddData() & CheckNewAnswer() will trigger no-data batch, hence executing two batches.
There was a problem hiding this comment.
Thanks for leaving the comments! By reading your comments I realized I did not quite understand the difference between watermark for eviction and watermark for late record before.
The test case should be still fine, I just deleted the comments. Dropping late record will be tested more throughly in the chaining of operator PR.
HeartSaVioR
left a comment
There was a problem hiding this comment.
Looks good to me except nits.
There was a problem hiding this comment.
This seems to be missed.
There was a problem hiding this comment.
Thanks, do you like to leave getExpiredTimers() as it is? Then let's leave a code comment that the requests of getExpiredTimers won't be interleaved, hence this is safe.
|
test failure seems to be unrelated but lint seems to be either related or simply broken. https://github.com/jingz-db/spark/actions/runs/11526119046/job/32090806651 @HyukjinKwon Do you somehow know the reason of failure? I guess the generated py file should be excluded from linter and, I thought we did it, as I didn't see linter failure in prior PRs. Was anything changed around pyspark linter? |
|
@jingz-db mind updating your master branch latest, and rebase against that this branch, and push it? |
c36be34 to
fc240c6
Compare
|
@jingz-db |
I do, I rebased on the latest master branch few hours ago: Let me add a type checking imports and see if it passes. |
|
Hey @HyukjinKwon, do we have any place that could manually escape the python style check for certain files? Currently the linter check is only failing on auto-generated file created by |
There was a problem hiding this comment.
Let's add # noqa: E501 back to ignore the length check.
|
Could you please try modifying mypy.ini file to ignore errors on proto generated python files? You'll need to move the generated file to proto directory (create a new directory) and add the exclusion. Lines 183 to 185 in 413242b |
|
Also please rebase to incorporate the removal of generated code for java. |
dd03580 to
e52fb3a
Compare
Thanks for the pointer! Moved proto generated py file under sql/streaming/proto directory and add the entry in the mypy.init file. |
|
Thanks! Merging to master. |


What changes were proposed in this pull request?
Support for timer in TransformWithStateInPandas Python API.
Why are the changes needed?
To couple with Scala API, TransformWithStateInPandas should also support processing/event time timer for arbitrary state.
Does this PR introduce any user-facing change?
Yes. Users can now interact with timers from
handleInputRowswith two addtional parameters as:And user can interact with a newly introduce
TimerValuesto get processing/event time for current batch:Users can also interact with
expired_timer_infoto get the timestamp for expired timers:How was this patch tested?
Unit tests in
TransformWithStateInPandasStateServerSuiteand integration tests intest_pandas_transform_with_state.py.Was this patch authored or co-authored using generative AI tooling?
No.