Skip to content

Commit ea08a69

Browse files
committed
WIP
1 parent 3445e5f commit ea08a69

File tree

3 files changed

+126
-19
lines changed

3 files changed

+126
-19
lines changed

async_substrate_interface/async_substrate.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,7 +1091,10 @@ async def __aenter__(self):
10911091
await self.initialize()
10921092
return self
10931093

1094-
async def initialize(self):
1094+
async def initialize(self) -> None:
1095+
await self._initialize()
1096+
1097+
async def _initialize(self) -> None:
10951098
"""
10961099
Initialize the connection to the chain.
10971100
"""
@@ -1116,7 +1119,7 @@ async def initialize(self):
11161119
self._initializing = False
11171120

11181121
async def __aexit__(self, exc_type, exc_val, exc_tb):
1119-
await self.ws.shutdown()
1122+
await self.close()
11201123

11211124
@property
11221125
def metadata(self):
@@ -4260,11 +4263,16 @@ class DiskCachedAsyncSubstrateInterface(AsyncSubstrateInterface):
42604263
Experimental new class that uses disk-caching in addition to memory-caching for the cached methods
42614264
"""
42624265

4266+
async def initialize(self) -> None:
4267+
await self.runtime_cache.load_from_disk(self.url)
4268+
await self._initialize()
4269+
42634270
async def close(self):
42644271
"""
42654272
Closes the substrate connection, and the websocket connection.
42664273
"""
42674274
try:
4275+
await self.runtime_cache.dump_to_disk(self.url)
42684276
await self.ws.shutdown()
42694277
except AttributeError:
42704278
pass

async_substrate_interface/types.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from .const import SS58_FORMAT
1818
from .utils import json
19+
from .utils.cache import AsyncSqliteDB
1920

2021
logger = logging.getLogger("async_substrate_interface")
2122

@@ -101,6 +102,21 @@ def retrieve(
101102
return runtime
102103
return None
103104

105+
async def load_from_disk(self, chain_endpoint: str):
106+
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
107+
block_mapping, block_hash_mapping, runtime_version_mapping = await db.load_runtime_cache(chain_endpoint)
108+
if not any([block_mapping, block_hash_mapping, runtime_version_mapping]):
109+
logger.debug("No runtime mappings in disk cache")
110+
else:
111+
logger.debug("Found runtime mappings in disk cache")
112+
self.blocks = {x: Runtime.deserialize(y) for x, y in block_mapping.items()}
113+
self.block_hashes = {x: Runtime.deserialize(y) for x, y in block_hash_mapping.items()}
114+
self.versions = {x: Runtime.deserialize(y) for x, y in runtime_version_mapping.items()}
115+
116+
async def dump_to_disk(self, chain_endpoint: str):
117+
db = AsyncSqliteDB(chain_endpoint=chain_endpoint)
118+
await db.dump_runtime_cache(chain_endpoint, self.blocks, self.block_hashes, self.versions)
119+
104120

105121
class Runtime:
106122
"""
@@ -149,6 +165,33 @@ def __init__(
149165
if registry is not None:
150166
self.load_registry_type_map()
151167

168+
def serialize(self):
169+
return {
170+
"chain": self.chain,
171+
"type_registry": self.type_registry,
172+
"metadata": self.metadata,
173+
"metadata_v15": self.metadata_v15.to_json() if self.metadata_v15 is not None else None,
174+
"runtime_info": self.runtime_info,
175+
"registry": None, # gets loaded from metadata_v15
176+
"ss58_format": self.ss58_format,
177+
"runtime_config": self.runtime_config,
178+
}
179+
180+
@classmethod
181+
def deserialize(cls, serialized: dict) -> "Runtime":
182+
mdv15 = MetadataV15
183+
registry = PortableRegistry.from_metadata_v15(mdv15) if (mdv15 := serialized["metadata_v15"]) else None
184+
return cls(
185+
chain=serialized["chain"],
186+
metadata=serialized["metadata"],
187+
type_registry=serialized["type_registry"],
188+
runtime_config=serialized["runtime_config"],
189+
metadata_v15=mdv15 if mdv15 is not None else None,
190+
registry=registry,
191+
ss58_format=serialized["ss58_format"],
192+
runtime_info=serialized["runtime_info"]
193+
)
194+
152195
def load_runtime(self):
153196
"""
154197
Initial loading of the runtime's type registry information.

async_substrate_interface/utils/cache.py

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import functools
55
import logging
66
import os
7-
import pickle
87
import sqlite3
98
from pathlib import Path
109
from typing import Callable, Any, Awaitable, Hashable, Optional
1110

1211
import aiosqlite
12+
import dill as pickle
1313

1414

1515
USE_CACHE = True if os.getenv("NO_CACHE") != "1" else False
@@ -38,13 +38,11 @@ def __new__(cls, chain_endpoint: str):
3838
cls._instances[chain_endpoint] = instance
3939
return instance
4040

41-
async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]:
41+
async def _create_if_not_exists(self, chain: str, table_name: str):
4242
async with self._lock:
4343
if not self._db:
4444
_ensure_dir()
4545
self._db = await aiosqlite.connect(CACHE_LOCATION)
46-
table_name = _get_table_name(func)
47-
key = None
4846
if not (local_chain := _check_if_local(chain)) or not USE_CACHE:
4947
await self._db.execute(
5048
f"""
@@ -72,19 +70,24 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
7270
"""
7371
)
7472
await self._db.commit()
75-
key = pickle.dumps((args, kwargs or None))
76-
try:
77-
cursor: aiosqlite.Cursor = await self._db.execute(
78-
f"SELECT value FROM {table_name} WHERE key=? AND chain=?",
79-
(key, chain),
80-
)
81-
result = await cursor.fetchone()
82-
await cursor.close()
83-
if result is not None:
84-
return pickle.loads(result[0])
85-
except (pickle.PickleError, sqlite3.Error) as e:
86-
logger.exception("Cache error", exc_info=e)
87-
pass
73+
return local_chain
74+
75+
async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]:
76+
table_name = _get_table_name(func)
77+
local_chain = await self._create_if_not_exists(chain, table_name)
78+
key = pickle.dumps((args, kwargs or None))
79+
try:
80+
cursor: aiosqlite.Cursor = await self._db.execute(
81+
f"SELECT value FROM {table_name} WHERE key=? AND chain=?",
82+
(key, chain),
83+
)
84+
result = await cursor.fetchone()
85+
await cursor.close()
86+
if result is not None:
87+
return pickle.loads(result[0])
88+
except (pickle.PickleError, sqlite3.Error) as e:
89+
logger.exception("Cache error", exc_info=e)
90+
pass
8891
result = await func(other_self, *args, **kwargs)
8992
if not local_chain or not USE_CACHE:
9093
# TODO use a task here
@@ -95,6 +98,59 @@ async def __call__(self, chain, other_self, func, args, kwargs) -> Optional[Any]
9598
await self._db.commit()
9699
return result
97100

101+
async def load_runtime_cache(self, chain: str) -> tuple[dict, dict, dict]:
102+
block_mapping = {}
103+
block_hash_mapping = {}
104+
version_mapping = {}
105+
tables = {
106+
"rt_cache_block": block_mapping,
107+
"rt_cache_block_hash": block_hash_mapping,
108+
"rt_cache_version": version_mapping
109+
}
110+
for table in tables.keys():
111+
local_chain = await self._create_if_not_exists(chain, table)
112+
if local_chain:
113+
return {}, {}, {}
114+
for table_name, mapping in tables.items():
115+
try:
116+
cursor: aiosqlite.Cursor = await self._db.execute(
117+
f"SELECT key, value FROM {table_name} WHERE chain=?",
118+
(chain,),
119+
)
120+
results = await cursor.fetchall()
121+
await cursor.close()
122+
if results is None:
123+
continue
124+
for row in results:
125+
key, value = row
126+
runtime = pickle.loads(value)
127+
mapping[key] = runtime
128+
except (pickle.PickleError, sqlite3.Error) as e:
129+
logger.exception("Cache error", exc_info=e)
130+
return {}, {}, {}
131+
return block_mapping, block_hash_mapping, version_mapping
132+
133+
async def dump_runtime_cache(self, chain: str, block_mapping: dict, block_hash_mapping: dict, version_mapping: dict) -> None:
134+
async with self._lock:
135+
if not self._db:
136+
_ensure_dir()
137+
self._db = await aiosqlite.connect(CACHE_LOCATION)
138+
tables = {
139+
"rt_cache_block": block_mapping,
140+
"rt_cache_block_hash": block_hash_mapping,
141+
"rt_cache_version": version_mapping
142+
}
143+
for table, mapping in tables.items():
144+
local_chain = await self._create_if_not_exists(chain, table)
145+
if local_chain:
146+
return None
147+
await self._db.executemany(
148+
f"INSERT OR REPLACE INTO {table} (key, value, chain) VALUES (?,?,?)",
149+
[(key, pickle.dumps(runtime.serialize()), chain) for key, runtime in mapping.items()],
150+
)
151+
await self._db.commit()
152+
return None
153+
98154

99155
def _ensure_dir():
100156
path = Path(CACHE_LOCATION).parent

0 commit comments

Comments
 (0)