Skip to content

darshjme/agent-checkpoint

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

agent-checkpoint

Zero-dependency PyTorch-style state checkpointing for AI agents — save, restore, rollback.

PyPI Python License: MIT


Inspired by AgentScope's StateModule and PyTorch's state_dict() / load_state_dict() pattern, agent-checkpoint lets you serialize and deserialize the complete state of any AI agent — memory, config, counters, tool state — to/from a plain dict, file, or in-memory store.

Zero runtime dependencies. Standard library only. Works on Python 3.9+.


Features

  • state_dict() / load_state_dict() on any class (mixin-based)
  • File backend — JSON (human-readable) or pickle (arbitrary objects)
  • In-memory backend — fast, in-process dict store
  • Rollback — restore agent state on exception
  • Versioning — schema version tracking + migration hooks
  • Auto-checkpoint decorator@checkpoint(manager, key="x", every=5)
  • Diff — compare agent states or two checkpoints
  • History — retrieve last N checkpoints per key
  • Clone — deep-copy agent via state round-trip
  • ✅ Zero dependencies

Installation

pip install agent-checkpoint

Or from source:

git clone https://github.com/darshjme/agent-checkpoint
cd agent-checkpoint
pip install -e .

Quick Start

from agent_checkpoint import Checkpointable, CheckpointManager, checkpoint

# 1. Make any class checkpointable
class MyAgent(Checkpointable):
    def __init__(self):
        self.memory = []
        self.config = {}
        self.call_count = 0

    def reply(self, msg):
        self.memory.append(msg)
        self.call_count += 1
        return f"Reply #{self.call_count}"

agent = MyAgent()
agent.reply("hello")
agent.reply("world")

# 2. Save state
state = agent.state_dict()
# → {"memory": ["hello", "world"], "config": {}, "call_count": 2, "__version__": "1.0", "__class__": "MyAgent"}

# 3. Restore to new instance
new_agent = MyAgent()
new_agent.load_state_dict(state)
assert new_agent.call_count == 2
assert new_agent.memory == ["hello", "world"]

CheckpointManager

# In-memory backend (default)
manager = CheckpointManager(backend="memory")

# File backend — JSON (human-readable, default)
manager = CheckpointManager(backend="file", path="/tmp/checkpoints")

# File backend — pickle (supports arbitrary Python objects)
manager = CheckpointManager(backend="file", path="/tmp/checkpoints", fmt="pickle")

# Save
manager.save("session_123", agent)

# Load
manager.load("session_123", new_agent)

# List all checkpoint keys
print(manager.list_keys())

# Get history (last N checkpoints for a key)
history = manager.history("session_123", n=5)
# [{"id": "session_123_1709...", "ts": 1709...}, ...]

Rollback

# Method 1: manual save + restore
manager.save("before_risky_op", agent)
try:
    risky_operation(agent)
except Exception:
    manager.load("before_risky_op", agent)   # instant rollback ✅

# Method 2: context manager (auto-saves before block, restores on exception)
with manager.rollback_context("before_risky_op", agent):
    risky_operation(agent)
# If risky_operation raises → agent is automatically restored. No exception suppressed.

Concrete example:

from agent_checkpoint import Checkpointable, CheckpointManager

class ConversationAgent(Checkpointable):
    def __init__(self):
        self.history = []
        self.tokens_used = 0

    def send(self, msg, tokens=10):
        self.history.append(msg)
        self.tokens_used += tokens

agent = ConversationAgent()
agent.send("Hello", tokens=5)
agent.send("How are you?", tokens=8)

mgr = CheckpointManager(backend="memory")
mgr.save("checkpoint_1", agent)       # snapshot before risky call

try:
    # Simulate an operation that corrupts state
    agent.send("CORRUPTED", tokens=9999)
    raise RuntimeError("API call failed mid-way")
except RuntimeError:
    mgr.load("checkpoint_1", agent)   # rollback → state restored

print(agent.tokens_used)  # → 13 (not 10012)
print(agent.history)      # → ["Hello", "How are you?"]

Auto-Checkpoint Decorator

manager = CheckpointManager(backend="file", path="/tmp/ckpts")

@checkpoint(manager, key="my_agent", every=5)   # checkpoint every 5 calls
def process(agent, msg):
    return agent.reply(msg)

for i in range(20):
    process(agent, f"message {i}")
# Checkpoints saved at call 5, 10, 15, 20

Diff

# Compare current agent state vs. a saved checkpoint
diff = manager.diff("my_key", agent=agent)
print(diff)
# {"added": {}, "removed": {}, "changed": {"call_count": (5, 10)}}

# Compare two stored checkpoints
diff = manager.diff("my_key", checkpoint_id_a=id1, checkpoint_id_b=id2)

Versioning & Migrations

from agent_checkpoint import Checkpointable, register_migration

class MyAgent(Checkpointable):
    def __init__(self):
        self.score = 0

    def _checkpoint_version(self):
        return "2.0"

# Migrate state saved by v1.0 → v2.0 schema
def migrate_v1_to_v2(state):
    state["score"] = state.pop("points", 0) * 10  # rename + scale
    return state

register_migration("MyAgent", from_version="1.0", fn=migrate_v1_to_v2)

Customising Captured Fields

class MyAgent(Checkpointable):
    def __init__(self):
        self.memory = []           # ✅ captured
        self.call_count = 0        # ✅ captured
        self._cache = {}           # ✅ excluded by default (starts with _)
        self._runtime_state = None # ✅ excluded by default

    def _checkpoint_fields(self):
        # Explicitly pin which fields to capture
        return ["memory", "call_count"]

Clone

agent = MyAgent()
agent.reply("hello")

cloned = agent.clone()          # deep copy via state round-trip
cloned.reply("world")           # mutating clone doesn't affect original

assert agent.call_count == 1
assert cloned.call_count == 2

Supported Types

Type Serialization
None, bool, int, float, str Native JSON
list, dict (str keys) Native JSON
tuple JSON with __type__: tuple tag
set JSON with __type__: set tag
dict (non-str keys) pickle + base64
Arbitrary Python objects pickle + base64
Nested Checkpointable Recursive state dict

Architecture

agent_checkpoint/
├── core.py        Checkpointable mixin — state_dict / load_state_dict / clone / diff
├── backends.py    MemoryBackend + FileBackend (JSON & pickle)
├── manager.py     CheckpointManager — save / load / rollback / history / diff
└── decorator.py   @checkpoint — auto-checkpoint decorator

Running Tests

pip install pytest
pytest tests/ -v

License

MIT © 2026 Darshankumar Joshi

About

Zero-dep PyTorch-style state checkpointing for AI agents — save, restore, rollback

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages