Skip to content

Commit

Permalink
Avoid reindexing on a single level rollback (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
droserasprout authored Jul 14, 2021
1 parent 2f6404e commit 7a38125
Show file tree
Hide file tree
Showing 12 changed files with 437 additions and 99 deletions.
2 changes: 1 addition & 1 deletion src/dipdup/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,10 +782,10 @@ def _pre_initialize_index(self, index_name: str, index_config: IndexConfigT) ->
if isinstance(contract, str):
index_config.contracts[i] = self.get_contract(contract)

transaction_id = 0
for handler_config in index_config.handlers:
self._callback_patterns[handler_config.callback].append(handler_config.pattern)
for pattern_config in handler_config.pattern:
transaction_id = 0
if isinstance(pattern_config, OperationHandlerTransactionPatternConfig):
if isinstance(pattern_config.destination, str):
pattern_config.destination = self.get_contract(pattern_config.destination)
Expand Down
20 changes: 16 additions & 4 deletions src/dipdup/datasources/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@

from dipdup.config import HTTPConfig
from dipdup.http import HTTPGateway
from dipdup.models import BigMapData, OperationData
from dipdup.models import BigMapData, HeadBlockData, OperationData


class EventType(Enum):
operations = 'operatitions'
big_maps = 'big_maps'
rollback = 'rollback'
head = 'head'


class OperationsCallback(Protocol):
def __call__(self, datasource: 'IndexDatasource', operations: List[OperationData]) -> Awaitable[None]:
def __call__(self, datasource: 'IndexDatasource', operations: List[OperationData], block: HeadBlockData) -> Awaitable[None]:
...


Expand All @@ -29,6 +30,11 @@ def __call__(self, datasource: 'IndexDatasource', from_level: int, to_level: int
...


class HeadCallback(Protocol):
def __call__(self, datasource: 'IndexDatasource', block: HeadBlockData) -> Awaitable[None]:
...


class IndexDatasource(HTTPGateway, AsyncIOEventEmitter):
def __init__(self, url: str, http_config: Optional[HTTPConfig] = None) -> None:
HTTPGateway.__init__(self, url, http_config)
Expand All @@ -51,11 +57,17 @@ def on_big_maps(self, fn: BigMapsCallback) -> None:
def on_rollback(self, fn: RollbackCallback) -> None:
super().on(EventType.rollback, fn)

def emit_operations(self, operations: List[OperationData]) -> None:
super().emit(EventType.operations, datasource=self, operations=operations)
def on_head(self, fn: HeadCallback) -> None:
super().on(EventType.head, fn)

def emit_operations(self, operations: List[OperationData], block: HeadBlockData) -> None:
super().emit(EventType.operations, datasource=self, operations=operations, block=block)

def emit_big_maps(self, big_maps: List[BigMapData]) -> None:
super().emit(EventType.big_maps, datasource=self, big_maps=big_maps)

def emit_rollback(self, from_level: int, to_level: int) -> None:
super().emit(EventType.rollback, datasource=self, from_level=from_level, to_level=to_level)

def emit_head(self, block: HeadBlockData) -> None:
super().emit(EventType.head, datasource=self, block=block)
150 changes: 118 additions & 32 deletions src/dipdup/datasources/tzkt/datasource.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import Any, AsyncGenerator, Dict, List, NoReturn, Optional, Set, Tuple, cast

Expand All @@ -18,7 +20,7 @@
)
from dipdup.datasources.datasource import IndexDatasource
from dipdup.datasources.tzkt.enums import TzktMessageType
from dipdup.models import BigMapAction, BigMapData, OperationData
from dipdup.models import BigMapAction, BigMapData, BlockData, HeadBlockData, OperationData
from dipdup.utils import split_by_chunks

OperationID = int
Expand Down Expand Up @@ -283,6 +285,7 @@ def __init__(

self._client: Optional[BaseHubConnection] = None

self._block: Optional[HeadBlockData] = None
self._level: Optional[int] = None
self._sync_level: Optional[int] = None

Expand All @@ -298,6 +301,12 @@ def level(self) -> Optional[int]:
def sync_level(self) -> Optional[int]:
return self._sync_level

@property
def block(self) -> HeadBlockData:
if self._block is None:
raise RuntimeError('No message from `head` channel received')
return self._block

async def get_similar_contracts(self, address: str, strict: bool = False) -> List[str]:
"""Get list of contracts sharing the same code hash or type hash"""
entrypoint = 'same' if strict else 'similar'
Expand Down Expand Up @@ -352,15 +361,23 @@ async def get_jsonschemas(self, address: str) -> Dict[str, Any]:
self._logger.debug(jsonschemas)
return jsonschemas

async def get_latest_block(self) -> Dict[str, Any]:
async def get_head_block(self) -> HeadBlockData:
"""Get latest block (head)"""
self._logger.info('Fetching latest block')
block = await self._http.request(
head_block_json = await self._http.request(
'get',
url='v1/head',
)
self._logger.debug(block)
return block
return self.convert_head_block(head_block_json)

async def get_block(self, level: int) -> BlockData:
"""Get block by level"""
self._logger.info('Fetching block %s', level)
block_json = await self._http.request(
'get',
url=f'v1/blocks/{level}',
)
return self.convert_block(block_json)

async def get_originations(
self, addresses: Set[str], offset: int, first_level: int, last_level: int, cache: bool = False
Expand Down Expand Up @@ -486,6 +503,7 @@ def _get_client(self) -> BaseHubConnection:
self._client.on_error(self._on_error)
self._client.on('operations', self._on_operation_message)
self._client.on('bigmaps', self._on_big_map_message)
self._client.on('head', self._on_head_message)

return self._client

Expand All @@ -502,6 +520,7 @@ async def _on_connect(self) -> None:
return

self._logger.info('Connected to server')
await self.subscribe_to_head()
for address in self._transaction_subscriptions:
await self.subscribe_to_transactions(address)
# NOTE: All originations are passed to matcher
Expand All @@ -517,11 +536,7 @@ def _on_error(self, message: CompletionMessage) -> NoReturn:
async def subscribe_to_transactions(self, address: str) -> None:
"""Subscribe to contract's operations on established WS connection"""
self._logger.info('Subscribing to %s transactions', address)

while self._get_client().transport.state != ConnectionState.connected:
await asyncio.sleep(0.1)

await self._get_client().send(
await self._send(
'SubscribeToOperations',
[
{
Expand All @@ -534,11 +549,7 @@ async def subscribe_to_transactions(self, address: str) -> None:
async def subscribe_to_originations(self) -> None:
"""Subscribe to all originations on established WS connection"""
self._logger.info('Subscribing to originations')

while self._get_client().transport.state != ConnectionState.connected:
await asyncio.sleep(0.1)

await self._get_client().send(
await self._send(
'SubscribeToOperations',
[
{
Expand All @@ -550,12 +561,8 @@ async def subscribe_to_originations(self) -> None:
async def subscribe_to_big_maps(self, address: str, paths: List[str]) -> None:
"""Subscribe to contract's big map diffs on established WS connection"""
self._logger.info('Subscribing to big map updates of %s, %s', address, paths)

while self._get_client().transport.state != ConnectionState.connected:
await asyncio.sleep(0.1)

for path in paths:
await self._get_client().send(
await self._send(
'SubscribeToBigMaps',
[
{
Expand All @@ -565,6 +572,14 @@ async def subscribe_to_big_maps(self, address: str, paths: List[str]) -> None:
],
)

async def subscribe_to_head(self) -> None:
"""Subscribe to head on established WS connection"""
self._logger.info('Subscribing to head')
await self._send(
'SubscribeToHead',
[],
)

def _default_http_config(self) -> HTTPConfig:
return HTTPConfig(
cache=True,
Expand All @@ -574,12 +589,8 @@ def _default_http_config(self) -> HTTPConfig:
ratelimit_period=30,
)

async def _on_operation_message(
self,
message: List[Dict[str, Any]],
) -> None:
async def _on_operation_message(self, message: List[Dict[str, Any]]) -> None:
"""Parse and emit raw operations from WS"""

for item in message:
current_level = item['state']
message_type = TzktMessageType(item['type'])
Expand All @@ -597,7 +608,8 @@ async def _on_operation_message(
if operation.status != 'applied':
continue
operations.append(operation)
self.emit_operations(operations)
if operations:
self.emit_operations(operations, self.block)

elif message_type == TzktMessageType.REORG:
if self.level is None:
Expand All @@ -607,10 +619,7 @@ async def _on_operation_message(
else:
raise NotImplementedError

async def _on_big_map_message(
self,
message: List[Dict[str, Any]],
) -> None:
async def _on_big_map_message(self, message: List[Dict[str, Any]]) -> None:
"""Parse and emit raw big map diffs from WS"""
for item in message:
current_level = item['state']
Expand All @@ -637,6 +646,31 @@ async def _on_big_map_message(
else:
raise NotImplementedError

async def _on_head_message(self, message: List[Dict[str, Any]]) -> None:
for item in message:
current_level = item['state']
message_type = TzktMessageType(item['type'])
self._logger.info('Got block message, %s, level %s', message_type, current_level)

if message_type == TzktMessageType.STATE:
self._sync_level = current_level
self._level = current_level

elif message_type == TzktMessageType.DATA:
self._level = current_level
block_json = item['data']
block = self.convert_head_block(block_json)
self._block = block
self.emit_head(block)

elif message_type == TzktMessageType.REORG:
if self.level is None:
raise RuntimeError
self.emit_rollback(self.level, current_level)

else:
raise NotImplementedError

@classmethod
def convert_operation(cls, operation_json: Dict[str, Any]) -> OperationData:
"""Convert raw operation message from WS/REST into dataclass"""
Expand All @@ -649,7 +683,7 @@ def convert_operation(cls, operation_json: Dict[str, Any]) -> OperationData:
type=operation_json['type'],
id=operation_json['id'],
level=operation_json['level'],
timestamp=operation_json['timestamp'],
timestamp=cls._parse_timestamp(operation_json['timestamp']),
block=operation_json.get('block'),
hash=operation_json['hash'],
counter=operation_json['counter'],
Expand Down Expand Up @@ -686,11 +720,63 @@ def convert_big_map(cls, big_map_json: Dict[str, Any]) -> BigMapData:
level=big_map_json['level'],
# FIXME: operation_id field in API
operation_id=big_map_json['level'],
timestamp=big_map_json['timestamp'],
timestamp=cls._parse_timestamp(big_map_json['timestamp']),
bigmap=big_map_json['bigmap'],
contract_address=big_map_json['contract']['address'],
path=big_map_json['path'],
action=BigMapAction(big_map_json['action']),
key=big_map_json.get('content', {}).get('key'),
value=big_map_json.get('content', {}).get('value'),
)

@classmethod
def convert_block(cls, block_json: Dict[str, Any]) -> BlockData:
"""Convert raw block message from REST into dataclass"""
return BlockData(
level=block_json['level'],
hash=block_json['hash'],
timestamp=cls._parse_timestamp(block_json['timestamp']),
proto=block_json['proto'],
priority=block_json['priority'],
validations=block_json['validations'],
deposit=block_json['deposit'],
reward=block_json['reward'],
fees=block_json['fees'],
nonce_revealed=block_json['nonceRevealed'],
baker_address=block_json.get('baker', {}).get('address'),
baker_alias=block_json.get('baker', {}).get('alias'),
)

@classmethod
def convert_head_block(cls, head_block_json: Dict[str, Any]) -> HeadBlockData:
"""Convert raw head block message from WS/REST into dataclass"""
return HeadBlockData(
cycle=head_block_json['cycle'],
level=head_block_json['level'],
hash=head_block_json['hash'],
protocol=head_block_json['protocol'],
timestamp=cls._parse_timestamp(head_block_json['timestamp']),
voting_epoch=head_block_json['votingEpoch'],
voting_period=head_block_json['votingPeriod'],
known_level=head_block_json['knownLevel'],
last_sync=head_block_json['lastSync'],
synced=head_block_json['synced'],
quote_level=head_block_json['quoteLevel'],
quote_btc=Decimal(head_block_json['quoteBtc']),
quote_eur=Decimal(head_block_json['quoteEur']),
quote_usd=Decimal(head_block_json['quoteUsd']),
quote_cny=Decimal(head_block_json['quoteCny']),
quote_jpy=Decimal(head_block_json['quoteJpy']),
quote_krw=Decimal(head_block_json['quoteKrw']),
quote_eth=Decimal(head_block_json['quoteEth']),
)

async def _send(self, method: str, arguments: List[Dict[str, Any]], on_invocation=None) -> None:
client = self._get_client()
while client.transport.state != ConnectionState.connected:
await asyncio.sleep(0.1)
await client.send(method, arguments, on_invocation)

@classmethod
def _parse_timestamp(cls, timestamp: str) -> datetime:
return datetime.fromisoformat(timestamp[:-1]).replace(tzinfo=timezone.utc)
Loading

0 comments on commit 7a38125

Please sign in to comment.