Skip to content

Commit 503e19f

Browse files
committed
fix
1 parent f283346 commit 503e19f

1 file changed

Lines changed: 10 additions & 12 deletions

File tree

verl/workers/sharding_manager/fsdp_sglang.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __enter__(self):
103103

104104
loop = asyncio.get_event_loop()
105105

106-
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
106+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
107107
if self.multi_stage_wake_up:
108108
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
109109
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
@@ -130,7 +130,7 @@ def __enter__(self):
130130
get_torch_device().empty_cache()
131131
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)
132132

133-
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.multi_stage_wake_up and self.rollout_config.free_cache_engine:
133+
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.multi_stage_wake_up:
134134
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
135135
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)
136136

@@ -141,11 +141,10 @@ def __enter__(self):
141141

142142
@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
143143
def __exit__(self, exc_type, exc_value, traceback):
144-
if self.rollout_config.free_cache_engine:
145-
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
146-
loop = asyncio.get_event_loop()
147-
loop.run_until_complete(self.release_memory())
148-
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
144+
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
145+
loop = asyncio.get_event_loop()
146+
loop.run_until_complete(self.release_memory())
147+
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
149148

150149
self.module.train()
151150

@@ -188,7 +187,7 @@ async def update_weights(self, params):
188187
)
189188

190189
async def release_memory(self):
191-
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
190+
if self.device_mesh["infer_tp"].get_local_rank() == 0:
192191
await self.inference_engine.release_memory_occupation()
193192

194193
@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
@@ -218,10 +217,9 @@ async def wake_up(self):
218217

219218
@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
220219
async def sleep(self):
221-
if self.rollout_config.free_cache_engine:
222-
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
223-
await self.release_memory()
224-
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
220+
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
221+
await self.release_memory()
222+
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)
225223

226224
self.module.train()
227225

0 commit comments

Comments
 (0)