Skip to content

Add function to carry out parallel fitting on N-dimensional datasets#16696

Merged
nden merged 58 commits intoastropy:mainfrom
astrofrog:parallel-fitting
Sep 3, 2024
Merged

Add function to carry out parallel fitting on N-dimensional datasets#16696
nden merged 58 commits intoastropy:mainfrom
astrofrog:parallel-fitting

Conversation

@astrofrog
Copy link
Member

@astrofrog astrofrog commented Jul 10, 2024

This is a collaborative PR between myself and @Cadair

Background

A commonly requested feature (at least one that @Cadair and I have heard a lot) these days is the ability to fit a model in parallel to many parts of a dataset, for example fitting a model to all the individual spectra in a spectral cube. This PR implements a helper function that can fit 1- or 2-dimensional models to N-dimensional datasets. The documentation and test suite could be expanded, but we are opening this pull request to allow people to already try this out and give early feedback.

The goal of the function is two-fold:

  • To make it easy to carry out many fits in a data cube
  • To provide easy parallelization, by default using multiple processes to speed things up

This is written in a very generic way so as to work with N-dimensional datasets, and any astropy.modeling model or fitter, so we hope that you agree that this belongs in astropy.modeling. This goes towards several of the items on the 2024 roadmap:

  • Improve and/or maintain interoperability with performant I/O file formats and libraries such as HDF5 and Dask. (this adds integration of astropy.modeling with dask)
  • Improve support for using Astropy tools in heterogeneous computing environments such as cloud environments or GPU systems. (this allows model fits to be carried out on any kind of distributed environment supported by dask, including clusters in the cloud)

Simple example

A simple example of this in action, fitting all the spectra in a spectral cube with a simple Gaussian model:

from astropy.io import fits

import matplotlib.pyplot as plt

from astropy.modeling.models import Gaussian1D
from astropy.modeling.fitting import LMLSQFitter
from astropy.modeling.fitting_parallel import parallel_fit_dask

if __name__ == "__main__":

    data = fits.getdata("l1448_13co.fits")

    g_init = Gaussian1D(mean=25, stddev=10, amplitude=1)
    g_init.mean.bounds = (0, 53)

    g_fit = parallel_fit_dask(
        model=g_init,
        fitter=LMLSQFitter(),
        data=data,
        fitting_axes=0,
    )

    fig = plt.figure(figsize=(10, 5))
    ax = fig.add_subplot(1, 3, 1)
    ax.imshow(g_fit.amplitude.value, vmin=0, vmax=5)
    ax = fig.add_subplot(1, 3, 2)
    ax.imshow(g_fit.mean.value, vmin=0, vmax=50)
    ax = fig.add_subplot(1, 3, 3)
    ax.imshow(g_fit.stddev.value, vmin=0, vmax=20)
    fig.savefig("results_no_world.png")

results_no_world

Features so far

  • Fit 1- or 2-dimensional models to any of the axes of an N-dimensional data cube.
  • The dataset is broken up efficiently into chunks, each of which is processed inside e.g. a process or thread
  • No understanding of dask required (it just needs to be installed)
  • Designed to be able to work with large datasets, including dask datasets that might be larger than memory. We've tested this on larger cubes with over a million compound model fits to be carried out, and it worked nicely.
  • Ability to specify the world coordinates (used for the x or x/y axes in the fitting) via 1D arrays, ND arrays, or any APE-14-compliant WCS, or just use pixel coordinates
  • Ability to output information about failed fits or fits that emit warnings
  • Ability to use this with any dask scheduler, including dask.distributed. We've tested this out with a cloud-based dask cluster (via https://www.coiled.io/) and it worked, although of course the efficiency of this depends on how fast one can send/receive data from the cloud.

You can take a look at the RTD preview as well as the docstring of the function to see what can be done.

Discussion points

  • This implementation makes use of dask, and requires dask to be installed (it is an optional astropy dependency). We have named the function parallel_fit_dask to make this clear, and also to allow other parallel implementations to be added in future. For example, we could imagine having parallel_fit_multiprocessing, which could use the built-in multiprocessing library. Adding these shouldn't be too much work but we also would prefer to defer this to a follow-up pull request to keep things simple. A reasonable amount of the code in this PR could be re-used in a multiprocessing implementation. We could imagine renaming the function to parallel_fit and having the method be a kwarg, but this is not ideal as different methods might require different kwargs, and the function will end up taking many arguments and be messy.
  • Currently if n_chunk_max (the number of fits to carry out in a chunk) is not specified, the default is related to the default dask chunk size which isn't necessarily optimal - if the array is not very large, the data will just be broken up into a single chunk which won't benefit from parallelization. We could instead actually have n_chunk_max default to e.g. 100 or 1000 which would probably be better than the dask default. Alternatively, we could simply have it be a required argument to force the user to think about this. - I've set the default to 500 instead of using the dask chunk size. I think this should be reasonable?

Detailed TODOs

  • By checking this box, the PR author has requested that maintainers do NOT use the "Squash and Merge" button. Maintainers should respect this when possible; however, the final decision is at the discretion of the maintainer that merges the PR.

@github-actions
Copy link
Contributor

Thank you for your contribution to Astropy! 🌌 This checklist is meant to remind the package maintainers who will review this pull request of some common things to look for.

  • Do the proposed changes actually accomplish desired goals?
  • Do the proposed changes follow the Astropy coding guidelines?
  • Are tests added/updated as required? If so, do they follow the Astropy testing guidelines?
  • Are docs added/updated as required? If so, do they follow the Astropy documentation guidelines?
  • Is rebase and/or squash necessary? If so, please provide the author with appropriate instructions. Also see instructions for rebase and squash.
  • Did the CI pass? If no, are the failures related? If you need to run daily and weekly cron jobs as part of the PR, please apply the "Extra CI" label. Codestyle issues can be fixed by the bot.
  • Is a change log needed? If yes, did the change log check pass? If no, add the "no-changelog-entry-needed" label. If this is a manual backport, use the "skip-changelog-checks" label unless special changelog handling is necessary.
  • Is this a big PR that makes a "What's new?" entry worthwhile and if so, is (1) a "what's new" entry included in this PR and (2) the "whatsnew-needed" label applied?
  • At the time of adding the milestone, if the milestone set requires a backport to release branch(es), apply the appropriate "backport-X.Y.x" label(s) before merge.

@github-actions
Copy link
Contributor

👋 Thank you for your draft pull request! Do you know that you can use [ci skip] or [skip ci] in your commit messages to skip running continuous integration tests until you are ready?

Copy link
Member

@pllim pllim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how I feel about adding Python scripts and data files into docs.

As discussed "offline," would be nice to have a section explaining the subtle differences with the existing n_models framework.

Thanks!

@astrofrog
Copy link
Member Author

@pllim - I've removed the separate script and sped up the docs page. Combined with #16673, the docs page here takes <10s to generate all the plots, which I think is acceptable.

@astrofrog
Copy link
Member Author

As discussed "offline," would be nice to have a section explaining the subtle differences with the existing n_models framework.

I've added a TODO in the description at the top to remind me to do this.

@pllim
Copy link
Member

pllim commented Jul 11, 2024

Please use the new perf change log category, see:

@astrofrog
Copy link
Member Author

I dont think this falls under perf, it is a new feature

@astrofrog
Copy link
Member Author

astrofrog commented Jul 12, 2024

I've included the bug fix from #16710 to show how with that fix we can avoid all the hackery around copying models with different parameter shapes (ace17a3)

@pllim
Copy link
Member

pllim commented Jul 12, 2024

Did someone say hackery?

@astrofrog
Copy link
Member Author

astrofrog commented Jul 22, 2024

This is now ready for review, and test coverage should be 100%. For now, it includes the commits for #16710 too to make sure everything works properly, but once that PR is merged we can rebase this one. But given the size of the present PR, it makes sense to start the review process, and once/if #16710 is merged it will be a simple rebase. The changelog error in CI is due to that other PR being included here, and will be resolved after the rebase.

@astrofrog astrofrog marked this pull request as ready for review July 22, 2024 11:05
@astrofrog astrofrog requested a review from a team as a code owner July 22, 2024 11:05
@astrofrog astrofrog changed the title Add helper function to carry out parallel fitting on N-dimensional datasets Add function to carry out parallel fitting on N-dimensional datasets Jul 22, 2024
@ketozhang
Copy link

ketozhang commented Jul 22, 2024

Hoping to make it more clear the difference between the one-model-multiple-data (this PR) vs. multiple-models-single-data (Model sets). I think the naming "parallel fitting" is too ambiguous since I also consider model sets to be parallelize-able.

