Add function to carry out parallel fitting on N-dimensional datasets#16696
Add function to carry out parallel fitting on N-dimensional datasets#16696nden merged 58 commits intoastropy:mainfrom
Conversation
|
Thank you for your contribution to Astropy! 🌌 This checklist is meant to remind the package maintainers who will review this pull request of some common things to look for.
|
|
👋 Thank you for your draft pull request! Do you know that you can use |
pllim
left a comment
There was a problem hiding this comment.
I don't know how I feel about adding Python scripts and data files into docs.
As discussed "offline," would be nice to have a section explaining the subtle differences with the existing n_models framework.
Thanks!
bd511a5 to
359c24f
Compare
I've added a TODO in the description at the top to remind me to do this. |
|
Please use the new |
|
I dont think this falls under perf, it is a new feature |
292a3b0 to
b725df2
Compare
b725df2 to
149322d
Compare
|
This is now ready for review, and test coverage should be 100%. For now, it includes the commits for #16710 too to make sure everything works properly, but once that PR is merged we can rebase this one. But given the size of the present PR, it makes sense to start the review process, and once/if #16710 is merged it will be a simple rebase. The changelog error in CI is due to that other PR being included here, and will be resolved after the rebase. |
|
Hoping to make it more clear the difference between the one-model-multiple-data (this PR) vs. multiple-models-single-data (Model sets). I think the naming "parallel fitting" is too ambiguous since I also consider model sets to be parallelize-able. I don't have a suggestion on what to rename the API to, but could the sidebar navigation be more explicit
vs.
On this...
I think it's worth preserving the FittableModel's fit kwargs than forcing everyone to migrate to a different one when they need to do parallel fitting. Notice that all the FittableModels are just callable/function factories, so could your function instead just be a wrapper on single_fit = LMLSQFitter(...)
single_fit(g_init, coords, data[:, 0, 0], maxiter=10)
# Wrapper outputs a function that redefines args, add in new kwargs, but ultimately supports the orig kwargs
parallel_fit = parallel_fit_wrapper(single_fit)
parallel_fit(
g_init, wcs=..., data=data,
# new kwargs
fitting_axes=0,
# original kwargs preserved
maxiter=10
) |
|
Hi @ketozhang thanks for the feedback. I am not sure I am following the distinction you are making between this and model sets. It's possible with this PR to have different initial parameters for each model you fit along your fittable axes (i.e. different mean for your Gaussian for each spectra) by specifying your initial model with the parameters the same shape as the non-fittable axes. We didn't use model sets for this because they only support one axis as the "model set axis", and you can have more than one non-fittable axis with this. See this section of the docs @astrofrog wrote: https://github.com/astropy/astropy/pull/16696/files#diff-b4da8df1c755ce661147c86ee4a3f45ee200fd6ad62e09f59c3fe40d3972baa9R25-R29 I agree this is all a little confusing and doing some documentation updates is something we should have another look over once we get close to a final version of this. |
|
Thanks for the clarification!
Ah I was not aware you could do this—fair that you're not advertising this because it's discouraged. Sounds good on the docs update, thanks! |
…a that need to be fit
07761bd to
488659f
Compare
perrygreenfield
left a comment
There was a problem hiding this comment.
This adds a lot of important functionality and is a major piece of work. I don't think my review does it justice at the depth it should get. My focus mostly on the code and less so on the tests. I just had one significant comment, and a couple minor issues.
astropy/modeling/fitting_parallel.py
Outdated
| *pixel_nd, | ||
| wcs=deepcopy(wcs), | ||
| new_axis=0, | ||
| chunks=(3,) + chunks, |
There was a problem hiding this comment.
Does this presume that there are 3 axes to the world coordinates? If so, isn't that too presumptive?
There was a problem hiding this comment.
This was a (non-critical) mistake and would not prevent N-dimensional WCS from being passed, but would just make it not be optimally efficient. I have fixed this to use the actual number of dimensions instead of 3.
| # up and for map_blocks to work properly as noted in | ||
| # https://github.com/dask/dask/issues/11188. However, rather than use dask | ||
| # operations to broadcast these up, which creates a complex graph and | ||
| # results in high memory usage, we use a ParameterContainer which does the |
There was a problem hiding this comment.
This information should also be in the docstring for the ParameterContainer class.
astropy/modeling/fitting_parallel.py
Outdated
| The axes to keep for the fitting (other axes will be sliced/iterated over) | ||
| world : `None` or dict or APE-14-WCS | ||
| This can be specified either as a dictionary mapping fitting axes to | ||
| world axis values, or as a WCS for the whole cube. If the former, then |
There was a problem hiding this comment.
Here and elsewhere there seems to be an implicit assumption that the WCS is for a cube. But I don't recall seeing that explicitly stated in other parts. Can this be clarified?
There was a problem hiding this comment.
This supports any dimensionality for the WCS so long as it matches the data. We've tested it with 4D WCS (time/spectral/celestial x2)
astropy/modeling/fitting_parallel.py
Outdated
| index_abs = np.array(index) + np.array( | ||
| [block_info[0]["array-location"][idx][0] for idx in iterating_axes] | ||
| ) | ||
| # index = tuple(int(idx) for idx in np.unravel_index(i_abs, iterating_shape)) |
|
@perrygreenfield - thanks for the review! I think all your comments should be addressed. |

This is a collaborative PR between myself and @Cadair
Background
A commonly requested feature (at least one that @Cadair and I have heard a lot) these days is the ability to fit a model in parallel to many parts of a dataset, for example fitting a model to all the individual spectra in a spectral cube. This PR implements a helper function that can fit 1- or 2-dimensional models to N-dimensional datasets. The documentation and test suite could be expanded, but we are opening this pull request to allow people to already try this out and give early feedback.
The goal of the function is two-fold:
This is written in a very generic way so as to work with N-dimensional datasets, and any astropy.modeling model or fitter, so we hope that you agree that this belongs in astropy.modeling. This goes towards several of the items on the 2024 roadmap:
Simple example
A simple example of this in action, fitting all the spectra in a spectral cube with a simple Gaussian model:
Features so far
You can take a look at the RTD preview as well as the docstring of the function to see what can be done.
Discussion points
parallel_fit_daskto make this clear, and also to allow other parallel implementations to be added in future. For example, we could imagine havingparallel_fit_multiprocessing, which could use the built-in multiprocessing library. Adding these shouldn't be too much work but we also would prefer to defer this to a follow-up pull request to keep things simple. A reasonable amount of the code in this PR could be re-used in a multiprocessing implementation. We could imagine renaming the function toparallel_fitand having the method be a kwarg, but this is not ideal as different methods might require different kwargs, and the function will end up taking many arguments and be messy.Currently if- I've set the default to 500 instead of using the dask chunk size. I think this should be reasonable?n_chunk_max(the number of fits to carry out in a chunk) is not specified, the default is related to the default dask chunk size which isn't necessarily optimal - if the array is not very large, the data will just be broken up into a single chunk which won't benefit from parallelization. We could instead actually have n_chunk_max default to e.g. 100 or 1000 which would probably be better than the dask default. Alternatively, we could simply have it be a required argument to force the user to think about this.Detailed TODOs
__name__ == '__main__'being needed for multiprocessing