Skip to content

Commit 61b389a

Browse files
committed
BUGFIX: Add last step into the logs. Render only when visual keys are present
1 parent 78e2ae3 commit 61b389a

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

robohive/envs/env_base.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -781,9 +781,20 @@ def examine_policy_new(self,
781781
t = t+1
782782
ep_rwd += rwd
783783

784+
# record last step and finalize the rollout --------------------------------
785+
act = np.nan*np.ones(self.action_space.shape)
786+
datum_dict = dict(
787+
time=t,
788+
observations=obs,
789+
actions=act.copy(),
790+
rewards=rwd,
791+
env_infos=env_info,
792+
done=done,
793+
)
794+
trace.append_datums(group_key=group_key, dataset_key_val=datum_dict)
784795
prompt(f"Episode {ep}:> Finished in {(timer.time()-ep_t0):0.4} sec. Total rewards {ep_rwd}", type=Prompt.INFO)
785796

786-
# save offscreen buffers as video
797+
# save offscreen buffers as video --------------------------------
787798
if render =='offscreen':
788799
file_name = output_dir + filename + str(ep) + ".mp4"
789800
# check if the platform is OS -- make it compatible with quicktime
@@ -797,15 +808,12 @@ def examine_policy_new(self,
797808
prompt("Total time taken = %f"% (timer.time()-exp_t0), type=Prompt.INFO)
798809
# print(trace)
799810
# trace.save(self.id+"_trace.pickle", verify_length=True)
800-
trace.save(output_dir+self.id+"_trace.h5", verify_length=True)
801-
# print(trace)
802-
render_keys = ['env_infos/visual_dict/rgb:top_cam:256x256:2d',
803-
'env_infos/visual_dict/rgb:left_cam:256x256:2d',
804-
'env_infos/visual_dict/rgb:right_cam:256x256:2d',
805-
'env_infos/visual_dict/rgb:Franka_wrist_cam:256x256:2d'
806-
]
807-
trace.close()
808-
trace.render(output_dir=output_dir, output_format="mp4", groups="Trial0", datasets=render_keys, input_fps=1/self.dt)
811+
# trace.save(output_dir+self.id+"_trace.h5", verify_length=True)
812+
print(trace)
813+
if self.visual_keys:
814+
trace.close()
815+
render_keys = ['env_infos/visual_dict/'+ key for key in self.visual_keys]
816+
trace.render(output_dir=output_dir, output_format="mp4", groups=":", datasets=render_keys, input_fps=1/self.dt)
809817

810818
# Does this belong here? Rendering should be a post processing step once the logs has been saved.
811819
# Note that the saved logs are read as pickle/h5, we can't call trace.render on them.

0 commit comments

Comments
 (0)