Skip to content

Commit 6b4266c

Browse files
committed
Added a new quantity_asanyarray function to properly deal with lists of quantities, and use this to fix the use of models with units when n_models > 1
1 parent e39dd6f commit 6b4266c

2 files changed

Lines changed: 14 additions & 2 deletions

File tree

astropy/modeling/core.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from ..extern.six.moves import copyreg, zip
3636
from ..table import Table
3737
from ..units import Quantity, UnitBase, UnitsError, dimensionless_unscaled
38+
from ..units.utils import quantity_asanyarray
3839
from ..utils import (sharedmethod, find_current_module,
3940
InheritDocstrings, OrderedDescriptorContainer,
4041
check_broadcast, IncompatibleShapeError)
@@ -1455,7 +1456,8 @@ def _initialize_parameters(self, args, kwargs):
14551456
if arg is None:
14561457
# A value of None implies using the default value, if exists
14571458
continue
1458-
params[self.param_names[idx]] = np.asanyarray(arg, dtype=np.float)
1459+
params[self.param_names[idx]] = quantity_asanyarray(arg, dtype=np.float)
1460+
14591461

14601462
# At this point the only remaining keyword arguments should be
14611463
# parameter names; any others are in error.
@@ -1469,7 +1471,7 @@ def _initialize_parameters(self, args, kwargs):
14691471
if value is None:
14701472
continue
14711473
else:
1472-
params[param_name] = np.asanyarray(value, dtype=np.float)
1474+
params[param_name] = quantity_asanyarray(value, dtype=np.float)
14731475

14741476
if kwargs:
14751477
# If any keyword arguments were left over at this point they are

astropy/units/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from numpy import finfo
2020

2121
from ..extern import six
22+
from .. import units as u
2223

2324
_float_finfo = finfo(float)
2425
# take float here to ensure comparison with another float is fast
@@ -219,3 +220,12 @@ def resolve_fractions(a, b):
219220
elif not a_is_fraction and b_is_fraction:
220221
a = Fraction(a)
221222
return a, b
223+
224+
225+
def quantity_asanyarray(a, dtype=None):
226+
if isinstance(a, np.ndarray):
227+
return np.asanyarray(a, dtype=dtype)
228+
elif any(isinstance(x, u.Quantity) for x in a):
229+
return u.Quantity(a, dtype=dtype)
230+
else:
231+
raise ValueError("Unexpected object passed to quantity_asanyarray: {0}".format(a))

0 commit comments

Comments
 (0)