Skip to content

Commit f807b33

Browse files
committed
updated optimal constraint handling + unit tests
1 parent be660ce commit f807b33

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

control/optimal.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def __init__(
146146
else:
147147
self.trajectory_constraints = trajectory_constraints
148148

149+
# Make sure that we recognize all of the constraint types
150+
for ctype, fun, lb, ub in self.trajectory_constraints:
151+
if not ctype in [opt.LinearConstraint, opt.NonlinearConstraint]:
152+
raise TypeError(f"unknown constraint type {ctype}")
153+
149154
# Process terminal constraints
150155
if isinstance(terminal_constraints, tuple):
151156
self.terminal_constraints = [terminal_constraints]
@@ -154,6 +159,11 @@ def __init__(
154159
else:
155160
self.terminal_constraints = terminal_constraints
156161

162+
# Make sure that we recognize all of the constraint types
163+
for ctype, fun, lb, ub in self.terminal_constraints:
164+
if not ctype in [opt.LinearConstraint, opt.NonlinearConstraint]:
165+
raise TypeError(f"unknown constraint type {ctype}")
166+
157167
#
158168
# Compute and store constraints
159169
#
@@ -401,7 +411,8 @@ def _constraint_function(self, coeffs):
401411
value.append(fun @ np.hstack([states[:, i], inputs[:, i]]))
402412
elif ctype == opt.NonlinearConstraint:
403413
value.append(fun(states[:, i], inputs[:, i]))
404-
else:
414+
else: # pragma: no cover
415+
# Checked above => we should never get here
405416
raise TypeError(f"unknown constraint type {ctype}")
406417

407418
# Evaluate the terminal constraint functions
@@ -413,7 +424,8 @@ def _constraint_function(self, coeffs):
413424
value.append(fun @ np.hstack([states[:, i], inputs[:, i]]))
414425
elif ctype == opt.NonlinearConstraint:
415426
value.append(fun(states[:, i], inputs[:, i]))
416-
else:
427+
else: # pragma: no cover
428+
# Checked above => we should never get here
417429
raise TypeError(f"unknown constraint type {ctype}")
418430

419431
# Update statistics
@@ -485,7 +497,8 @@ def _eqconst_function(self, coeffs):
485497
value.append(fun @ np.hstack([states[:, i], inputs[:, i]]))
486498
elif ctype == opt.NonlinearConstraint:
487499
value.append(fun(states[:, i], inputs[:, i]))
488-
else:
500+
else: # pragma: no cover
501+
# Checked above => we should never get here
489502
raise TypeError(f"unknown constraint type {ctype}")
490503

491504
# Evaluate the terminal constraint functions
@@ -497,7 +510,8 @@ def _eqconst_function(self, coeffs):
497510
value.append(fun @ np.hstack([states[:, i], inputs[:, i]]))
498511
elif ctype == opt.NonlinearConstraint:
499512
value.append(fun(states[:, i], inputs[:, i]))
500-
else:
513+
else: # pragma: no cover
514+
# Checked above => we should never get here
501515
raise TypeError("unknown constraint type {ctype}")
502516

503517
# Update statistics
@@ -844,7 +858,7 @@ def __init__(
844858

845859
# Compute the input for a nonlinear, (constrained) optimal control problem
846860
def solve_ocp(
847-
sys, horizon, X0, cost, constraints=[], terminal_cost=None,
861+
sys, horizon, X0, cost, trajectory_constraints=[], terminal_cost=None,
848862
terminal_constraints=[], initial_guess=None, basis=None, squeeze=None,
849863
transpose=None, return_states=False, log=False, **kwargs):
850864

@@ -865,7 +879,7 @@ def solve_ocp(
865879
Function that returns the integral cost given the current state
866880
and input. Called as `cost(x, u)`.
867881
868-
constraints : list of tuples, optional
882+
trajectory_constraints : list of tuples, optional
869883
List of constraints that should hold at each point in the time vector.
870884
Each element of the list should consist of a tuple with first element
871885
given by :meth:`scipy.optimize.LinearConstraint` or
@@ -943,13 +957,18 @@ def solve_ocp(
943957
:func:`OptimalControlProblem` for more information.
944958
945959
"""
960+
# Process keyword arguments
961+
if trajectory_constraints is None:
962+
# Backwards compatibility
963+
trajectory_constraints = kwargs.pop('constraints', None)
964+
946965
# Allow 'return_x` as a synonym for 'return_states'
947966
return_states = ct.config._get_param(
948967
'optimal', 'return_x', kwargs, return_states, pop=True)
949968

950969
# Set up the optimal control problem
951970
ocp = OptimalControlProblem(
952-
sys, horizon, cost, trajectory_constraints=constraints,
971+
sys, horizon, cost, trajectory_constraints=trajectory_constraints,
953972
terminal_cost=terminal_cost, terminal_constraints=terminal_constraints,
954973
initial_guess=initial_guess, basis=basis, log=log, **kwargs)
955974

control/tests/optimal_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,17 @@ def test_ocp_argument_errors():
441441
res = opt.solve_ocp(
442442
sys, time, x0, cost, constraints, terminal_constraint=None)
443443

444+
# Unrecognized trajectory constraint type
445+
constraints = [(None, np.eye(3), [0, 0, 0], [0, 0, 0])]
446+
with pytest.raises(TypeError, match="unknown constraint type"):
447+
res = opt.solve_ocp(
448+
sys, time, x0, cost, trajectory_constraints=constraints)
449+
450+
# Unrecognized terminal constraint type
451+
with pytest.raises(TypeError, match="unknown constraint type"):
452+
res = opt.solve_ocp(
453+
sys, time, x0, cost, terminal_constraints=constraints)
454+
444455

445456
def test_optimal_basis_simple():
446457
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
@@ -484,3 +495,57 @@ def test_optimal_basis_simple():
484495
basis=flat.BezierFamily(4, Tf), return_x=True, log=True)
485496
assert res3.success
486497
np.testing.assert_almost_equal(res3.inputs, res1.inputs, decimal=3)
498+
499+
500+
def test_equality_constraints():
501+
"""Test out the ability to handle equality constraints"""
502+
# Create the system (double integrator, continuous time)
503+
sys = ct.ss2io(ct.ss(np.zeros((2, 2)), np.eye(2), np.eye(2), 0))
504+
505+
# Shortest path to a point is a line
506+
Q = np.zeros((2, 2))
507+
R = np.eye(2)
508+
cost = opt.quadratic_cost(sys, Q, R)
509+
510+
# Set up the terminal constraint to be the origin
511+
final_point = [opt.state_range_constraint(sys, [0, 0], [0, 0])]
512+
513+
# Create the optimal control problem
514+
time = np.arange(0, 3, 1)
515+
optctrl = opt.OptimalControlProblem(
516+
sys, time, cost, terminal_constraints=final_point)
517+
518+
# Find a path to the origin
519+
x0 = np.array([4, 3])
520+
res = optctrl.compute_trajectory(x0, squeeze=True, return_x=True)
521+
t, u1, x1 = res.time, res.inputs, res.states
522+
523+
# Bug prior to SciPy 1.6 will result in incorrect results
524+
if NumpyVersion(sp.__version__) < '1.6.0':
525+
pytest.xfail("SciPy 1.6 or higher required")
526+
527+
np.testing.assert_almost_equal(x1[:,-1], 0, decimal=4)
528+
529+
# Set up terminal constraints as a nonlinear constraint
530+
def final_point_eval(x, u):
531+
return x
532+
final_point = [
533+
(sp.optimize.NonlinearConstraint, final_point_eval, [0, 0], [0, 0])]
534+
535+
optctrl = opt.OptimalControlProblem(
536+
sys, time, cost, terminal_constraints=final_point)
537+
538+
# Find a path to the origin
539+
x0 = np.array([4, 3])
540+
res = optctrl.compute_trajectory(x0, squeeze=True, return_x=True)
541+
t, u2, x2 = res.time, res.inputs, res.states
542+
np.testing.assert_almost_equal(x2[:,-1], 0, decimal=4)
543+
np.testing.assert_almost_equal(u1, u2)
544+
np.testing.assert_almost_equal(x1, x2)
545+
546+
# Try passing and unknown constraint type
547+
final_point = [(None, final_point_eval, [0, 0], [0, 0])]
548+
with pytest.raises(TypeError, match="unknown constraint type"):
549+
optctrl = opt.OptimalControlProblem(
550+
sys, time, cost, terminal_constraints=final_point)
551+
res = optctrl.compute_trajectory(x0, squeeze=True, return_x=True)

0 commit comments

Comments
 (0)