Zero-dependency PyTorch-style state checkpointing for AI agents — save, restore, rollback.
Inspired by AgentScope's StateModule and PyTorch's state_dict() / load_state_dict() pattern, agent-checkpoint lets you serialize and deserialize the complete state of any AI agent — memory, config, counters, tool state — to/from a plain dict, file, or in-memory store.
Zero runtime dependencies. Standard library only. Works on Python 3.9+.
- ✅
state_dict()/load_state_dict()on any class (mixin-based) - ✅ File backend — JSON (human-readable) or pickle (arbitrary objects)
- ✅ In-memory backend — fast, in-process dict store
- ✅ Rollback — restore agent state on exception
- ✅ Versioning — schema version tracking + migration hooks
- ✅ Auto-checkpoint decorator —
@checkpoint(manager, key="x", every=5) - ✅ Diff — compare agent states or two checkpoints
- ✅ History — retrieve last N checkpoints per key
- ✅ Clone — deep-copy agent via state round-trip
- ✅ Zero dependencies
pip install agent-checkpointOr from source:
git clone https://github.com/darshjme/agent-checkpoint
cd agent-checkpoint
pip install -e .from agent_checkpoint import Checkpointable, CheckpointManager, checkpoint
# 1. Make any class checkpointable
class MyAgent(Checkpointable):
def __init__(self):
self.memory = []
self.config = {}
self.call_count = 0
def reply(self, msg):
self.memory.append(msg)
self.call_count += 1
return f"Reply #{self.call_count}"
agent = MyAgent()
agent.reply("hello")
agent.reply("world")
# 2. Save state
state = agent.state_dict()
# → {"memory": ["hello", "world"], "config": {}, "call_count": 2, "__version__": "1.0", "__class__": "MyAgent"}
# 3. Restore to new instance
new_agent = MyAgent()
new_agent.load_state_dict(state)
assert new_agent.call_count == 2
assert new_agent.memory == ["hello", "world"]# In-memory backend (default)
manager = CheckpointManager(backend="memory")
# File backend — JSON (human-readable, default)
manager = CheckpointManager(backend="file", path="/tmp/checkpoints")
# File backend — pickle (supports arbitrary Python objects)
manager = CheckpointManager(backend="file", path="/tmp/checkpoints", fmt="pickle")
# Save
manager.save("session_123", agent)
# Load
manager.load("session_123", new_agent)
# List all checkpoint keys
print(manager.list_keys())
# Get history (last N checkpoints for a key)
history = manager.history("session_123", n=5)
# [{"id": "session_123_1709...", "ts": 1709...}, ...]# Method 1: manual save + restore
manager.save("before_risky_op", agent)
try:
risky_operation(agent)
except Exception:
manager.load("before_risky_op", agent) # instant rollback ✅
# Method 2: context manager (auto-saves before block, restores on exception)
with manager.rollback_context("before_risky_op", agent):
risky_operation(agent)
# If risky_operation raises → agent is automatically restored. No exception suppressed.Concrete example:
from agent_checkpoint import Checkpointable, CheckpointManager
class ConversationAgent(Checkpointable):
def __init__(self):
self.history = []
self.tokens_used = 0
def send(self, msg, tokens=10):
self.history.append(msg)
self.tokens_used += tokens
agent = ConversationAgent()
agent.send("Hello", tokens=5)
agent.send("How are you?", tokens=8)
mgr = CheckpointManager(backend="memory")
mgr.save("checkpoint_1", agent) # snapshot before risky call
try:
# Simulate an operation that corrupts state
agent.send("CORRUPTED", tokens=9999)
raise RuntimeError("API call failed mid-way")
except RuntimeError:
mgr.load("checkpoint_1", agent) # rollback → state restored
print(agent.tokens_used) # → 13 (not 10012)
print(agent.history) # → ["Hello", "How are you?"]manager = CheckpointManager(backend="file", path="/tmp/ckpts")
@checkpoint(manager, key="my_agent", every=5) # checkpoint every 5 calls
def process(agent, msg):
return agent.reply(msg)
for i in range(20):
process(agent, f"message {i}")
# Checkpoints saved at call 5, 10, 15, 20# Compare current agent state vs. a saved checkpoint
diff = manager.diff("my_key", agent=agent)
print(diff)
# {"added": {}, "removed": {}, "changed": {"call_count": (5, 10)}}
# Compare two stored checkpoints
diff = manager.diff("my_key", checkpoint_id_a=id1, checkpoint_id_b=id2)from agent_checkpoint import Checkpointable, register_migration
class MyAgent(Checkpointable):
def __init__(self):
self.score = 0
def _checkpoint_version(self):
return "2.0"
# Migrate state saved by v1.0 → v2.0 schema
def migrate_v1_to_v2(state):
state["score"] = state.pop("points", 0) * 10 # rename + scale
return state
register_migration("MyAgent", from_version="1.0", fn=migrate_v1_to_v2)class MyAgent(Checkpointable):
def __init__(self):
self.memory = [] # ✅ captured
self.call_count = 0 # ✅ captured
self._cache = {} # ✅ excluded by default (starts with _)
self._runtime_state = None # ✅ excluded by default
def _checkpoint_fields(self):
# Explicitly pin which fields to capture
return ["memory", "call_count"]agent = MyAgent()
agent.reply("hello")
cloned = agent.clone() # deep copy via state round-trip
cloned.reply("world") # mutating clone doesn't affect original
assert agent.call_count == 1
assert cloned.call_count == 2| Type | Serialization |
|---|---|
None, bool, int, float, str |
Native JSON |
list, dict (str keys) |
Native JSON |
tuple |
JSON with __type__: tuple tag |
set |
JSON with __type__: set tag |
dict (non-str keys) |
pickle + base64 |
| Arbitrary Python objects | pickle + base64 |
Nested Checkpointable |
Recursive state dict |
agent_checkpoint/
├── core.py Checkpointable mixin — state_dict / load_state_dict / clone / diff
├── backends.py MemoryBackend + FileBackend (JSON & pickle)
├── manager.py CheckpointManager — save / load / rollback / history / diff
└── decorator.py @checkpoint — auto-checkpoint decorator
pip install pytest
pytest tests/ -vMIT © 2026 Darshankumar Joshi