@@ -463,28 +463,53 @@ def test_append_tf(self):
463463 np .testing .assert_array_almost_equal (sys3c .A [:3 , 3 :], np .zeros ((3 , 2 )))
464464 np .testing .assert_array_almost_equal (sys3c .A [3 :, :3 ], np .zeros ((2 , 3 )))
465465
466- def test_array_access_ss (self ):
467-
466+ def test_array_access_ss_failure (self ):
468467 sys1 = StateSpace (
469468 [[1. , 2. ], [3. , 4. ]],
470469 [[5. , 6. ], [6. , 8. ]],
471470 [[9. , 10. ], [11. , 12. ]],
472471 [[13. , 14. ], [15. , 16. ]], 1 ,
473472 inputs = ['u0' , 'u1' ], outputs = ['y0' , 'y1' ])
473+ with pytest .raises (IOError ):
474+ sys1 [0 ]
475+
476+ @pytest .mark .parametrize ("outdx, inpdx" ,
477+ [(0 , 1 ),
478+ (slice (0 , 1 , 1 ), 1 ),
479+ (0 , slice (1 , 2 , 1 )),
480+ (slice (0 , 1 , 1 ), slice (1 , 2 , 1 )),
481+ (slice (None , None , - 1 ), 1 ),
482+ (0 , slice (None , None , - 1 )),
483+ (slice (None , 2 , None ), 1 ),
484+ (slice (None , None , 1 ), slice (None , None , 2 )),
485+ (0 , slice (1 , 2 , 1 )),
486+ (slice (0 , 1 , 1 ), slice (1 , 2 , 1 ))])
487+ def test_array_access_ss (self , outdx , inpdx ):
488+ sys1 = StateSpace (
489+ [[1. , 2. ], [3. , 4. ]],
490+ [[5. , 6. ], [7. , 8. ]],
491+ [[9. , 10. ], [11. , 12. ]],
492+ [[13. , 14. ], [15. , 16. ]], 1 ,
493+ inputs = ['u0' , 'u1' ], outputs = ['y0' , 'y1' ])
474494
475- sys1_01 = sys1 [0 , 1 ]
495+ sys1_01 = sys1 [outdx , inpdx ]
496+
497+ # Convert int to slice to ensure that numpy doesn't drop the dimension
498+ if isinstance (outdx , int ): outdx = slice (outdx , outdx + 1 , 1 )
499+ if isinstance (inpdx , int ): inpdx = slice (inpdx , inpdx + 1 , 1 )
500+
476501 np .testing .assert_array_almost_equal (sys1_01 .A ,
477502 sys1 .A )
478503 np .testing .assert_array_almost_equal (sys1_01 .B ,
479- sys1 .B [:, 1 : 2 ])
504+ sys1 .B [:, inpdx ])
480505 np .testing .assert_array_almost_equal (sys1_01 .C ,
481- sys1 .C [0 : 1 , :])
506+ sys1 .C [outdx , :])
482507 np .testing .assert_array_almost_equal (sys1_01 .D ,
483- sys1 .D [0 , 1 ])
508+ sys1 .D [outdx , inpdx ])
484509
485510 assert sys1 .dt == sys1_01 .dt
486- assert sys1_01 .input_labels == [ 'u1' ]
487- assert sys1_01 .output_labels == [ 'y0' ]
511+ assert sys1_01 .input_labels == sys1 . input_labels [ inpdx ]
512+ assert sys1_01 .output_labels == sys1 . output_labels [ outdx ]
488513 assert sys1_01 .name == sys1 .name + "$indexed"
489514
490515 def test_dc_gain_cont (self ):
0 commit comments