Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 91 additions & 7 deletions src/ape/managers/_contractscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Collection
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Generic, TypeVar
Expand Down Expand Up @@ -30,6 +31,12 @@
_BASE_MODEL = TypeVar("_BASE_MODEL", bound=BaseModel)


@dataclass(frozen=True)
class ProxyInfoCacheEntry:
exists: bool
value: ProxyInfoAPI | None = None


class ApeDataCache(CacheDirectory, Generic[_BASE_MODEL]):
"""
A wrapper around some cached models in the data directory,
Expand All @@ -47,7 +54,7 @@ def __init__(
data_folder = base_data_folder / ecosystem_key
base_path = data_folder / network_key
self._model_type = model_type
self.memory: dict[str, _BASE_MODEL] = {}
self.memory: dict[str, _BASE_MODEL | None] = {}

# Only write if we are not testing!
self._write_to_disk = not network_key.endswith("-fork") and network_key != "local"
Expand Down Expand Up @@ -92,6 +99,60 @@ def get_type(self, key: str, fetch_from_disk: bool = True) -> _BASE_MODEL | None
return None


class ProxyInfoCache(ApeDataCache[ProxyInfoAPI]):
"""
Cache of proxy detection results.

A missing file means unchecked, `null` means checked and not a proxy, and a
JSON object means checked and proxy info found.
"""

def __init__(
self,
base_data_folder: Path,
ecosystem_key: str,
network_key: str,
key: str,
model_type: type[ProxyInfoAPI],
):
super().__init__(base_data_folder, ecosystem_key, network_key, key, model_type)
self.memory: dict[str, ProxyInfoAPI | None] = {}

def __setitem__(self, key: str, value: ProxyInfoAPI | None): # type: ignore
self.memory[key] = value
if self._write_to_disk:
self.cache_data(key, value.model_dump(mode="json") if value is not None else None)

def __delitem__(self, key: str):
super().__delitem__(key)

def get_type(self, key: str, fetch_from_disk: bool = True) -> ProxyInfoAPI | None:
return self.get_entry(key, fetch_from_disk=fetch_from_disk).value

def get_entry(self, key: str, fetch_from_disk: bool = True) -> ProxyInfoCacheEntry:
if key in self.memory:
return ProxyInfoCacheEntry(exists=True, value=self.memory[key])

elif fetch_from_disk and self._read_from_disk:
file = self.get_file(key)
if file.is_file():
data = self.get_data(key)
if data is None:
self.memory[key] = None
return ProxyInfoCacheEntry(exists=True)

# Found proxy info on disk.
model = self._model_type.model_validate(data)
# Cache locally for next time.
self.memory[key] = model
return ProxyInfoCacheEntry(exists=True, value=model)

return ProxyInfoCacheEntry(exists=False)

def clear_memory(self):
self.memory = {}


class ContractCache(BaseManager):
"""
A collection of cached contracts. Contracts can be cached in two ways:
Expand All @@ -115,8 +176,8 @@ def contract_types(self) -> ApeDataCache[ContractType]:
return self._get_data_cache("contract_types", ContractType)

@property
def proxy_infos(self) -> ApeDataCache[ProxyInfoAPI]:
return self._get_data_cache("proxy_info", ProxyInfoAPI)
def proxy_infos(self) -> ProxyInfoCache:
return self._get_data_cache("proxy_info", ProxyInfoAPI, cache_type=ProxyInfoCache)

@property
def blueprints(self) -> ApeDataCache[ContractType]:
Expand All @@ -132,6 +193,7 @@ def _get_data_cache(
model_type: type,
ecosystem_key: str | None = None,
network_key: str | None = None,
cache_type: type[ApeDataCache] = ApeDataCache,
):
ecosystem_name = ecosystem_key or self.provider.network.ecosystem.name
network_name = network_key or self.provider.network.name.replace("-fork", "")
Expand All @@ -141,7 +203,7 @@ def _get_data_cache(
if cache := self._caches[ecosystem_name][network_name].get(key):
return cache

self._caches[ecosystem_name][network_name][key] = ApeDataCache(
self._caches[ecosystem_name][network_name][key] = cache_type(
self.config_manager.DATA_FOLDER, ecosystem_name, network_name, key, model_type
)
return self._caches[ecosystem_name][network_name][key]
Expand Down Expand Up @@ -178,6 +240,15 @@ def __setitem__(
else:
raise TypeError(item)

def cache_proxy_info_no_hit(self, address: AddressType):
"""
Cache that proxy detection found no proxy information for this address.

Args:
address (AddressType): The address that is not a proxy.
"""
self.proxy_infos[address] = None

def cache_contract_type(
self,
address: AddressType,
Expand Down Expand Up @@ -267,6 +338,8 @@ def _delete_proxy(self, address: AddressType):
target = info.target
del self.proxy_infos[target]
del self.contract_types[target]
else:
del self.proxy_infos[address]

def __contains__(self, address: AddressType) -> bool:
return self.get(address) is not None
Expand Down Expand Up @@ -307,6 +380,7 @@ def cache_deployment(

else:
# Cache as normal.
self.cache_proxy_info_no_hit(address)
self.contract_types[address] = contract_type

else:
Expand Down Expand Up @@ -584,9 +658,16 @@ def get(
# Check broader sources, such as an explorer.
if not proxy_info and detect_proxy:
# Proxy info not provided. Attempt to detect.
if not (proxy_info := self.proxy_infos[address_key]):
cached_proxy_info = self.proxy_infos.get_entry(
address_key, fetch_from_disk=fetch_from_disk
)
if cached_proxy_info.exists:
proxy_info = cached_proxy_info.value
else:
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
self.proxy_infos[address_key] = proxy_info
self.cache_proxy_info(address_key, proxy_info)
else:
self.cache_proxy_info_no_hit(address_key)

if proxy_info:
if proxy_contract_type := self._get_proxy_contract_type(
Expand Down Expand Up @@ -917,7 +998,10 @@ def clear_local_caches(self):
self.contract_creations,
self.blueprints,
):
cache.memory = {}
if isinstance(cache, ProxyInfoCache):
cache.clear_memory()
else:
cache.memory = {}

self.deployments.clear_local()

Expand Down
8 changes: 4 additions & 4 deletions src/ape/utils/os.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def __init__(self, path: Path):

self._path = path

def __getitem__(self, key: str) -> dict:
def __getitem__(self, key: str) -> dict | None:
"""
Get the data from ``base_path / <key>.json``.

Expand All @@ -405,7 +405,7 @@ def __getitem__(self, key: str) -> dict:
"""
return self.get_data(key)

def __setitem__(self, key: str, value: dict):
def __setitem__(self, key: str, value: dict | None):
"""
Cache the given data to ``base_path / <key>.json``.

Expand All @@ -427,14 +427,14 @@ def __delitem__(self, key: str):
def get_file(self, key: str) -> Path:
return self._path / f"{key}.json"

def cache_data(self, key: str, data: dict):
def cache_data(self, key: str, data: dict | None):
json_str = json.dumps(data)
file = self.get_file(key)
file.unlink(missing_ok=True)
file.parent.mkdir(parents=True, exist_ok=True)
file.write_text(json_str)

def get_data(self, key: str) -> dict:
def get_data(self, key: str) -> dict | None:
file = self.get_file(key)
if not file.is_file():
return {}
Expand Down
67 changes: 67 additions & 0 deletions tests/functional/test_contracts_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,70 @@ def test_instance_at_skip_proxy(mocker, chain, vyper_contract_instance, owner):
assert address != arg


def test_get_caches_proxy_info_no_hit(mocker, chain, vyper_contract_instance, ethereum):
address = vyper_contract_instance.address
with chain.contracts.use_temporary_caches():
ecosystem_type = type(ethereum)
get_proxy_info = ecosystem_type.get_proxy_info
proxy_detection_spy = mocker.patch.object(
ecosystem_type,
"get_proxy_info",
autospec=True,
side_effect=lambda ecosystem, address: get_proxy_info(ecosystem, address),
)

assert chain.contracts.get(address, fetch_from_explorer=False) is None
assert proxy_detection_spy.call_count == 1
cached_proxy_info = chain.contracts.proxy_infos.get_entry(address)
assert cached_proxy_info.exists is True
assert cached_proxy_info.value is None

assert chain.contracts.get(address, fetch_from_explorer=False) is None
assert proxy_detection_spy.call_count == 1


def test_cache_proxy_info_no_hit_live_network(chain, clean_contract_caches, dummy_live_network):
address = "0x4a986a6dca6dbF99Bc3D17F8d71aFB0D60E740F9"
cache = chain.contracts.proxy_infos

try:
chain.contracts.cache_proxy_info_no_hit(address)
assert cache.get_file(address).is_file()
assert cache.get_data(address) is None

cache.clear_memory()
cached_proxy_info = cache.get_entry(address)
assert cached_proxy_info.exists is True
assert cached_proxy_info.value is None
assert cache[address] is None

finally:
del cache[address]


def test_cache_proxy_info_loads_existing_disk_model(
chain, clean_contract_caches, dummy_live_network
):
address = "0x4a986a6dca6dbF99Bc3D17F8d71aFB0D60E740F9"
target = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe"
cache = chain.contracts.proxy_infos
proxy_info = ProxyInfo(type=ProxyType.Minimal, target=target)

try:
cache.cache_data(address, proxy_info.model_dump(mode="json"))
cache.clear_memory()

cached_proxy_info = cache.get_entry(address)
assert cached_proxy_info.exists is True
assert cached_proxy_info.value is not None
assert cached_proxy_info.value.target == target
assert cached_proxy_info.value.type_name == "Minimal"
assert cache[address] == cached_proxy_info.value

finally:
del cache[address]


def test_cache_deployment_live_network(
chain,
project,
Expand Down Expand Up @@ -661,12 +725,15 @@ def test_clear_local_caches(chain, vyper_contract_instance, project, owner):
chain.contracts.blueprints[address] = vyper_contract_instance.contract_type
# Ensure proxy exists.
proxy = project.SimpleProxy.deploy(address, sender=owner)
# Ensure proxy no-hit exists.
chain.contracts.cache_proxy_info_no_hit(address)
# Ensure creation exists.
_ = chain.contracts.get_creation_metadata(address)

# Test setup verification.
assert address in chain.contracts.contract_types, "Setup failed - no contract type(s) cached"
assert proxy.address in chain.contracts.proxy_infos, "Setup failed - no proxy cached"
assert chain.contracts.proxy_infos.get_entry(address).exists, "Setup failed - no proxy no-hit"
assert address in chain.contracts.contract_creations, "Setup failed - no creation(s) cached"

# This is the method we are testing.
Expand Down
Loading