Skip to content

Add an additional line in ctx_mp_python.py line 679 to convert also numpy ndarrays of length 1 #824

@sven0schuierer

Description

@sven0schuierer

Hi,

I would suggest adding the line

    if isinstance(x, np.ndarray) and len(x) == 1: x = x[0]

as line 679 in ctx_mp_python.py to convert numpy ndarrays of length 1.

Justification:
The following code fails without this:

from mpmath import *
from scipy.optimize import fsolve
fsolve (lambda x: exp(x) - 5, 1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 170, in fsolve
    res = _root_hybr(_wrapped_func, x0, args, jac=fprime, **options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 238, in _root_hybr
    shape, dtype = _check_func('fsolve', 'func', func, x0, args, n, (n,))
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 23, in _check_func
    res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 158, in _wrapped_func
    return func(*fargs)
           ^^^^^^^^^^^^
  File "<stdin>", line 1, in <lambda>
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/ctx_mp_python.py", line 991, in f
    x = ctx.convert(x)
        ^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/ctx_mp_python.py", line 650, in convert
    if type(x).__module__ == 'numpy': return ctx.npconvert(x)
                                             ^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/ctx_mp_python.py", line 683, in npconvert
    raise TypeError("cannot create mpf from " + repr(x))
TypeError: cannot create mpf from array([1])

Of course, one could use findroot instead of fsolve above. But there are instances in which fsolve works and findroot does not:

n_1, n_obs_1, n_2, n_obs_2 = 62380000, 34539224.0, 62390000, 34542106.0
def f_1 (N0, N1):
    return N0 * (1 - exp (-1 * n_1 / (N0 + 2 * N1))) + N1 * (1 - exp (-2 * n_1 / (N0 + 2 * N1))) - n_obs_1,  \
           N0 * (1 - exp (-1 * n_2 / (N0 + 2 * N1))) + N1 * (1 - exp (-2 * n_2 / (N0 + 2 * N1))) - n_obs_2

findroot (f_1, (27109937, 28500053))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/calculus/optimization.py", line 969, in findroot
    for x, error in iterations:
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/calculus/optimization.py", line 660, in __iter__
    s = self.ctx.lu_solve(Jx, fxn)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/matrices/linalg.py", line 224, in lu_solve
    A, p = ctx.LU_decomp(A)
           ^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/mpmath/matrices/linalg.py", line 149, in LU_decomp
    raise ZeroDivisionError('matrix is numerically singular')
ZeroDivisionError: matrix is numerically singular

def f_2 (param):
    N0, N1 = param
    return [N0 * (1 - exp (-1 * n_1 / (N0 + 2 * N1))) + N1 * (1 - exp (-2 * n_1 / (N0 + 2 * N1))) - n_obs_1,  \
            N0 * (1 - exp (-1 * n_2 / (N0 + 2 * N1))) + N1 * (1 - exp (-2 * n_2 / (N0 + 2 * N1))) - n_obs_2]

fsolve (f_2, [27109937, 28500053])
<stdin>:1: RuntimeWarning: The iteration is not making good progress, as measured by the 
  improvement from the last ten iterations.
array([36079713.35263181, 14271606.52089586])

Finally, there are instances in which fsolve runs into an overflow with the imports from math but not from mpmath:

from math import *
from scipy.optimize import fsolve

def f (param):
  N0, N1, N2 = param
  return [N0 * (1 - exp (-1 * n_1 / (N0 + 2 * N1 + 3 * N2))) + N1 * (1 - exp (-2 * n_1 / (N0 + 2 * N1 + 3 * N2))) + N2 * (1 - exp (-3 * n_1 / (N0 + 2 * N1 + 3 * N2))) - n_obs_1, 
          N0 * (1 - exp (-1 * n_2 / (N0 + 2 * N1 + 3 * N2))) + N1 * (1 - exp (-2 * n_2 / (N0 + 2 * N1 + 3 * N2))) + N2 * (1 - exp (-3 * n_1 / (N0 + 2 * N1 + 3 * N2))) - n_obs_2, 
          N0 * (1 - exp (-1 * n_3 / (N0 + 2 * N1 + 3 * N2))) + N1 * (1 - exp (-2 * n_3 / (N0 + 2 * N1 + 3 * N2))) + N2 * (1 - exp (-3 * n_1 / (N0 + 2 * N1 + 3 * N2))) - n_obs_3]

n_1, n_obs_1, n_2, n_obs_2, n_3, n_obs_3 = 62380000, 34539224.0, 62390000, 34542106.0, 62400000, 34544973.0
initial_estimate = [27376700.0, 18251133.333333332, 9125566.666666666]
fsolve (f, initial_estimate)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 170, in fsolve
    res = _root_hybr(_wrapped_func, x0, args, jac=fprime, **options)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 249, in _root_hybr
    retval = _minpack._hybrd(func, x0, args, 1, xtol, maxfev,
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/da/cbt/GX/schuisv1/python/envs/library-complexity-test/lib/python3.12/site-packages/scipy/optimize/_minpack_py.py", line 158, in _wrapped_func
    return func(*fargs)
           ^^^^^^^^^^^^
  File "<stdin>", line 3, in f
OverflowError: math range error
from mpmath import *
>>> fsolve (f, initial_estimate)
<stdin>:3: RuntimeWarning: invalid value encountered in multiply
<stdin>:4: RuntimeWarning: invalid value encountered in multiply
<stdin>:5: RuntimeWarning: invalid value encountered in multiply
<stdin>:1: RuntimeWarning: The iteration is not making good progress, as measured by the 
  improvement from the last ten iterations.
array([27376700.        , 18251133.33333333,  9125566.66666667])

Best regards,
Sven

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugan unexpected problem or unintended behavior

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions