Skip to content

Commit c9372d9

Browse files
declare_contract cheatcode (#354)
* Added tests * Added cheatcode * Lint
1 parent 7da25fc commit c9372d9

5 files changed

Lines changed: 158 additions & 0 deletions

File tree

protostar/commands/test/test_execution_environment.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from starkware.starknet.services.api.contract_class import ContractClass
99
from starkware.starknet.testing.contract import StarknetContract
1010
from starkware.starkware_utils.error_handling import StarkException
11+
from starkware.starknet.testing.contract import DeclaredClass
1112

1213
from protostar.commands.test.cheatcodes import (
1314
Cheatcode,
@@ -21,6 +22,8 @@
2122
)
2223
from protostar.commands.test.starkware.forkable_starknet import ForkableStarknet
2324
from protostar.commands.test.test_context import TestContext
25+
26+
2427
from protostar.commands.test.test_environment_exceptions import (
2528
CheatcodeException,
2629
ExpectedEventMissingException,
@@ -44,6 +47,15 @@ def contract_address(self):
4447
return self._starknet_contract.contract_address
4548

4649

50+
class ProtostarDeclaredClass:
51+
def __init__(self, declared_class: DeclaredClass):
52+
self._declared_class = declared_class
53+
54+
@property
55+
def class_hash(self):
56+
return self._declared_class.class_hash
57+
58+
4759
class TestExecutionEnvironment:
4860
def __init__(
4961
self,
@@ -98,6 +110,17 @@ def deploy_in_env(
98110
)
99111
return contract
100112

113+
def declare_in_env(self, contract_path: str):
114+
contract = ProtostarDeclaredClass(
115+
asyncio.run(
116+
self.starknet.declare(
117+
source=contract_path,
118+
cairo_path=self._include_paths,
119+
)
120+
)
121+
)
122+
return contract
123+
101124
async def invoke_setup_hook(self, fn_name: str) -> None:
102125
await self.invoke_test_case(fn_name)
103126

@@ -291,6 +314,10 @@ def deploy_contract(
291314
):
292315
return self.deploy_in_env(contract_path, constructor_calldata)
293316

317+
@register_cheatcode
318+
def declare_contract(contract_path: str):
319+
return self.declare_in_env(contract_path)
320+
294321
cheatcodes: List[Cheatcode] = [
295322
ExpectRevertCheatcode(self),
296323
RollCheatcode(cheatable_syscall_handler),
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
%lang starknet
2+
3+
from starkware.cairo.common.cairo_builtins import HashBuiltin
4+
5+
@storage_var
6+
func balance() -> (res : felt):
7+
end
8+
9+
@external
10+
func increase_balance{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(
11+
amount : felt):
12+
let (res) = balance.read()
13+
balance.write(res + amount)
14+
return ()
15+
end
16+
17+
@view
18+
func get_balance{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}() -> (
19+
res : felt):
20+
let (res) = balance.read()
21+
return (res)
22+
end
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
%lang starknet
2+
3+
from starkware.cairo.common.alloc import alloc
4+
from starkware.starknet.common.syscalls import deploy
5+
6+
@contract_interface
7+
namespace BasicContract:
8+
func increase_balance(amount : felt):
9+
end
10+
11+
func get_balance() -> (res : felt):
12+
end
13+
end
14+
15+
@contract_interface
16+
namespace ProxyContract:
17+
func deploy_contract_from_proxy(class_hash_d : felt) -> (address : felt):
18+
end
19+
end
20+
21+
22+
@external
23+
func test_deploy_declared_contract{syscall_ptr : felt*, range_check_ptr}():
24+
alloc_locals
25+
26+
local class_hash : felt
27+
%{
28+
ids.class_hash = declare_contract("./tests/integration/cheatcodes/declare_contract/basic_contract.cairo").class_hash
29+
%}
30+
31+
let (local calldata: felt*) = alloc()
32+
let (contract_address) = deploy(class_hash, 42, 0, calldata)
33+
34+
BasicContract.increase_balance(contract_address, 12)
35+
36+
let (balance) = BasicContract.get_balance(contract_address)
37+
assert balance = 12
38+
return ()
39+
end
40+
41+
@external
42+
func test_deploy_declared_contract_in_proxy{syscall_ptr : felt*, range_check_ptr}():
43+
alloc_locals
44+
45+
local proxy_address : felt
46+
local class_hash : felt
47+
%{
48+
ids.proxy_address = deploy_contract("./tests/integration/cheatcodes/declare_contract/proxy_contract.cairo").contract_address
49+
ids.class_hash = declare_contract("./tests/integration/cheatcodes/declare_contract/basic_contract.cairo").class_hash
50+
%}
51+
52+
let (contract_address) = ProxyContract.deploy_contract_from_proxy(proxy_address, class_hash)
53+
54+
BasicContract.increase_balance(contract_address, 12)
55+
56+
let (balance) = BasicContract.get_balance(contract_address)
57+
assert balance = 12
58+
return ()
59+
end
60+
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from pathlib import Path
2+
3+
import pytest
4+
5+
from protostar.commands.test.test_command import TestCommand
6+
from tests.integration.conftest import assert_cairo_test_cases
7+
8+
9+
@pytest.mark.asyncio
10+
async def test_declare_contract(mocker):
11+
testing_summary = await TestCommand(
12+
project=mocker.MagicMock(),
13+
protostar_directory=mocker.MagicMock(),
14+
).test(targets=[str(Path(__file__).parent / "declare_contract_test.cairo")])
15+
16+
assert_cairo_test_cases(
17+
testing_summary,
18+
expected_passed_test_cases_names=[
19+
"test_deploy_declared_contract",
20+
"test_deploy_declared_contract_in_proxy",
21+
],
22+
expected_failed_test_cases_names=[],
23+
)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Declare this file as a StarkNet contract and set the required
2+
# builtins.
3+
%lang starknet
4+
5+
from starkware.cairo.common.cairo_builtins import HashBuiltin
6+
from starkware.cairo.common.alloc import alloc
7+
from starkware.starknet.common.syscalls import deploy
8+
9+
@contract_interface
10+
namespace BasicContract:
11+
func increase_balance(amount : felt):
12+
end
13+
14+
func get_balance() -> (res : felt):
15+
end
16+
end
17+
18+
19+
@external
20+
func deploy_contract_from_proxy{syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr}(
21+
class_hash_d : felt) -> (address : felt):
22+
alloc_locals
23+
let (local calldata: felt*) = alloc()
24+
let (contract_address) = deploy(class_hash_d, 42, 0, calldata)
25+
return (contract_address)
26+
end

0 commit comments

Comments
 (0)