Skip to content

Add Plotly-like interpolation algorithm to optuna.visualization.matplotlib.plot_contour#2810

Merged
himkt merged 14 commits intooptuna:masterfrom
xadrianzetx:feature-contour-algorithm
Jul 30, 2021
Merged

Add Plotly-like interpolation algorithm to optuna.visualization.matplotlib.plot_contour#2810
himkt merged 14 commits intooptuna:masterfrom
xadrianzetx:feature-contour-algorithm

Conversation

@xadrianzetx
Copy link
Copy Markdown
Collaborator

@xadrianzetx xadrianzetx commented Jul 15, 2021

This PR introduces new interpolation algorithm to be used when creating contour plots with optuna.visualization.matplotlib.plot_contour(), improving on linear interpolation offered by scipy.griddata().

The algorithm is a port of Plotly.js version.

This addresses known issues in current implementation:

  • contour lines which are not smooth
  • blank areas on edges of the plot

Additionally, tweaks to xrange and yrange are introduced in order to correctly visualize trial values at the edges of param grid.

Motivation

This PR resolves #2738 and addresses #2712 (review).

Description of the changes

  • Implement interpolation algorithm and helpers.
  • Switch to new algorithm in optuna.visualization.matplotlib.plot_contour()
  • Address xrange and yrange padding
  • Write tests

Examples

All plots were generated using example from #2595

Current algorithm
image

New algorithm
image

Plotly
image

And a very cool gif showing how this algorithm converges :^)
plot_iter2

This patch introduces _create_zmatrix() function, which is used to
transform trial params and values into irregularly spaced grid which
forms a base for interpolation algorithm.
This patch introduces algorithm, which is used to determine order in
which z-matrix should be interpolated.
This patch introduces new z-matrix interpolation algorithm for
matplotlib contour plots. This algorithm is a port of one used in Plotly
@github-actions github-actions bot added the optuna.visualization Related to the `optuna.visualization` submodule. This is automatically labeled by github-actions. label Jul 15, 2021
Numba has to be specified as install requirement, since jit is a
decorator, and as such has to be present at start time
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jul 16, 2021

Codecov Report

Merging #2810 (b830410) into master (e0fd830) will decrease coverage by 0.36%.
The diff coverage is 90.62%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2810      +/-   ##
==========================================
- Coverage   91.79%   91.43%   -0.37%     
==========================================
  Files         146      147       +1     
  Lines       11192    11352     +160     
==========================================
+ Hits        10274    10380     +106     
- Misses        918      972      +54     
Impacted Files Coverage Δ
optuna/visualization/matplotlib/_contour.py 83.21% <90.62%> (+3.53%) ⬆️
optuna/cli.py 26.97% <0.00%> (-2.60%) ⬇️
optuna/study/_optimize.py 97.77% <0.00%> (-0.04%) ⬇️
optuna/storages/__init__.py 100.00% <0.00%> (ø)
optuna/storages/_heartbeat.py 100.00% <0.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update e0fd830...b830410. Read the comment docs.

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

blackdoc itself has failed in CI (not the files it was checking). Strange.

This patch introduces axes padding algorithm similar to the one used in
plotly. With it, trial values close to the edge of param grid are
rendered correctly
@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

Waiting for #2806 to merge so I can proceed with tests.

@hvy
Copy link
Copy Markdown
Member

hvy commented Jul 19, 2021

A friendly ping that #2806 has been merged 🙂

@hvy
Copy link
Copy Markdown
Member

hvy commented Jul 19, 2021

Assigned @himkt although the work is still ongoing.

And, thanks a lot for the great work! At a quick glance, looks like a definite improvement. I'm just curious how critical Numba is to this functionality. Does the visualization become unusable without the JIT compilation (i.e. how much faster does it become with Numba)? We want to be careful introducing additional install requirements.

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

xadrianzetx commented Jul 19, 2021

@hvy thank you for the question! Let me share a small bench test i did while developing this.

Version Runtime
Matplotlib (no JIT) 1 loop, best of 5: 1min 17s per loop
Matplotlib (JIT) 1 loop, best of 5: 880 ms per loop
Plotly 1 loop, best of 5: 304 ms per loop

I understand your point about introducing additional install requirements though. As an alternative, i can move jitted code to separate module, try_import it, and fallback to default griddata solution if we are missing dependencies. This would move numba out of install requirements back to optional. Then i could add a note in dosctring of plot_contour informing users about this functionality. I'm just worried that many users would then see default at first go (although popular environments such as google colab or kaggle notebooks are including Numba by default). Please let me know what do you think of this solution.

@hvy
Copy link
Copy Markdown
Member

hvy commented Jul 20, 2021

Thanks for the numbers! With a separate module, do you mean independent of Optuna or within the Optuna library?

Interesting that Plotly is significantly faster than the Python implementation without the JIT compilation. I haven't taken a deeper look at the algorithm yet, but I'm curious whether there is room for optimizing the Python implementation without relying on Numba (e.g. by simple vectorization). If there is, we could perhaps leave it as future work and introduce this algorithm as the default one (given that the matplotlib visualizations are all still experimental). If not, try_import might be an option. But it'll complicate the behavior (and/or the API) so I'd like to hear inputs from others if possible.

@hvy
Copy link
Copy Markdown
Member

hvy commented Jul 20, 2021

@nzw0301 let me assign you as the author of the issue. Let me know if you're busy and I can reassign someone else. Do you have any thoughts or inputs on how to continue forward with this?

@hvy hvy added the enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. label Jul 20, 2021
@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

xadrianzetx commented Jul 20, 2021

@hvy, sorry for not being very clear on that one. I meant module within Optuna, more specifically within optuna.visualization.matplotlib with imports working in the similar way to _matplotlib_imports.py.

I think Plotly speed comes from the fact that there is not much heavy lifting done python side. From what i know, execution is passed almost immediately to one of renderers.

I'm not sure if vectorization is feasible in this instance, since algorithm relies on very specific order in which missing values are interpolated (i. e. when imputing v[1] with mean of neighbors, there might be a case when its necessary to fill v[0] first, when those values are close to each other on the grid). There are two ways I think this algorithm could be speed up if we were to drop Numba:

  • Interpolate z-matrix first, upsample to smooth contours later (currently its done other way around) - this would mean we are iterating over smaller matrices, in most cases at least. (and looking at it again, this is I think what Plotly does)
  • Instead extensively using array indexing/slicing, base this algorithm on hashmap of coordinates, and iterate those. (and just build z-matrix out of it as the very last step)

I will explore both possibilities, before marking this as ready to review.

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

@hvy, so I've managed to cut no-JIT time from 1m 17s to 4.5s (on Intel Core i5) just by using hashmaps. This comes with no downsides to the quality of the plot and does not use any new dependencies. It's still a bit slower than Plotly or jitted numpy, but I think we are within reasonable runtime for a function that produces a plot. I don't think there is much to gain anymore with pure python or without diverging from Plotly algorithm. Overall im happy to exclude Numba and switch to the new approach, so please let me know if this is the level of performance we are ok with.

@hvy
Copy link
Copy Markdown
Member

hvy commented Jul 21, 2021

That sounds great!
Could you push the change (or create a separate PR if you like), and we could take a look at the implementation? I'm a bit curious about the time complexity how it scales, but it in any case sounds like a significant improvement that we should continue with.

This patch changes the implementation of Plotly zmatrix interpolation
algorithm to a hashmap based one in order to avoid taking dependency on
Numba
@xadrianzetx xadrianzetx marked this pull request as ready for review July 22, 2021 18:49
@nzw0301
Copy link
Copy Markdown
Member

nzw0301 commented Jul 26, 2021

hi, @xadrianzetx, sorry for the delayed response. We have had national holidays since last Thursday.

I'll take a look your this PR soon. Please give me a little bit more time.

best regards,

Copy link
Copy Markdown
Member

@nzw0301 nzw0301 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your great PR! The separated functions are really helpful to follow the logic and tests! I do not still understand the interpolation logic yet, but let me leave minor comments :)

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

Thank you for the comments @nzw0301! Just FYI, it might be couple days before I'm able to address them, due to some hardware failures I'm currently resolving. Sorry for any delays!

@nzw0301
Copy link
Copy Markdown
Member

nzw0301 commented Jul 27, 2021

No worries! It is totally fine. I hope your hardware issue is resolved🤞🏻

@nzw0301
Copy link
Copy Markdown
Member

nzw0301 commented Jul 27, 2021

hi, let me give additional comments.

Personally using complex numbers looks cool! However, it is slightly hard for other committers to maintain the codes. So could you replace complex numbers with tuples or np.ndarray of two integers if the implementation does not make the computing performance severely worse? Related to this, zmap can be replaced with np.ndarray whose values are initialised at np.nan like np.full((len(xi), len(yi)), np.NaN) in _create_zmap.

In this complex number-free implementation, one drawback is that we need to care the boundary access with NEIGHBOR_OFFSETS in _find_coordinates_where_empty and _run_iteration, like np.maximum(np.minimum(coord + offset, contour_point_num-1), 0).

What do you think about it?

Copy link
Copy Markdown
Member

@himkt himkt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huge thanks for the awesome improvements @xadrianzetx!
Let me leave some suggestions about type checking.



NEIGHBOR_OFFSETS = [1 + 0j, -1 + 0j, 0 + 1j, 0 - 1j]
NUM_OPTIMIZATION_ITERATIONS = 100
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

@nzw0301,

zmap can be replaced with np.ndarray whose values are initialised at np.nan like np.full((len(xi), len(yi)), np.NaN) in _create_zmap.

That was the design of initial version introduced in a27cdf4 and f7b343f. Unfortunately, without JIT, performance really suffered due to iterative nature of this algorithm, and Numpy element access being somewhat slow. Without Numba, we needed a data structure that offers fast access to its elements, preferably included in Python standard library, hence usage of hashmap.

So could you replace complex numbers with tuples or np.ndarray of two integers if the implementation does not make the computing performance severely worse?

The reason behind coordinates being expressed as complex numbers is, that it offers the fastest way to calculate locations of neighbors, with just one simple add to cover both x offset and y offset simultaneously. This is one of the most critical parts of the algorithm, which, assuming 200 Optuna trials, gets hit 9800 missing values * 100 optim iterations * 4 neighbors (excluding nan discovery part). I did a quick benchmark for both tuples and Numpy arrays (just the core part of addition operation looped over 3 milion iterations) and saw:

  • tuple of coordinates (defined as tuple(sum(c) for c in zip(coord, offset))) being ~2x slower than complex
  • Numpy arrays (defined as coord_arr + offset_arr) being ~1.4x slower than complex

Given those, I would be in favor of staying with current implementation. I do agree with you on maintainability point though, this algorithm is not the easiest thing to understand. To help with this, I think additional comments explaining grid system and algorithm steps could be introduced.

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

Thanks for reviews @nzw0301, @himkt! I've addressed all points in 34a887e, and added some additional comments in b830410.

@himkt
Copy link
Copy Markdown
Member

himkt commented Jul 29, 2021

@xadrianzetx

Thank you so much for the revise! Basically LGTM. ⭐
I have one question.

There still be little bit difference between plotly's plot and matplotlib's plot.
Do you know the cause of this difference?

Plotly matplotlib
スクリーンショット 2021-07-30 001921 スクリーンショット 2021-07-30 001935

@xadrianzetx
Copy link
Copy Markdown
Collaborator Author

@himkt,

It might be due to the fact that for Plotly contour we are not doing any upsampling to trial params, so in your case matrix of shape 30x30 (roughly :)) is interpolated and contour smoothing is used, while in matplotlib implementation we are always resampling to 100x100 either with linspace or logspace. This could result in differences we see, especially for small number of trials.

Other than that, there might be slight differences in implementation between Plotly and matplotlib as far as calculating number and position of contours (looking at your example, matplotlib seems to be more detailed even though we are not specifying number of contour lines to use anywhere). To my best knowledge they both use marching squares, but I did not look up specific implementations.

Copy link
Copy Markdown
Member

@nzw0301 nzw0301 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments and update PR! I understand the reason for using complex numbers, so I'm happy with the current implementation!

@himkt
Copy link
Copy Markdown
Member

himkt commented Jul 30, 2021

It might be due to the fact that for Plotly contour we are not doing any upsampling to trial params, so in your case matrix of shape 30x30 (roughly :)) is interpolated and contour smoothing is used, while in matplotlib implementation we are always resampling to 100x100 either with linspace or logspace.

It makes sense. I agree with you that it results the difference. And I also think it is not bad point. 👍

Copy link
Copy Markdown
Member

@himkt himkt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, thank you so much @xadrianzetx for the awesome improvement! LGTM.

@himkt himkt added this to the v2.10.0 milestone Jul 30, 2021
@himkt himkt merged commit 1dc9e4c into optuna:master Jul 30, 2021
@xadrianzetx xadrianzetx deleted the feature-contour-algorithm branch July 30, 2021 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement Change that does not break compatibility and not affect public interfaces, but improves performance. optuna.visualization Related to the `optuna.visualization` submodule. This is automatically labeled by github-actions.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make optuna.visualization.matplotlib.plot_contour more beautiful

5 participants