I don't have a suggestion on what to rename the API to, but could the sidebar navigation be more explicit

  • Compound Models
  • Parallel Fitting
  • Fitting a line
  • Fitting with constraints
  • Fitting Model Sets

vs.

  • Fitting a line
  • Fitting with constraints
  • Fitting compound models
  • Fitting many independent models on same data (Model sets)
  • Fitting same model on multiple data (this PR)

On this...

We could imagine renaming the function to parallel_fit and having the method be a kwarg, but this is not ideal as different methods might require different kwargs, and the function will end up taking many arguments and be messy.

I think it's worth preserving the FittableModel's fit kwargs than forcing everyone to migrate to a different one when they need to do parallel fitting. Notice that all the FittableModels are just callable/function factories, so could your function instead just be a wrapper on FittableModel.__call__ (or any other callable)? You may use type hinting to propagate the callable's kwargs with Protocol as the return type (but types are controversial in astropy 🤷 ).

single_fit = LMLSQFitter(...)
single_fit(g_init, coords, data[:, 0, 0], maxiter=10)

# Wrapper outputs a function that redefines args, add in new kwargs, but ultimately supports the orig kwargs
parallel_fit = parallel_fit_wrapper(single_fit)
parallel_fit(
  g_init, wcs=..., data=data, 
  # new kwargs
  fitting_axes=0, 
  # original kwargs preserved
  maxiter=10
)

@Cadair
Copy link
Member

Cadair commented Jul 23, 2024

Hi @ketozhang thanks for the feedback.

I am not sure I am following the distinction you are making between this and model sets. It's possible with this PR to have different initial parameters for each model you fit along your fittable axes (i.e. different mean for your Gaussian for each spectra) by specifying your initial model with the parameters the same shape as the non-fittable axes. We didn't use model sets for this because they only support one axis as the "model set axis", and you can have more than one non-fittable axis with this. See this section of the docs @astrofrog wrote: https://github.com/astropy/astropy/pull/16696/files#diff-b4da8df1c755ce661147c86ee4a3f45ee200fd6ad62e09f59c3fe40d3972baa9R25-R29

I agree this is all a little confusing and doing some documentation updates is something we should have another look over once we get close to a final version of this.

@ketozhang
Copy link

Thanks for the clarification!

this PR to have different initial parameters for each model you fit along your fittable axes

Ah I was not aware you could do this—fair that you're not advertising this because it's discouraged. Sounds good on the docs update, thanks!

Copy link
Member

@perrygreenfield perrygreenfield left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This adds a lot of important functionality and is a major piece of work. I don't think my review does it justice at the depth it should get. My focus mostly on the code and less so on the tests. I just had one significant comment, and a couple minor issues.

*pixel_nd,
wcs=deepcopy(wcs),
new_axis=0,
chunks=(3,) + chunks,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this presume that there are 3 axes to the world coordinates? If so, isn't that too presumptive?

Copy link
Member Author

@astrofrog astrofrog Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a (non-critical) mistake and would not prevent N-dimensional WCS from being passed, but would just make it not be optimally efficient. I have fixed this to use the actual number of dimensions instead of 3.

# up and for map_blocks to work properly as noted in
# https://github.com/dask/dask/issues/11188. However, rather than use dask
# operations to broadcast these up, which creates a complex graph and
# results in high memory usage, we use a ParameterContainer which does the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This information should also be in the docstring for the ParameterContainer class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

The axes to keep for the fitting (other axes will be sliced/iterated over)
world : `None` or dict or APE-14-WCS
This can be specified either as a dictionary mapping fitting axes to
world axis values, or as a WCS for the whole cube. If the former, then
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and elsewhere there seems to be an implicit assumption that the WCS is for a cube. But I don't recall seeing that explicitly stated in other parts. Can this be clarified?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This supports any dimensionality for the WCS so long as it matches the data. We've tested it with 4D WCS (time/spectral/celestial x2)

index_abs = np.array(index) + np.array(
[block_info[0]["array-location"][idx][0] for idx in iterating_axes]
)
# index = tuple(int(idx) for idx in np.unravel_index(i_abs, iterating_shape))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Commented out code.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

@astrofrog
Copy link
Member Author

@perrygreenfield - thanks for the review! I think all your comments should be addressed.

Copy link
Member

@perrygreenfield perrygreenfield left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants