1212import warnings
1313from math import pi
1414
15+ import matplotlib as mpl
1516import matplotlib .pyplot as plt
1617import numpy as np
1718import pytest
@@ -138,29 +139,39 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
138139
139140 # Use callable form, with parameters (if not correct, will get /0 error)
140141 ct .phase_plane_plot (
141- invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )})
142+ invpend_ode , [- 5 , 5 , - 2 , 2 ], params = {'args' : (1 , 1 , 0.2 , 1 )},
143+ plot_streamlines = True )
142144
143145 # Linear I/O system
144146 ct .phase_plane_plot (
145- ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ))
147+ ct .ss ([[0 , 1 ], [- 1 , - 1 ]], [[0 ], [1 ]], [[1 , 0 ]], 0 ),
148+ plot_streamlines = True )
146149
147150
148151@pytest .mark .usefixtures ('mplcleanup' )
149152def test_phaseplane_errors ():
150153 with pytest .raises (ValueError , match = "invalid grid specification" ):
151- ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridspec = 'bad' )
154+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridspec = 'bad' ,
155+ plot_streamlines = True )
152156
153157 with pytest .raises (ValueError , match = "unknown grid type" ):
154- ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' )
158+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), gridtype = 'bad' ,
159+ plot_streamlines = True )
155160
156161 with pytest .raises (ValueError , match = "system must be planar" ):
157- ct .phase_plane_plot (ct .rss (3 , 1 , 1 ))
162+ ct .phase_plane_plot (ct .rss (3 , 1 , 1 ),
163+ plot_streamlines = True )
158164
159165 with pytest .raises (ValueError , match = "params must be dict with key" ):
160166 def invpend_ode (t , x , m = 0 , l = 0 , b = 0 , g = 0 ):
161167 return (x [1 ], - b / m * x [1 ] + (g * l / m ) * np .sin (x [0 ]))
162168 ct .phase_plane_plot (
163- invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )})
169+ invpend_ode , [- 5 , 5 , 2 , 2 ], params = {'stuff' : (1 , 1 , 0.2 , 1 )},
170+ plot_streamlines = True )
171+
172+ with pytest .raises (ValueError , match = "gridtype must be 'meshgrid' when using streamplot" ):
173+ ct .phase_plane_plot (ct .rss (2 , 1 , 1 ), plot_streamlines = False ,
174+ plot_streamplot = True , gridtype = 'boxgrid' )
164175
165176 # Warning messages for invalid solutions: nonlinear spring mass system
166177 sys = ct .nlsys (
@@ -171,14 +182,87 @@ def invpend_ode(t, x, m=0, l=0, b=0, g=0):
171182 UserWarning , match = r"initial_state=\[.*\], solve_ivp failed" ):
172183 ct .phase_plane_plot (
173184 sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
174- plot_separatrices = False )
185+ plot_separatrices = False , plot_streamlines = True )
175186
176187 # Turn warnings off
177188 with warnings .catch_warnings ():
178189 warnings .simplefilter ("error" )
179190 ct .phase_plane_plot (
180191 sys , [- 12 , 12 , - 10 , 10 ], 15 , gridspec = [2 , 9 ],
181- plot_separatrices = False , suppress_warnings = True )
192+ plot_streamlines = True , plot_separatrices = False ,
193+ suppress_warnings = True )
194+
195+ @pytest .mark .usefixtures ('mplcleanup' )
196+ def test_phase_plot_zorder ():
197+ # some of these tests are a bit akward since the streamlines and separatrices
198+ # are stored in the same list, so we separate them by color
199+ key_color = "tab:blue" # must not be 'k', 'r', 'b' since they are used by separatrices
200+
201+ def get_zorders (cplt ):
202+ max_zorder = lambda items : max ([line .get_zorder () for line in items ])
203+ assert isinstance (cplt .lines [0 ], list )
204+ streamline_lines = [line for line in cplt .lines [0 ] if line .get_color () == key_color ]
205+ separatrice_lines = [line for line in cplt .lines [0 ] if line .get_color () != key_color ]
206+ streamlines = max_zorder (streamline_lines ) if streamline_lines else None
207+ separatrices = max_zorder (separatrice_lines ) if separatrice_lines else None
208+ assert cplt .lines [1 ] == None or isinstance (cplt .lines [1 ], mpl .quiver .Quiver )
209+ quiver = cplt .lines [1 ].get_zorder () if cplt .lines [1 ] else None
210+ assert cplt .lines [2 ] == None or isinstance (cplt .lines [2 ], list )
211+ equilpoints = max_zorder (cplt .lines [2 ]) if cplt .lines [2 ] else None
212+ assert cplt .lines [3 ] == None or isinstance (cplt .lines [3 ], mpl .streamplot .StreamplotSet )
213+ streamplot = max (cplt .lines [3 ].lines .get_zorder (), cplt .lines [3 ].arrows .get_zorder ()) if cplt .lines [3 ] else None
214+ return streamlines , quiver , streamplot , separatrices , equilpoints
215+
216+ def assert_orders (streamlines , quiver , streamplot , separatrices , equilpoints ):
217+ print (streamlines , quiver , streamplot , separatrices , equilpoints )
218+ if streamlines is not None :
219+ assert streamlines < separatrices < equilpoints
220+ if quiver is not None :
221+ assert quiver < separatrices < equilpoints
222+ if streamplot is not None :
223+ assert streamplot < separatrices < equilpoints
224+
225+ def sys (t , x ):
226+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
227+
228+ # ensure correct zordering for all three flow types
229+ res_streamlines = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color ))
230+ assert_orders (* get_zorders (res_streamlines ))
231+ res_vectorfield = ct .phase_plane_plot (sys , plot_vectorfield = True )
232+ assert_orders (* get_zorders (res_vectorfield ))
233+ res_streamplot = ct .phase_plane_plot (sys , plot_streamplot = True )
234+ assert_orders (* get_zorders (res_streamplot ))
235+
236+ # ensure that zorder can still be overwritten
237+ res_reversed = ct .phase_plane_plot (sys , plot_streamlines = dict (color = key_color , zorder = 50 ), plot_vectorfield = dict (zorder = 40 ),
238+ plot_streamplot = dict (zorder = 30 ), plot_separatrices = dict (zorder = 20 ), plot_equilpoints = dict (zorder = 10 ))
239+ streamlines , quiver , streamplot , separatrices , equilpoints = get_zorders (res_reversed )
240+ assert streamlines > quiver > streamplot > separatrices > equilpoints
241+
242+
243+ @pytest .mark .usefixtures ('mplcleanup' )
244+ def test_stream_plot_magnitude ():
245+ def sys (t , x ):
246+ return np .array ([4 * x [1 ], - np .sin (4 * x [0 ])])
247+
248+ # plt context with linewidth
249+ with plt .rc_context ({'lines.linewidth' : 4 }):
250+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_linewidth = True ))
251+ linewidths = res .lines [3 ].lines .get_linewidths ()
252+ # linewidths are scaled to be between 0.25 and 2 times default linewidth
253+ # but the extremes may not exist if there is no line at that point
254+ assert min (linewidths ) < 2 and max (linewidths ) > 7
255+
256+ # make sure changing the colormap works
257+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'viridis' ))
258+ assert res .lines [3 ].lines .get_cmap ().name == 'viridis'
259+ res = ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , cmap = 'turbo' ))
260+ assert res .lines [3 ].lines .get_cmap ().name == 'turbo'
261+
262+ # make sure changing the norm at least doesn't throw an error
263+ ct .phase_plane_plot (sys , plot_streamplot = dict (vary_color = True , norm = mpl .colors .LogNorm ()))
264+
265+
182266
183267
184268@pytest .mark .usefixtures ('mplcleanup' )
@@ -190,7 +274,7 @@ def test_basic_phase_plots(savefigs=False):
190274 plt .figure ()
191275 axis_limits = [- 1 , 1 , - 1 , 1 ]
192276 T = 8
193- ct .phase_plane_plot (sys , axis_limits , T )
277+ ct .phase_plane_plot (sys , axis_limits , T , plot_streamlines = True )
194278 if savefigs :
195279 plt .savefig ('phaseplot-dampedosc-default.png' )
196280
@@ -203,7 +287,7 @@ def invpend_update(t, x, u, params):
203287 ct .phase_plane_plot (
204288 invpend , [- 2 * pi , 2 * pi , - 2 , 2 ], 5 ,
205289 gridtype = 'meshgrid' , gridspec = [5 , 8 ], arrows = 3 ,
206- plot_separatrices = {'gridspec' : [12 , 9 ]},
290+ plot_separatrices = {'gridspec' : [12 , 9 ]}, plot_streamlines = True ,
207291 params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 })
208292 plt .xlabel (r"$\theta$ [rad]" )
209293 plt .ylabel (r"$\dot\theta$ [rad/sec]" )
@@ -218,7 +302,8 @@ def oscillator_update(t, x, u, params):
218302 oscillator_update , states = 2 , inputs = 0 , name = 'nonlinear oscillator' )
219303
220304 plt .figure ()
221- ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 )
305+ ct .phase_plane_plot (oscillator , [- 1.5 , 1.5 , - 1.5 , 1.5 ], 0.9 ,
306+ plot_streamlines = True )
222307 pp .streamlines (
223308 oscillator , np .array ([[0 , 0 ]]), 1.5 ,
224309 gridtype = 'circlegrid' , gridspec = [0.5 , 6 ], dir = 'both' )
@@ -228,6 +313,18 @@ def oscillator_update(t, x, u, params):
228313 if savefigs :
229314 plt .savefig ('phaseplot-oscillator-helpers.png' )
230315
316+ plt .figure ()
317+ ct .phase_plane_plot (
318+ invpend , [- 2 * pi , 2 * pi , - 2 , 2 ],
319+ plot_streamplot = dict (vary_color = True , vary_density = True ),
320+ gridspec = [60 , 20 ], params = {'m' : 1 , 'l' : 1 , 'b' : 0.2 , 'g' : 1 }
321+ )
322+ plt .xlabel (r"$\theta$ [rad]" )
323+ plt .ylabel (r"$\dot\theta$ [rad/sec]" )
324+
325+ if savefigs :
326+ plt .savefig ('phaseplot-invpend-streamplot.png' )
327+
231328
232329if __name__ == "__main__" :
233330 #
0 commit comments