diff --git a/chain/kusama/defaults.go b/chain/kusama/defaults.go index 26054fa83a..0ec9f86a14 100644 --- a/chain/kusama/defaults.go +++ b/chain/kusama/defaults.go @@ -25,7 +25,7 @@ func DefaultConfig() *cfg.Config { config.Core.GrandpaAuthority = false config.Core.Role = 1 config.Network.NoMDNS = false - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/chain/paseo/defaults.go b/chain/paseo/defaults.go index 30bdc77e3f..8b6f0e8687 100644 --- a/chain/paseo/defaults.go +++ b/chain/paseo/defaults.go @@ -24,7 +24,7 @@ func DefaultConfig() *cfg.Config { config.Core.GrandpaAuthority = false config.Core.Role = 1 config.Network.NoMDNS = false - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/chain/polkadot/defaults.go b/chain/polkadot/defaults.go index c0e405f69c..bc2bcd0f74 100644 --- a/chain/polkadot/defaults.go +++ b/chain/polkadot/defaults.go @@ -24,7 +24,7 @@ func DefaultConfig() *cfg.Config { config.Core.GrandpaAuthority = false config.Core.Role = 1 config.Network.NoMDNS = false - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/chain/westend-dev/defaults.go b/chain/westend-dev/defaults.go index 912322ba79..0345052a62 100644 --- a/chain/westend-dev/defaults.go +++ b/chain/westend-dev/defaults.go @@ -24,7 +24,7 @@ func DefaultConfig() *cfg.Config { config.RPC.UnsafeRPC = true config.RPC.WSExternal = true config.RPC.UnsafeWSExternal = true - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/chain/westend-local/default.go b/chain/westend-local/default.go index 9e43c5a648..24f6c6cc05 100644 --- a/chain/westend-local/default.go +++ b/chain/westend-local/default.go @@ -29,7 +29,7 @@ func DefaultConfig() *cfg.Config { config.RPC.UnsafeRPC = true config.RPC.WSExternal = true config.RPC.UnsafeWSExternal = true - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/chain/westend/defaults.go b/chain/westend/defaults.go index 1419207084..df867236ed 100644 --- a/chain/westend/defaults.go +++ b/chain/westend/defaults.go @@ -24,7 +24,7 @@ func DefaultConfig() *cfg.Config { config.Core.GrandpaAuthority = false config.Core.Role = 1 config.Network.NoMDNS = false - config.Core.Sync = "full" + config.Core.SyncMode = cfg.FullSync return config } diff --git a/cmd/gossamer/commands/root.go b/cmd/gossamer/commands/root.go index 6a8a33416b..6f563fd5a0 100644 --- a/cmd/gossamer/commands/root.go +++ b/cmd/gossamer/commands/root.go @@ -39,6 +39,8 @@ var ( role string // validator when set, the node will be an authority validator bool + // Sync mode [warp | full] + syncMode string // Account Config // key to use for the node @@ -102,6 +104,10 @@ Usage: return fmt.Errorf("failed to parse role: %s", err) } + if err := parseSyncMode(); err != nil { + return fmt.Errorf("failed to parse sync mode: %s", err) + } + if err := parseTelemetryURL(); err != nil { return fmt.Errorf("failed to parse telemetry-url: %s", err.Error()) } @@ -529,13 +535,10 @@ func addCoreFlags(cmd *cobra.Command) error { return fmt.Errorf("failed to add --grandpa-interval flag: %s", err) } - if err := addStringFlagBindViper(cmd, + cmd.Flags().StringVar(&syncMode, "sync", - config.Core.Sync, - "sync mode [warp | full]", - "core.sync"); err != nil { - return fmt.Errorf("failed to add --sync flag: %s", err) - } + cfg.FullSync.String(), + "Sync mode. One of 'full' or 'warp'.") return nil } diff --git a/cmd/gossamer/commands/utils.go b/cmd/gossamer/commands/utils.go index 2414bdeeb6..22cfcaab64 100644 --- a/cmd/gossamer/commands/utils.go +++ b/cmd/gossamer/commands/utils.go @@ -407,6 +407,23 @@ func parseRole() error { return nil } +// parseSyncMode parses the sync mode from the command line flags +func parseSyncMode() error { + var selectedSyncMode cfg.SyncMode + switch syncMode { + case cfg.FullSync.String(): + selectedSyncMode = cfg.FullSync + case cfg.WarpSync.String(): + selectedSyncMode = cfg.WarpSync + default: + return fmt.Errorf("invalid sync mode: %s", role) + } + + config.Core.SyncMode = selectedSyncMode + viper.Set("core.syncMode", config.Core.SyncMode) + return nil +} + // parseTelemetryURL parses the telemetry-url from the command line flag func parseTelemetryURL() error { if telemetryURLs == "" { diff --git a/config/config.go b/config/config.go index 9ff07eb042..93381792ac 100644 --- a/config/config.go +++ b/config/config.go @@ -63,7 +63,7 @@ const ( DefaultSystemVersion = "0.0.0" // DefaultSyncMode is the default block sync mode - DefaultSyncMode = "full" + DefaultSyncMode = FullSync ) // DefaultRPCModules the default RPC modules @@ -191,7 +191,7 @@ type CoreConfig struct { GrandpaAuthority bool `mapstructure:"grandpa-authority"` WasmInterpreter string `mapstructure:"wasm-interpreter,omitempty"` GrandpaInterval time.Duration `mapstructure:"grandpa-interval,omitempty"` - Sync string `mapstructure:"sync,omitempty"` + SyncMode SyncMode `mapstructure:"sync,omitempty"` } // StateConfig contains the configuration for the state. @@ -367,7 +367,7 @@ func DefaultConfig() *Config { GrandpaAuthority: true, WasmInterpreter: DefaultWasmInterpreter, GrandpaInterval: DefaultDiscoveryInterval, - Sync: DefaultSyncMode, + SyncMode: DefaultSyncMode, }, Network: &NetworkConfig{ Port: DefaultNetworkPort, @@ -449,7 +449,7 @@ func DefaultConfigFromSpec(nodeSpec *genesis.Genesis) *Config { GrandpaAuthority: true, WasmInterpreter: DefaultWasmInterpreter, GrandpaInterval: DefaultDiscoveryInterval, - Sync: DefaultSyncMode, + SyncMode: DefaultSyncMode, }, Network: &NetworkConfig{ Port: DefaultNetworkPort, @@ -531,7 +531,7 @@ func Copy(c *Config) Config { GrandpaAuthority: c.Core.GrandpaAuthority, WasmInterpreter: c.Core.WasmInterpreter, GrandpaInterval: c.Core.GrandpaInterval, - Sync: c.Core.Sync, + SyncMode: c.Core.SyncMode, }, Network: &NetworkConfig{ Port: c.Network.Port, @@ -611,6 +611,19 @@ func (c Chain) String() string { return string(c) } +// SyncMode is a string representing a sync mode +type SyncMode string + +const ( + FullSync SyncMode = "full" + WarpSync SyncMode = "warp" + StateSync SyncMode = "state" +) + +func (n SyncMode) String() string { + return string(n) +} + // NetworkRole is a string representing a network role type NetworkRole string diff --git a/dot/network/discovery.go b/dot/network/discovery.go index 5c81559fc2..641a50f1b3 100644 --- a/dot/network/discovery.go +++ b/dot/network/discovery.go @@ -171,7 +171,7 @@ func (d *discovery) advertise() { ttl, err = d.rd.Advertise(d.ctx, string(d.pid)) if err != nil { - logger.Warnf("failed to advertise in the DHT: %s", err) + logger.Debugf("failed to advertise in the DHT: %s", err) ttl = tryAdvertiseTimeout } } diff --git a/dot/network/messages/state.go b/dot/network/messages/state.go index baae4e1487..835f98479b 100644 --- a/dot/network/messages/state.go +++ b/dot/network/messages/state.go @@ -22,10 +22,18 @@ type StateRequest struct { NoProof bool } +func NewStateRequest(block common.Hash, start [][]byte, noProof bool) *StateRequest { + return &StateRequest{ + Block: block, + Start: start, + NoProof: noProof, + } +} + func (s *StateRequest) String() string { - return fmt.Sprintf("StateRequest Block=%s Start=[0x%x, 0x%x] NoProof=%v", + return fmt.Sprintf("StateRequest Block=%s Start=[%v] NoProof=%v", s.Block.String(), - s.Start[0], s.Start[1], + s.Start, s.NoProof, ) } @@ -98,3 +106,14 @@ func (s *StateResponse) Decode(in []byte) error { return nil } + +func (s *StateResponse) Encode() ([]byte, error) { + panic("not implemented") +} + +func (s *StateResponse) String() string { + return fmt.Sprintf("StateResponse Entries=[%v] Proof=[%v]", + s.Entries, + s.Proof, + ) +} diff --git a/dot/network/notifications.go b/dot/network/notifications.go index a938a64661..b8d478991b 100644 --- a/dot/network/notifications.go +++ b/dot/network/notifications.go @@ -282,7 +282,7 @@ func (s *Service) sendData(peer peer.ID, hs Handshake, info *notificationsProtoc // we've completed the handshake with the peer, send message directly logger.Tracef("sending message to peer %s using protocol %s: %s", peer, info.protocolID, msg) if err := s.host.writeToStream(stream, msg); err != nil { - logger.Errorf("failed to send message to peer %s: %s", peer, err) + logger.Debugf("failed to send message to peer %s: %s", peer, err) // the stream was closed or reset, close it on our end and delete it from our peer's data if errors.Is(err, io.EOF) || errors.Is(err, network.ErrReset) { diff --git a/dot/network/request_response.go b/dot/network/request_response.go index da99c54950..5efba87115 100644 --- a/dot/network/request_response.go +++ b/dot/network/request_response.go @@ -30,6 +30,19 @@ type RequestResponseProtocol struct { responseBuf []byte } +func NewRequestResponseProtocol(ctx context.Context, host *host, protocolID protocol.ID, + requestTimeout time.Duration, maxResponseSize uint64) *RequestResponseProtocol { + return &RequestResponseProtocol{ + ctx: ctx, + host: host, + requestTimeout: requestTimeout, + maxResponseSize: maxResponseSize, + protocolID: protocolID, + responseBuf: make([]byte, maxResponseSize), + responseBufMu: sync.Mutex{}, + } +} + func (rrp *RequestResponseProtocol) Do(to peer.ID, req, res messages.P2PMessage) error { rrp.host.p2pHost.ConnManager().Protect(to, "") defer rrp.host.p2pHost.ConnManager().Unprotect(to, "") diff --git a/dot/network/service.go b/dot/network/service.go index 41cd4fa3a8..6bc5435e15 100644 --- a/dot/network/service.go +++ b/dot/network/service.go @@ -34,6 +34,7 @@ const ( // the following are sub-protocols used by the node SyncID = "/sync/2" WarpSyncID = "/sync/warp" + StateSyncID = "/state/2" lightID = "/light/2" blockAnnounceID = "/block-announces/1" transactionsID = "/transactions/1" @@ -629,15 +630,7 @@ func (s *Service) GetRequestResponseProtocol(subprotocol string, requestTimeout genesisHash = strings.TrimPrefix(genesisHash, "0x") protocolId := fmt.Sprintf("/%s%s", genesisHash, subprotocol) - return &RequestResponseProtocol{ - ctx: s.ctx, - host: s.host, - requestTimeout: requestTimeout, - maxResponseSize: maxResponseSize, - protocolID: protocol.ID(protocolId), - responseBuf: make([]byte, maxResponseSize), - responseBufMu: sync.Mutex{}, - } + return NewRequestResponseProtocol(s.ctx, s.host, protocol.ID(protocolId), requestTimeout, maxResponseSize) } // Health returns information about host needed for the rpc server @@ -765,7 +758,7 @@ func (s *Service) processMessage(msg peerset.Message) { err := s.host.connect(addrInfo) if err != nil { // TODO: if error happens here outgoing (?) slot is occupied but no peer is really connected - logger.Warnf("failed to open connection for peer %s: %s", peerID, err) + logger.Debugf("failed to open connection for peer %s: %s", peerID, err) return } logger.Debugf("connection successful with peer %s", peerID) diff --git a/dot/peerset/constants.go b/dot/peerset/constants.go index 095d334ec4..c9851a94ca 100644 --- a/dot/peerset/constants.go +++ b/dot/peerset/constants.go @@ -86,4 +86,9 @@ const ( BadWarpProofValue Reputation = -(1 << 29) // BadWarpProofReason is used when peer send invalid warp sync proof. BadWarpProofReason = "Bad warp proof" + + // BadStateValue is used when peer send invalid state response. + BadStateValue Reputation = -(1 << 29) + // BadStateReason is used when peer send invalid state response. + BadStateReason = "Bad state" ) diff --git a/dot/services.go b/dot/services.go index c9dc92169d..c5068b5912 100644 --- a/dot/services.go +++ b/dot/services.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "strings" - "time" cfg "github.com/ChainSafe/gossamer/config" @@ -38,8 +37,6 @@ import ( wazero_runtime "github.com/ChainSafe/gossamer/lib/runtime/wazero" ) -const blockRequestTimeout = 20 * time.Second - // BlockProducer to produce blocks type BlockProducer interface { Pause() error @@ -524,51 +521,20 @@ func (nodeBuilder) newSyncService(config *cfg.Config, st *state.Service, fg sync return nil, fmt.Errorf("failed to parse sync log level: %w", err) } - // Should be shared between all sync strategies - peersView := sync.NewPeerViewSet() - - var warpSyncStrategy sync.Strategy - - if config.Core.Sync == "warp" { - warpSyncProvider := warpsync.NewWarpSyncProofProvider(st.Block, st.Grandpa) - - warpSyncCfg := &sync.WarpSyncConfig{ - Telemetry: telemetryMailer, - BadBlocks: genesisData.BadBlocks, - WarpSyncProvider: warpSyncProvider, - WarpSyncRequestMaker: net.GetRequestResponseProtocol(network.WarpSyncID, - blockRequestTimeout, network.MaxBlockResponseSize), - SyncRequestMaker: net.GetRequestResponseProtocol(network.SyncID, - blockRequestTimeout, network.MaxBlockResponseSize), - BlockState: st.Block, - Peers: peersView, - } - - warpSyncStrategy = sync.NewWarpSyncStrategy(warpSyncCfg) - } - - syncCfg := &sync.FullSyncConfig{ - BlockState: st.Block, - StorageState: st.Storage, - TransactionState: st.Transaction, - FinalityGadget: fg, - BabeVerifier: verifier, - BlockImportHandler: cs, - Telemetry: telemetryMailer, - BadBlocks: genesisData.BadBlocks, - RequestMaker: net.GetRequestResponseProtocol(network.SyncID, - blockRequestTimeout, network.MaxBlockResponseSize), - Peers: peersView, - } - fullSync := sync.NewFullSyncStrategy(syncCfg) - return sync.NewSyncService( syncLogLevel, sync.WithNetwork(net), sync.WithBlockState(st.Block), + sync.WithGrandpaState(st.Grandpa), + sync.WithStorageState(st.Storage), + sync.WithFinalityGadget(fg), + sync.WithBabeVerifier(verifier), + sync.WithBlockImportHandler(cs), + sync.WithTelemetry(telemetryMailer), + sync.WithBadBlocks(genesisData.BadBlocks), + sync.WithSyncMethod(config.Core.SyncMode), + sync.WithTransactionState(st.Transaction), sync.WithSlotDuration(slotDuration), - sync.WithWarpSyncStrategy(warpSyncStrategy), - sync.WithFullSyncStrategy(fullSync), sync.WithMinPeers(config.Network.MinPeers), ), nil } diff --git a/dot/sync/block_importer.go b/dot/sync/block_importer.go index 395113c8a4..f6ffe65e40 100644 --- a/dot/sync/block_importer.go +++ b/dot/sync/block_importer.go @@ -33,6 +33,7 @@ type ( // StorageState is the interface for the storage state StorageState interface { + StoreTrie(ts *rtstorage.TrieState, header *types.Header) error TrieState(root *common.Hash) (*rtstorage.TrieState, error) sync.Locker } diff --git a/dot/sync/configuration.go b/dot/sync/configuration.go index 9e01c3996d..a7d286a421 100644 --- a/dot/sync/configuration.go +++ b/dot/sync/configuration.go @@ -3,32 +3,78 @@ package sync -import "time" +import ( + "time" + + "github.com/ChainSafe/gossamer/config" +) type ServiceConfig func(svc *SyncService) -func WithWarpSyncStrategy(warpSyncStrategy Strategy) ServiceConfig { +func WithBlockState(bs BlockState) ServiceConfig { return func(svc *SyncService) { - svc.warpSyncStrategy = warpSyncStrategy + svc.blockState = bs } } -func WithFullSyncStrategy(fullSyncStrategy Strategy) ServiceConfig { +func WithGrandpaState(gs GrandpaState) ServiceConfig { return func(svc *SyncService) { - svc.fullSyncStrategy = fullSyncStrategy + svc.grandpaState = gs } } -func WithNetwork(net Network) ServiceConfig { +func WithStorageState(ss StorageState) ServiceConfig { return func(svc *SyncService) { - svc.network = net - svc.workerPool = newSyncWorkerPool(net) + svc.storageState = ss } } -func WithBlockState(bs BlockState) ServiceConfig { +func WithFinalityGadget(fg FinalityGadget) ServiceConfig { return func(svc *SyncService) { - svc.blockState = bs + svc.finalityGadget = fg + } +} + +func WithBabeVerifier(bv BabeVerifier) ServiceConfig { + return func(svc *SyncService) { + svc.babeVerifier = bv + } +} + +func WithBlockImportHandler(bih BlockImportHandler) ServiceConfig { + return func(svc *SyncService) { + svc.blockImportHandler = bih + } +} + +func WithTelemetry(t Telemetry) ServiceConfig { + return func(svc *SyncService) { + svc.telemetry = t + } +} + +func WithBadBlocks(badBlocks []string) ServiceConfig { + return func(svc *SyncService) { + svc.badBlocks = badBlocks + } +} + +func WithSyncMethod(method config.SyncMode) ServiceConfig { + return func(svc *SyncService) { + svc.syncStrategy = method + } +} + +func WithTransactionState(ts TransactionState) ServiceConfig { + return func(svc *SyncService) { + svc.transactionState = ts + } +} + +func WithNetwork(net Network) ServiceConfig { + return func(svc *SyncService) { + svc.network = net + svc.workerPool = newSyncWorkerPool(net) } } diff --git a/dot/sync/fullsync.go b/dot/sync/fullsync.go index 248f3bd67a..f5c30e457a 100644 --- a/dot/sync/fullsync.go +++ b/dot/sync/fullsync.go @@ -281,7 +281,7 @@ func (f *FullSyncStrategy) Process(results []*SyncTaskResult) ( return false, repChanges, peersToIgnore, nil } -func (f *FullSyncStrategy) ShowMetrics() { +func (f *FullSyncStrategy) ShowStatus() { totalSyncAndImportSeconds := time.Since(f.startedAt).Seconds() bps := float64(f.syncedBlocks) / totalSyncAndImportSeconds logger.Infof("⛓️ synced %d blocks, tasks on queue %d, disjoint fragments %d, incomplete blocks %d, "+ diff --git a/dot/sync/message.go b/dot/sync/message.go index 8ea81322f4..c937218f1e 100644 --- a/dot/sync/message.go +++ b/dot/sync/message.go @@ -208,13 +208,13 @@ func (s *SyncService) handleDescendingRequest(req *messages.BlockRequestMessage) } if startHash == nil || endHash == nil { - logger.Infof("handling block request message with direction %s "+ + logger.Debugf("handling block request message with direction %s "+ "from number %d to number %d\n", req.Direction.String(), startNumber, endNumber) return s.handleDescendingByNumber(startNumber, endNumber, req.RequestedData) } - logger.Infof("handling block request message with direction %s "+ + logger.Debugf("handling block request message with direction %s "+ "from hash %s to end block with hash %s", req.Direction.String(), *startHash, *endHash) return s.handleChainByHash(*endHash, *startHash, max, req.RequestedData, req.Direction) diff --git a/dot/sync/message_integration_test.go b/dot/sync/message_integration_test.go index f597c13789..9e911566ab 100644 --- a/dot/sync/message_integration_test.go +++ b/dot/sync/message_integration_test.go @@ -6,10 +6,13 @@ package sync import ( + "context" "path/filepath" "testing" "time" + cfg "github.com/ChainSafe/gossamer/config" + "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/network/messages" "github.com/ChainSafe/gossamer/dot/state" "github.com/ChainSafe/gossamer/dot/types" @@ -22,6 +25,7 @@ import ( "github.com/ChainSafe/gossamer/pkg/trie" "github.com/ChainSafe/gossamer/tests/utils/config" "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" "go.uber.org/mock/gomock" rtstorage "github.com/ChainSafe/gossamer/lib/runtime/storage" @@ -135,25 +139,21 @@ func newFullSyncService(t *testing.T) *SyncService { AnyTimes() mockNetwork := NewMockNetwork(ctrl) - - fullSyncCfg := &FullSyncConfig{ - BlockState: stateSrvc.Block, - StorageState: stateSrvc.Storage, - BlockImportHandler: blockImportHandler, - TransactionState: stateSrvc.Transaction, - BabeVerifier: mockBabeVerifier, - FinalityGadget: mockFinalityGadget, - Telemetry: mockTelemetryClient, - RequestMaker: NewMockRequestMaker(ctrl), - } - - fullSync := NewFullSyncStrategy(fullSyncCfg) + mockNetwork.EXPECT().GetRequestResponseProtocol(gomock.Any(), gomock.Any(), gomock.Any()).Return( + network.NewRequestResponseProtocol( + context.Background(), + nil, + protocol.ID(network.SyncID), + 20*time.Second, + 1024*64, + ), + ).AnyTimes() serviceCfg := []ServiceConfig{ WithBlockState(stateSrvc.Block), WithNetwork(mockNetwork), WithSlotDuration(6 * time.Second), - WithFullSyncStrategy(fullSync), + WithSyncMethod(cfg.FullSync), } syncLogLvl := log.Info diff --git a/dot/sync/service.go b/dot/sync/service.go index 94e8970b65..9d33d937d6 100644 --- a/dot/sync/service.go +++ b/dot/sync/service.go @@ -9,11 +9,13 @@ import ( "sync" "time" + "github.com/ChainSafe/gossamer/config" "github.com/ChainSafe/gossamer/dot/network" "github.com/ChainSafe/gossamer/dot/peerset" "github.com/ChainSafe/gossamer/dot/types" "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/lib/grandpa/warpsync" "github.com/ChainSafe/gossamer/lib/runtime" lrucache "github.com/ChainSafe/gossamer/lib/utils/lru-cache" "github.com/libp2p/go-libp2p/core/peer" @@ -24,6 +26,7 @@ import ( const ( waitPeersDefaultTimeout = 10 * time.Second minPeersDefault = 1 + blockRequestTimeout = 20 * time.Second ) var ( @@ -43,6 +46,12 @@ const ( networkBroadcast ) +type GrandpaState interface { + GetCurrentSetID() (uint64, error) + GetAuthorities(uint64) ([]types.GrandpaVoter, error) + GetAuthoritiesChangesFromBlock(uint) ([]uint, error) +} + type Network interface { AllConnectedPeersIDs() []peer.ID ReportPeer(change peerset.ReputationChange, p peer.ID) @@ -90,20 +99,28 @@ type Strategy interface { OnBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error NextActions() ([]*SyncTask, error) Process(results []*SyncTaskResult) (done bool, repChanges []Change, blocks []peer.ID, err error) - ShowMetrics() + ShowStatus() IsSynced() bool Result() any } type SyncService struct { - mu sync.Mutex - wg sync.WaitGroup - network Network - blockState BlockState - - currentStrategy Strategy - fullSyncStrategy Strategy - warpSyncStrategy Strategy + mu sync.Mutex + wg sync.WaitGroup + network Network + blockState BlockState + grandpaState GrandpaState + storageState StorageState + transactionState TransactionState + finalityGadget FinalityGadget + babeVerifier BabeVerifier + blockImportHandler BlockImportHandler + telemetry Telemetry + badBlocks []string + peers *peerViewSet + + syncStrategy config.SyncMode + currentStrategy Strategy workerPool *syncWorkerPool waitPeersDuration time.Duration @@ -123,22 +140,61 @@ func NewSyncService(logLvl log.Level, cfgs ...ServiceConfig) *SyncService { waitPeersDuration: waitPeersDefaultTimeout, stopCh: make(chan struct{}), seenBlockSyncRequests: lrucache.NewLRUCache[common.Hash, uint](100), + peers: NewPeerViewSet(), } for _, cfg := range cfgs { cfg(svc) } - // Set initial strategy - if svc.warpSyncStrategy != nil { - svc.currentStrategy = svc.warpSyncStrategy - } else { - svc.currentStrategy = svc.fullSyncStrategy + switch svc.syncStrategy { + case config.FullSync: + svc.useFullSyncStrategy() + case config.WarpSync: + svc.useWarpSyncStrategy() } return svc } +func (s *SyncService) useWarpSyncStrategy() { + warpSyncProvider := warpsync.NewWarpSyncProofProvider(s.blockState, s.grandpaState) + + warpSyncCfg := &WarpSyncConfig{ + Telemetry: s.telemetry, + BadBlocks: s.badBlocks, + WarpSyncProvider: warpSyncProvider, + WarpSyncRequestMaker: s.network.GetRequestResponseProtocol(network.WarpSyncID, + blockRequestTimeout, network.MaxBlockResponseSize), + SyncRequestMaker: s.network.GetRequestResponseProtocol(network.SyncID, + blockRequestTimeout, network.MaxBlockResponseSize), + BlockState: s.blockState, + Peers: s.peers, + } + + s.currentStrategy = NewWarpSyncStrategy(warpSyncCfg) + s.syncStrategy = config.WarpSync +} + +func (s *SyncService) useFullSyncStrategy() { + syncCfg := &FullSyncConfig{ + BlockState: s.blockState, + StorageState: s.storageState, + TransactionState: s.transactionState, + FinalityGadget: s.finalityGadget, + BabeVerifier: s.babeVerifier, + BlockImportHandler: s.blockImportHandler, + Telemetry: s.telemetry, + BadBlocks: s.badBlocks, + RequestMaker: s.network.GetRequestResponseProtocol(network.SyncID, + blockRequestTimeout, network.MaxBlockResponseSize), + Peers: s.peers, + } + + s.currentStrategy = NewFullSyncStrategy(syncCfg) + s.syncStrategy = config.FullSync +} + func (s *SyncService) waitWorkers() { bestBlockHeader, err := s.blockState.BestBlockHeader() if err != nil { @@ -233,7 +289,7 @@ func (s *SyncService) runSyncEngine() { defer s.wg.Done() s.waitWorkers() - logger.Infof("starting sync engine with strategy: %T", s.currentStrategy) + logger.Infof("starting sync engine with strategy: %s", s.syncStrategy) for { select { @@ -256,28 +312,7 @@ func (s *SyncService) runStrategy() { s.mu.Lock() defer s.mu.Unlock() - logger.Tracef("running strategy: %T", s.currentStrategy) - - finalisedHeader, err := s.blockState.GetHighestFinalisedHeader() - if err != nil { - logger.Criticalf("getting highest finalized header: %w", err) - return - } - - bestBlockHeader, err := s.blockState.BestBlockHeader() - if err != nil { - logger.Criticalf("getting best block header: %w", err) - return - } - - logger.Infof( - "🚣 currently syncing, %d peers connected, finalized #%d (%s), best #%d (%s)", - len(s.network.AllConnectedPeersIDs()), - finalisedHeader.Number, - finalisedHeader.Hash().Short(), - bestBlockHeader.Number, - bestBlockHeader.Hash().Short(), - ) + logger.Tracef("running strategy: %s", s.syncStrategy) tasks, err := s.currentStrategy.NextActions() if err != nil { @@ -306,13 +341,34 @@ func (s *SyncService) runStrategy() { s.workerPool.ignorePeerAsWorker(block) } - s.currentStrategy.ShowMetrics() + s.currentStrategy.ShowStatus() // TODO: why not use s.currentStrategy.IsSynced()? if done { - // Switch to full sync when warp sync finishes - if s.warpSyncStrategy != nil { - s.currentStrategy = s.fullSyncStrategy + switch s.syncStrategy { + case config.WarpSync: + logger.Info("Switching sync strategy: warp sync -> state sync") + // Switch to state sync when warp sync finishes + stateSyncCfg := &StateSyncStrategyConfig{ + Telemetry: s.telemetry, + BadBlocks: s.badBlocks, + BlockState: s.blockState, + Peers: s.peers, + ReqMaker: s.network.GetRequestResponseProtocol(network.StateSyncID, + blockRequestTimeout, network.MaxBlockResponseSize), + StateStorage: s.storageState, + TargetBlock: s.currentStrategy.Result().(types.Header), + } + + s.currentStrategy = NewStateSyncStrategy(stateSyncCfg) + s.syncStrategy = config.StateSync + + case config.StateSync: + logger.Info("Switching sync strategy: state sync -> full sync") + // Switch to full sync when state sync finishes + s.useFullSyncStrategy() + case config.FullSync: + logger.Errorf("Full sync strategy should not finish") } } } diff --git a/dot/sync/state_request_provider.go b/dot/sync/state_request_provider.go new file mode 100644 index 0000000000..1c53983ea1 --- /dev/null +++ b/dot/sync/state_request_provider.go @@ -0,0 +1,100 @@ +package sync + +import ( + "errors" + + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/pkg/trie/inmemory" +) + +var ( + ErrEmptyStateEntries = errors.New("empty state entries") +) + +type StateRequestProvider struct { + lastKeys [][]byte + collectedResponses []*messages.StateResponse + targetHash common.Hash + completed bool + stateVersion trie.TrieLayout +} + +func NewStateRequestProvider(target common.Hash, stateVersion trie.TrieLayout) *StateRequestProvider { + return &StateRequestProvider{ + lastKeys: [][]byte{}, + targetHash: target, + collectedResponses: make([]*messages.StateResponse, 0), + stateVersion: stateVersion, + } +} + +func (s *StateRequestProvider) BuildRequest() *messages.StateRequest { + return &messages.StateRequest{ + Block: s.targetHash, + Start: s.lastKeys, + NoProof: true, + } +} + +func (s *StateRequestProvider) ProcessResponse(stateResponse *messages.StateResponse) (completed bool, err error) { + // TODO: handle merkle proofs to prevent accepting invalid state entries. Issue: #4481 + + if len(stateResponse.Entries) == 0 { + return false, ErrEmptyStateEntries + } + + logger.Debugf("retrieved %d entries", len(stateResponse.Entries)) + for idx, entry := range stateResponse.Entries { + logger.Debugf("#%d with %d entries (complete: %v, root: %s)", + idx, len(entry.StateEntries), entry.Complete, entry.StateRoot.String()) + } + + s.collectedResponses = append(s.collectedResponses, stateResponse) + + if len(s.lastKeys) == 2 && len(stateResponse.Entries[0].StateEntries) == 0 { + // pop last item and keep the first + // do not remove the parent trie position. + s.lastKeys = s.lastKeys[:len(s.lastKeys)-1] + } else { + s.lastKeys = [][]byte{} + } + + s.completed = true + + for _, state := range stateResponse.Entries { + if !state.Complete { + lastItemInResponse := state.StateEntries[len(state.StateEntries)-1] + s.lastKeys = append(s.lastKeys, lastItemInResponse.Key) + s.completed = false + } + } + + return s.completed, nil +} + +func (s *StateRequestProvider) BuildTrie() (trie.Trie, error) { + stateTrie := inmemory.NewEmptyTrie() + stateTrie.SetVersion(s.stateVersion) + + for _, stateResponse := range s.collectedResponses { + for _, stateEntry := range stateResponse.Entries { + for _, kv := range stateEntry.StateEntries { + if err := stateTrie.Put(kv.Key, kv.Value); err != nil { + return nil, err + } + } + } + } + + return stateTrie, nil +} + +func (s *StateRequestProvider) GetLastKeys() [][]byte { + return s.lastKeys +} + +func (s *StateRequestProvider) IsCompleted() bool { + return s.completed +} diff --git a/dot/sync/state_request_provider_test.go b/dot/sync/state_request_provider_test.go new file mode 100644 index 0000000000..b21e1fcde1 --- /dev/null +++ b/dot/sync/state_request_provider_test.go @@ -0,0 +1,125 @@ +package sync + +import ( + "testing" + + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/ChainSafe/gossamer/pkg/trie/inmemory" + "github.com/stretchr/testify/require" +) + +func TestNewStateRequestProvider(t *testing.T) { + targetHash := common.Hash{0x12, 0x23, 0x34} + provider := NewStateRequestProvider(targetHash, trie.V0) + + require.NotNil(t, provider) + + require.Equal(t, [][]byte{}, provider.GetLastKeys()) + require.Equal(t, false, provider.IsCompleted()) +} + +func TestBuildRequest(t *testing.T) { + targetHash := common.Hash{0x12, 0x23, 0x34} + provider := NewStateRequestProvider(targetHash, trie.V0) + + req := provider.BuildRequest() + + require.NotNil(t, req) + require.Equal(t, targetHash, req.Block) + require.Equal(t, [][]byte{}, req.Start) + require.True(t, req.NoProof) +} + +func TestProcessResponse(t *testing.T) { + tc := map[string]struct { + response *messages.StateResponse + completed bool + expectedError error + }{ + "empty_response_entries": { + response: &messages.StateResponse{}, + completed: false, + expectedError: ErrEmptyStateEntries, + }, + "uncomplete_successful_response": { + response: &messages.StateResponse{ + Entries: []messages.KeyValueStateEntry{ + { + StateRoot: common.Hash{0x34, 0x45, 0x56}, + StateEntries: trie.Entries{ + {Key: []byte{0x01}, Value: []byte{0x11}}, + {Key: []byte{0x02}, Value: []byte{0x22}}, + }, + Complete: false, + }, + }, + }, + completed: false, + expectedError: nil, + }, + "complete_successful_response": { + response: &messages.StateResponse{ + Entries: []messages.KeyValueStateEntry{ + { + StateRoot: common.Hash{0x34, 0x45, 0x56}, + StateEntries: trie.Entries{ + {Key: []byte{0x01}, Value: []byte{0x11}}, + {Key: []byte{0x02}, Value: []byte{0x22}}, + }, + Complete: true, + }, + }, + }, + completed: true, + expectedError: nil, + }, + } + + for name, test := range tc { + t.Run(name, func(t *testing.T) { + targetHash := common.Hash{0x12, 0x23, 0x34} + provider := NewStateRequestProvider(targetHash, trie.V0) + + completed, err := provider.ProcessResponse(test.response) + + require.Equal(t, test.completed, completed) + require.Equal(t, test.expectedError, err) + }) + } +} + +func TestBuildTrie(t *testing.T) { + targetHash := common.Hash{0x12, 0x23, 0x34} + provider := NewStateRequestProvider(targetHash, trie.V0) + + trieEntries := trie.Entries{ + {Key: []byte("triekey1"), Value: []byte("trievalue1")}, + {Key: []byte("triekey2"), Value: []byte("trievalue2")}, + {Key: []byte("triekey3"), Value: []byte("trievalue3")}, + {Key: []byte("triekey4"), Value: []byte("trievalue4")}, + } + + expectedTrie := inmemory.NewEmptyTrie() + for _, entry := range trieEntries { + expectedTrie.Put(entry.Key, entry.Value) + } + + response := &messages.StateResponse{ + Entries: []messages.KeyValueStateEntry{ + { + StateRoot: common.Hash{0x34, 0x45, 0x56}, + StateEntries: trieEntries, + Complete: true, + }, + }, + } + + _, err := provider.ProcessResponse(response) + require.Nil(t, err) + + resultTrie, err := provider.BuildTrie() + require.Nil(t, err) + require.Equal(t, expectedTrie.MustHash(), resultTrie.MustHash()) +} diff --git a/dot/sync/state_sync.go b/dot/sync/state_sync.go new file mode 100644 index 0000000000..56eca90d1d --- /dev/null +++ b/dot/sync/state_sync.go @@ -0,0 +1,209 @@ +package sync + +import ( + "fmt" + "slices" + "time" + + "github.com/ChainSafe/gossamer/dot/network" + "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/dot/peerset" + "github.com/ChainSafe/gossamer/dot/types" + "github.com/ChainSafe/gossamer/lib/runtime/storage" + "github.com/ChainSafe/gossamer/pkg/trie" + "github.com/libp2p/go-libp2p/core/peer" +) + +type StateStorage interface { + StoreTrie(ts *storage.TrieState, header *types.Header) error +} + +type StateSyncStrategy struct { + // Strategy dependencies and config + peers *peerViewSet + badBlocks []string + reqMaker network.RequestMaker + blockState BlockState + storage StateStorage + stateRequestProvider *StateRequestProvider + + // State sync state + startedAt time.Time + targetBlock types.Header + completed bool +} + +type StateSyncStrategyConfig struct { + Telemetry Telemetry + BadBlocks []string + BlockState BlockState + Peers *peerViewSet + ReqMaker network.RequestMaker + TargetBlock types.Header + StateStorage StateStorage +} + +func NewStateSyncStrategy( + cfg *StateSyncStrategyConfig, +) *StateSyncStrategy { + return &StateSyncStrategy{ + peers: cfg.Peers, + badBlocks: cfg.BadBlocks, + blockState: cfg.BlockState, + targetBlock: cfg.TargetBlock, + reqMaker: cfg.ReqMaker, + storage: cfg.StateStorage, + // TODO: we can assume that v1 is right for every chain but we need to find a way to set the right state version + stateRequestProvider: NewStateRequestProvider(cfg.TargetBlock.Hash(), trie.V1), + } +} + +// OnBlockAnnounce on every new block announce received. +// we are going to only update the peerset reputation and peers target block. +func (s *StateSyncStrategy) OnBlockAnnounce(from peer.ID, msg *network.BlockAnnounceMessage) ( + repChange *Change, err error) { + + blockAnnounceHeaderHash, err := msg.Hash() + if err != nil { + return nil, err + } + + logger.Debugf("received block announce from %s: #%d (%s) best block: %v", + from, + msg.Number, + blockAnnounceHeaderHash, + msg.BestBlock, + ) + + if slices.Contains(s.badBlocks, blockAnnounceHeaderHash.String()) { + logger.Debugf("bad block received from %s: #%d (%s) is a bad block", + from, msg.Number, blockAnnounceHeaderHash) + + return &Change{ + who: from, + rep: peerset.ReputationChange{ + Value: peerset.BadBlockAnnouncementValue, + Reason: peerset.BadBlockAnnouncementReason, + }, + }, errBadBlockReceived + } + + if msg.BestBlock { + s.peers.update(from, blockAnnounceHeaderHash, uint32(msg.Number)) + } + + return &Change{ + who: from, + rep: peerset.ReputationChange{ + Value: peerset.GossipSuccessValue, + Reason: peerset.GossipSuccessReason, + }, + }, nil +} + +func (s *StateSyncStrategy) OnBlockAnnounceHandshake(from peer.ID, msg *network.BlockAnnounceHandshake) error { + s.peers.update(from, msg.BestBlockHash, msg.BestBlockNumber) + return nil +} + +func (s *StateSyncStrategy) Process(results []*SyncTaskResult) ( + done bool, repChanges []Change, peersToBlock []peer.ID, err error) { + repChanges = make([]Change, 0) + peersToBlock = make([]peer.ID, 0) + + for _, result := range results { + switch response := result.response.(type) { + case *messages.StateResponse: + logger.Debugf("Retrieving state data from %s with %d keys", + result.who, len(response.Entries)) + + s.completed, err = s.stateRequestProvider.ProcessResponse(response) + if err != nil { + switch err { + case ErrEmptyStateEntries: + logger.Infof("Bad state response") + peersToBlock = append(peersToBlock, result.who) + repChanges = append(repChanges, Change{ + who: result.who, + rep: peerset.ReputationChange{ + Value: peerset.BadStateValue, + Reason: peerset.BadStateReason, + }, + }) + continue + } + } + + if s.completed { + return true, repChanges, peersToBlock, s.importState() + } + + return false, repChanges, peersToBlock, nil + + default: + logger.Warnf("unexpected response type %T for state request, banning peer: %s", result.who, response) + repChanges = append(repChanges, Change{ + who: result.who, + rep: peerset.ReputationChange{ + Value: peerset.UnexpectedResponseValue, + Reason: peerset.UnexpectedResponseReason, + }}) + peersToBlock = append(peersToBlock, result.who) + continue + } + } + + return s.IsSynced(), repChanges, peersToBlock, nil +} + +// importState imports the retreived state into our state storage +func (s *StateSyncStrategy) importState() error { + // Store state in our state storage + trieState, err := s.stateRequestProvider.BuildTrie() + if err != nil { + return err + } + + if trieState.MustHash() != s.targetBlock.StateRoot { + return fmt.Errorf("state root mismatch: got %s expected %s", trieState.MustHash(), s.targetBlock.StateRoot) + } + + storageTrie := storage.NewTrieState(trieState) + return s.storage.StoreTrie(storageTrie, &s.targetBlock) +} + +// NextActions returns the next actions to be taken by the sync service +func (s *StateSyncStrategy) NextActions() ([]*SyncTask, error) { + s.startedAt = time.Now() + + task := &SyncTask{ + request: s.stateRequestProvider.BuildRequest(), + response: &messages.StateResponse{}, + requestMaker: s.reqMaker, + } + + return []*SyncTask{task}, nil +} + +func (s *StateSyncStrategy) ShowStatus() { + if len(s.stateRequestProvider.GetLastKeys()) == 0 { + return + } + + lastKey := s.stateRequestProvider.GetLastKeys()[0] + cursor := float32(lastKey[0]) + percentDone := cursor / 256 * 100 + + logger.Infof("⚙️ State Sync, downloading state %.2f%%, last key: 0x%x", percentDone, lastKey) +} + +func (w *StateSyncStrategy) Result() any { + logger.Debug("unexpected call to Result() in StateSyncStrategy") + return nil +} + +func (s *StateSyncStrategy) IsSynced() bool { + return s.completed +} + +var _ Strategy = (*StateSyncStrategy)(nil) diff --git a/dot/sync/warp_sync.go b/dot/sync/warp_sync.go index b886b5e817..93078a9f47 100644 --- a/dot/sync/warp_sync.go +++ b/dot/sync/warp_sync.go @@ -47,7 +47,7 @@ type WarpSyncStrategy struct { setId primitives.SetID authorities primitives.AuthorityList lastBlock *types.Header - result types.BlockData + result types.Header } type WarpSyncConfig struct { @@ -201,7 +201,7 @@ func (w *WarpSyncStrategy) Process(results []*SyncTaskResult) ( repChanges, bans, validRes = validateResults(results, w.badBlocks) if len(validRes) > 0 && validRes[0].responseData != nil && len(validRes[0].responseData) > 0 { - w.result = *validRes[0].responseData[0] + w.result = *validRes[0].responseData[0].Header w.phase = Completed } } @@ -234,7 +234,7 @@ func (w *WarpSyncStrategy) validateWarpSyncResults(results []*SyncTaskResult) ( res, err := w.warpSyncProvider.Verify(encodedProof, w.setId, w.authorities) if err != nil { - logger.Warnf("bad warp proof response: %s", err) + logger.Debugf("bad warp proof response: %s", err) repChanges = append(repChanges, Change{ who: result.who, @@ -267,7 +267,7 @@ func (w *WarpSyncStrategy) validateWarpSyncResults(results []*SyncTaskResult) ( return repChanges, peersToBlock, bestResult } -func (w *WarpSyncStrategy) ShowMetrics() { +func (w *WarpSyncStrategy) ShowStatus() { switch w.phase { case WarpProof: totalSyncSeconds := time.Since(w.startedAt).Seconds() @@ -279,7 +279,6 @@ func (w *WarpSyncStrategy) ShowMetrics() { logger.Infof("⏩ Warping, downloading target block #%d (%s)", w.lastBlock.Number, w.lastBlock.Hash().String()) } - } func (w *WarpSyncStrategy) IsSynced() bool { diff --git a/dot/sync/worker_pool.go b/dot/sync/worker_pool.go index b11b726db7..b776693f24 100644 --- a/dot/sync/worker_pool.go +++ b/dot/sync/worker_pool.go @@ -149,14 +149,14 @@ func (s *syncWorkerPool) submitRequests(tasks []*SyncTask) []*SyncTaskResult { func executeTask(task *SyncTask, workerPool chan peer.ID, failedTasks chan *SyncTask, results chan *SyncTaskResult) { worker := <-workerPool - logger.Infof("[EXECUTING] worker %s", worker) + logger.Debugf("[EXECUTING] worker %s", worker) err := task.requestMaker.Do(worker, task.request, task.response) if err != nil { - logger.Infof("[ERR] worker %s, request: %s, err: %s", worker, task.request.String(), err.Error()) + logger.Debugf("[ERR] worker %s, request: %s, err: %s", worker, task.request.String(), err.Error()) failedTasks <- task } else { - logger.Infof("[FINISHED] worker %s, request: %s", worker, task.request.String()) + logger.Debugf("[FINISHED] worker %s, request: %s", worker, task.request.String()) workerPool <- worker results <- &SyncTaskResult{ who: worker, diff --git a/internal/primitives/trie/proof.go b/internal/primitives/trie/proof.go new file mode 100644 index 0000000000..216bf37be6 --- /dev/null +++ b/internal/primitives/trie/proof.go @@ -0,0 +1,5 @@ +package trie + +type CompactProof struct { + EncodedNodes [][]byte +} diff --git a/scripts/retrieve_state/retrieve_state.go b/scripts/retrieve_state/retrieve_state.go index 3ac70afdd3..14f89fc9c3 100644 --- a/scripts/retrieve_state/retrieve_state.go +++ b/scripts/retrieve_state/retrieve_state.go @@ -14,10 +14,9 @@ import ( "os" "github.com/ChainSafe/gossamer/dot/network/messages" + "github.com/ChainSafe/gossamer/dot/sync" "github.com/ChainSafe/gossamer/lib/common" - "github.com/ChainSafe/gossamer/pkg/scale" "github.com/ChainSafe/gossamer/pkg/trie" - "github.com/ChainSafe/gossamer/pkg/trie/inmemory" "github.com/ChainSafe/gossamer/scripts/p2p" lip2pnetwork "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" @@ -26,7 +25,6 @@ import ( var ( errZeroLengthResponse = errors.New("zero length response") - errEmptyStateEntries = errors.New("empty state entries") supportedVersions = map[string]trie.TrieLayout{ "v0": trie.V0, @@ -34,104 +32,6 @@ var ( } ) -type StateRequestProvider struct { - lastKeys [][]byte - collectedResponses []*messages.StateResponse - targetHash common.Hash - completed bool -} - -func NewStateRequestProvider(target common.Hash) *StateRequestProvider { - return &StateRequestProvider{ - lastKeys: [][]byte{}, - targetHash: target, - collectedResponses: make([]*messages.StateResponse, 0), - } -} - -func (s *StateRequestProvider) buildRequest() *messages.StateRequest { - return &messages.StateRequest{ - Block: s.targetHash, - Start: s.lastKeys, - NoProof: true, - } -} - -func (s *StateRequestProvider) processResponse(stateResponse *messages.StateResponse) (err error) { - if len(stateResponse.Entries) == 0 { - return errEmptyStateEntries - } - - log.Printf("retrieved %d entries\n", len(stateResponse.Entries)) - for idx, entry := range stateResponse.Entries { - log.Printf("\t#%d with %d entries (complete: %v, root: %s)\n", - idx, len(entry.StateEntries), entry.Complete, entry.StateRoot.String()) - } - - s.collectedResponses = append(s.collectedResponses, stateResponse) - - if len(s.lastKeys) == 2 && len(stateResponse.Entries[0].StateEntries) == 0 { - // pop last item and keep the first - // do not remove the parent trie position. - s.lastKeys = s.lastKeys[:len(s.lastKeys)-1] - } else { - s.lastKeys = [][]byte{} - } - - for _, state := range stateResponse.Entries { - if !state.Complete { - lastItemInResponse := state.StateEntries[len(state.StateEntries)-1] - s.lastKeys = append(s.lastKeys, lastItemInResponse.Key) - s.completed = false - } else { - s.completed = true - } - } - - return nil -} - -func (s *StateRequestProvider) buildTrie(expectedStorageRootHash common.Hash, - destination string, v trie.TrieLayout) error { - tt := inmemory.NewEmptyTrie() - tt.SetVersion(v) - - entries := make([]string, 0) - - for _, stateResponse := range s.collectedResponses { - for _, stateEntry := range stateResponse.Entries { - for _, kv := range stateEntry.StateEntries { - - trieEntry := trie.Entry{Key: kv.Key, Value: kv.Value} - encodedTrieEntry, err := scale.Marshal(trieEntry) - if err != nil { - return err - } - entries = append(entries, common.BytesToHex(encodedTrieEntry)) - - if err := tt.Put(kv.Key, kv.Value); err != nil { - return err - } - } - } - } - - rootHash := tt.MustHash() - if expectedStorageRootHash != rootHash { - log.Printf("\n\texpected root hash: %s\ngot root hash: %s\n", - expectedStorageRootHash.String(), rootHash.String()) - } - - fmt.Printf("=> trie root hash: %s\n", tt.MustHash().String()) - encodedEntries, err := json.Marshal(entries) - if err != nil { - return err - } - - err = os.WriteFile(destination, encodedEntries, 0o600) - return err -} - func main() { if len(os.Args) != 6 { log.Fatalf(`script usage: @@ -154,14 +54,14 @@ func main() { p2pHost := p2p.SetupP2PClient() bootnodes := p2p.ParseBootnodes(chain.Bootnodes) - provider := NewStateRequestProvider(targetBlockHash) + provider := sync.NewStateRequestProvider(targetBlockHash, version) var ( pid peer.AddrInfo refreshPeerID bool = true ) - for !provider.completed { + for !provider.IsCompleted() { if refreshPeerID { rng, err := rand.Int(rand.Reader, big.NewInt(int64(len(bootnodes)))) if err != nil { @@ -196,15 +96,33 @@ func main() { refreshPeerID = false } - if err := provider.buildTrie(expectedStorageRootHash, os.Args[5], version); err != nil { + trie, err := provider.BuildTrie() + if err != nil { + panic(err) + } + + if trie.MustHash() != expectedStorageRootHash { + panic(fmt.Errorf("ERR: state root mismatch: got %s expected %s", trie.MustHash(), expectedStorageRootHash)) + } + + encodedEntries, err := json.Marshal(trie.Entries()) + if err != nil { + panic(err) + } + + err = os.WriteFile(os.Args[5], encodedEntries, 0o600) + if err != nil { panic(err) } } -func sendAndProcessResponse(provider *StateRequestProvider, stream lip2pnetwork.Stream) error { +func sendAndProcessResponse(provider *sync.StateRequestProvider, stream lip2pnetwork.Stream) error { defer stream.Close() //nolint:errcheck - err := p2p.WriteStream(provider.buildRequest(), stream) + request := provider.BuildRequest() + + fmt.Printf("Sending request to peer %s\n", request.String()) + err := p2p.WriteStream(provider.BuildRequest(), stream) if err != nil { return err } @@ -224,7 +142,7 @@ func sendAndProcessResponse(provider *StateRequestProvider, stream lip2pnetwork. return err } - err = provider.processResponse(stateResponse) + _, err = provider.ProcessResponse(stateResponse) if err != nil { return err }