Add Plotly-like interpolation algorithm to optuna.visualization.matplotlib.plot_contour#2810
Conversation
This patch introduces _create_zmatrix() function, which is used to transform trial params and values into irregularly spaced grid which forms a base for interpolation algorithm.
This patch introduces algorithm, which is used to determine order in which z-matrix should be interpolated.
This patch introduces new z-matrix interpolation algorithm for matplotlib contour plots. This algorithm is a port of one used in Plotly
Numba has to be specified as install requirement, since jit is a decorator, and as such has to be present at start time
Codecov Report
@@ Coverage Diff @@
## master #2810 +/- ##
==========================================
- Coverage 91.79% 91.43% -0.37%
==========================================
Files 146 147 +1
Lines 11192 11352 +160
==========================================
+ Hits 10274 10380 +106
- Misses 918 972 +54
Continue to review full report at Codecov.
|
|
|
This patch introduces axes padding algorithm similar to the one used in plotly. With it, trial values close to the edge of param grid are rendered correctly
|
Waiting for #2806 to merge so I can proceed with tests. |
|
A friendly ping that #2806 has been merged 🙂 |
|
Assigned @himkt although the work is still ongoing. And, thanks a lot for the great work! At a quick glance, looks like a definite improvement. I'm just curious how critical Numba is to this functionality. Does the visualization become unusable without the JIT compilation (i.e. how much faster does it become with Numba)? We want to be careful introducing additional install requirements. |
|
@hvy thank you for the question! Let me share a small bench test i did while developing this.
I understand your point about introducing additional install requirements though. As an alternative, i can move jitted code to separate module, |
|
Thanks for the numbers! With a separate module, do you mean independent of Optuna or within the Optuna library? Interesting that Plotly is significantly faster than the Python implementation without the JIT compilation. I haven't taken a deeper look at the algorithm yet, but I'm curious whether there is room for optimizing the Python implementation without relying on Numba (e.g. by simple vectorization). If there is, we could perhaps leave it as future work and introduce this algorithm as the default one (given that the matplotlib visualizations are all still experimental). If not, |
|
@nzw0301 let me assign you as the author of the issue. Let me know if you're busy and I can reassign someone else. Do you have any thoughts or inputs on how to continue forward with this? |
|
@hvy, sorry for not being very clear on that one. I meant module within Optuna, more specifically within I think Plotly speed comes from the fact that there is not much heavy lifting done python side. From what i know, execution is passed almost immediately to one of renderers. I'm not sure if vectorization is feasible in this instance, since algorithm relies on very specific order in which missing values are interpolated (i. e. when imputing
I will explore both possibilities, before marking this as ready to review. |
|
@hvy, so I've managed to cut no-JIT time from 1m 17s to 4.5s (on Intel Core i5) just by using hashmaps. This comes with no downsides to the quality of the plot and does not use any new dependencies. It's still a bit slower than Plotly or jitted numpy, but I think we are within reasonable runtime for a function that produces a plot. I don't think there is much to gain anymore with pure python or without diverging from Plotly algorithm. Overall im happy to exclude Numba and switch to the new approach, so please let me know if this is the level of performance we are ok with. |
|
That sounds great! |
This patch changes the implementation of Plotly zmatrix interpolation algorithm to a hashmap based one in order to avoid taking dependency on Numba
|
hi, @xadrianzetx, sorry for the delayed response. We have had national holidays since last Thursday. I'll take a look your this PR soon. Please give me a little bit more time. best regards, |
nzw0301
left a comment
There was a problem hiding this comment.
Thank you for your great PR! The separated functions are really helpful to follow the logic and tests! I do not still understand the interpolation logic yet, but let me leave minor comments :)
|
Thank you for the comments @nzw0301! Just FYI, it might be couple days before I'm able to address them, due to some hardware failures I'm currently resolving. Sorry for any delays! |
|
No worries! It is totally fine. I hope your hardware issue is resolved🤞🏻 |
|
hi, let me give additional comments. Personally using complex numbers looks cool! However, it is slightly hard for other committers to maintain the codes. So could you replace complex numbers with tuples or In this complex number-free implementation, one drawback is that we need to care the boundary access with What do you think about it? |
himkt
left a comment
There was a problem hiding this comment.
Huge thanks for the awesome improvements @xadrianzetx!
Let me leave some suggestions about type checking.
|
|
||
|
|
||
| NEIGHBOR_OFFSETS = [1 + 0j, -1 + 0j, 0 + 1j, 0 - 1j] | ||
| NUM_OPTIMIZATION_ITERATIONS = 100 |
There was a problem hiding this comment.
[memo] Same iteration number as https://github.com/plotly/plotly.js/blob/bbeeda93efb56a50bafdf6f396b0abf345b5c5cb/src/traces/heatmap/interp2d.js#L45.
That was the design of initial version introduced in a27cdf4 and f7b343f. Unfortunately, without JIT, performance really suffered due to iterative nature of this algorithm, and Numpy element access being somewhat slow. Without Numba, we needed a data structure that offers fast access to its elements, preferably included in Python standard library, hence usage of hashmap.
The reason behind coordinates being expressed as complex numbers is, that it offers the fastest way to calculate locations of neighbors, with just one simple
Given those, I would be in favor of staying with current implementation. I do agree with you on maintainability point though, this algorithm is not the easiest thing to understand. To help with this, I think additional comments explaining grid system and algorithm steps could be introduced. |
|
Thank you so much for the revise! Basically LGTM. ⭐ There still be little bit difference between plotly's plot and matplotlib's plot.
|
|
It might be due to the fact that for Plotly contour we are not doing any upsampling to trial params, so in your case matrix of shape 30x30 (roughly :)) is interpolated and contour smoothing is used, while in matplotlib implementation we are always resampling to 100x100 either with linspace or logspace. This could result in differences we see, especially for small number of trials. Other than that, there might be slight differences in implementation between Plotly and matplotlib as far as calculating number and position of contours (looking at your example, matplotlib seems to be more detailed even though we are not specifying number of contour lines to use anywhere). To my best knowledge they both use marching squares, but I did not look up specific implementations. |
nzw0301
left a comment
There was a problem hiding this comment.
Thank you for your comments and update PR! I understand the reason for using complex numbers, so I'm happy with the current implementation!
It makes sense. I agree with you that it results the difference. And I also think it is not bad point. 👍 |
himkt
left a comment
There was a problem hiding this comment.
Again, thank you so much @xadrianzetx for the awesome improvement! LGTM.


This PR introduces new interpolation algorithm to be used when creating contour plots with
optuna.visualization.matplotlib.plot_contour(), improving on linear interpolation offered byscipy.griddata().The algorithm is a port of Plotly.js version.
This addresses known issues in current implementation:
Additionally, tweaks to
xrangeandyrangeare introduced in order to correctly visualize trial values at the edges of param grid.Motivation
This PR resolves #2738 and addresses #2712 (review).
Description of the changes
optuna.visualization.matplotlib.plot_contour()xrangeandyrangepaddingExamples
All plots were generated using example from #2595
Current algorithm

New algorithm

Plotly

And a very cool gif showing how this algorithm converges :^)
