@@ -103,7 +103,7 @@ def __enter__(self):
103103
104104 loop = asyncio .get_event_loop ()
105105
106- if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self . rollout_config . free_cache_engine :
106+ if self .device_mesh ["infer_tp" ].get_local_rank () == 0 :
107107 if self .multi_stage_wake_up :
108108 loop .run_until_complete (self .inference_engine .resume_memory_occupation (tags = ["weights" ]))
109109 log_gpu_memory_usage ("Before resume SGLang weights in sharding manager" , logger = logger )
@@ -130,7 +130,7 @@ def __enter__(self):
130130 get_torch_device ().empty_cache ()
131131 log_gpu_memory_usage ("After del state_dict and empty_cache in sharding manager" , logger = logger )
132132
133- if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self .multi_stage_wake_up and self . rollout_config . free_cache_engine :
133+ if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self .multi_stage_wake_up :
134134 loop .run_until_complete (self .inference_engine .resume_memory_occupation (tags = ["kv_cache" ]))
135135 log_gpu_memory_usage ("After resume SGLang kv_cache in sharding manager" , logger = logger )
136136
@@ -141,11 +141,10 @@ def __enter__(self):
141141
142142 @GPUMemoryLogger (role = "FSDPSGLangShardingManager exit" , logger = logger )
143143 def __exit__ (self , exc_type , exc_value , traceback ):
144- if self .rollout_config .free_cache_engine :
145- log_gpu_memory_usage ("Before SGLang offload in sharding manager" , logger = logger )
146- loop = asyncio .get_event_loop ()
147- loop .run_until_complete (self .release_memory ())
148- log_gpu_memory_usage ("After SGLang offload in sharding manager" , logger = logger )
144+ log_gpu_memory_usage ("Before SGLang offload in sharding manager" , logger = logger )
145+ loop = asyncio .get_event_loop ()
146+ loop .run_until_complete (self .release_memory ())
147+ log_gpu_memory_usage ("After SGLang offload in sharding manager" , logger = logger )
149148
150149 self .module .train ()
151150
@@ -188,7 +187,7 @@ async def update_weights(self, params):
188187 )
189188
190189 async def release_memory (self ):
191- if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self . rollout_config . free_cache_engine :
190+ if self .device_mesh ["infer_tp" ].get_local_rank () == 0 :
192191 await self .inference_engine .release_memory_occupation ()
193192
194193 @GPUMemoryLogger (role = "FSDPSGLangShardingManager enter" , logger = logger )
@@ -218,10 +217,9 @@ async def wake_up(self):
218217
219218 @GPUMemoryLogger (role = "FSDPSGLangShardingManager exit" , logger = logger )
220219 async def sleep (self ):
221- if self .rollout_config .free_cache_engine :
222- log_gpu_memory_usage ("Before SGLang offload in sharding manager" , logger = logger )
223- await self .release_memory ()
224- log_gpu_memory_usage ("After SGLang offload in sharding manager" , logger = logger )
220+ log_gpu_memory_usage ("Before SGLang offload in sharding manager" , logger = logger )
221+ await self .release_memory ()
222+ log_gpu_memory_usage ("After SGLang offload in sharding manager" , logger = logger )
225223
226224 self .module .train ()
227225
0 commit comments