diff --git a/src/rpc/blockchain.cpp b/src/rpc/blockchain.cpp index bd5deedf6ee..db8d922a2a7 100644 --- a/src/rpc/blockchain.cpp +++ b/src/rpc/blockchain.cpp @@ -3001,13 +3001,64 @@ class TemporaryRollback { ChainstateManager& m_chainman; const CBlockIndex& m_invalidate_index; + std::vector m_invalidated_fork_blocks; + public: TemporaryRollback(ChainstateManager& chainman, const CBlockIndex& index) : m_chainman(chainman), m_invalidate_index(index) { + // First, invalidate any competing fork blocks to prevent reorg during main chain invalidation + InvalidateCompetingForks(); + + // Then invalidate the main chain block for rollback InvalidateBlock(m_chainman, m_invalidate_index.GetBlockHash()); }; + ~TemporaryRollback() { + // Restore main chain block first ReconsiderBlock(m_chainman, m_invalidate_index.GetBlockHash()); + + // Then restore all fork blocks + for (const uint256& fork_hash : m_invalidated_fork_blocks) { + ReconsiderBlock(m_chainman, fork_hash); + } }; + +private: + void InvalidateCompetingForks() { + LOCK(m_chainman.GetMutex()); + + // Find the target height (the height we want to roll back to) + const CBlockIndex* target_index = m_invalidate_index.pprev; + if (!target_index) return; // Genesis block case + + const int target_height = target_index->nHeight; + + // Iterate through all known block indices to find competing forks + for (const auto& [hash, block_index] : m_chainman.m_blockman.m_block_index) { + // Skip if this block is on the active chain + if (m_chainman.ActiveChain().Contains(&block_index)) continue; + + // Skip if this block is at or below the target height + if (block_index.nHeight <= target_height) continue; + + // Skip if this block doesn't have valid data + if (!(block_index.nStatus & BLOCK_HAVE_DATA)) continue; + + // Check if this fork block could interfere with rollback + // by tracing back to see if it forks at or after the target height + const CBlockIndex* fork_ancestor = &block_index; + while (fork_ancestor && fork_ancestor->nHeight > target_height) { + fork_ancestor = fork_ancestor->pprev; + } + + // If we can trace this fork back to the target height or below, + // and it's not on the active chain, it's a competing fork + if (fork_ancestor && fork_ancestor->nHeight <= target_height) { + // Invalidate this fork block to prevent reorg + InvalidateBlock(m_chainman, hash); + m_invalidated_fork_blocks.push_back(hash); + } + } + } }; /** diff --git a/test/functional/rpc_dumptxoutset_forks.py b/test/functional/rpc_dumptxoutset_forks.py index ed6ea47cb52..ef6b30cc6c9 100644 --- a/test/functional/rpc_dumptxoutset_forks.py +++ b/test/functional/rpc_dumptxoutset_forks.py @@ -74,21 +74,21 @@ class DumptxoutsetForksTest(BitcoinTestFramework): assert_equal(active_tip['height'], 18) return active_tip, fork_tips - def test_rollback_with_forks(self, target_height): - """Test that dumptxoutset rollback fails when competing forks are present.""" + def test_rollback_with_forks(self, target_height, target_hash): + """Test that dumptxoutset rollback works correctly even when competing forks are present.""" self.log.info("Testing dumptxoutset rollback with competing forks present") original_tip = self.nodes[0].getbestblockhash() original_height = self.nodes[0].getblockcount() - assert_raises_rpc_error( - -1, - "Could not roll back to requested height", - self.nodes[0].dumptxoutset, - "fork_test_utxo.dat", - rollback=target_height - ) + # This should now work correctly with our fix + result = self.nodes[0].dumptxoutset("fork_test_utxo.dat", rollback=target_height) + # Verify the snapshot was created from the correct block on the main chain + assert_equal(result['base_height'], target_height) + assert_equal(result['base_hash'], target_hash) + + # Verify node state is restored after successful rollback current_tip = self.nodes[0].getbestblockhash() current_height = self.nodes[0].getblockcount() assert_equal(current_tip, original_tip) @@ -108,7 +108,7 @@ class DumptxoutsetForksTest(BitcoinTestFramework): self.verify_fork_visibility() # Test the main functionality - self.test_rollback_with_forks(target_height) + self.test_rollback_with_forks(target_height, target_hash) if __name__ == '__main__':