diff --git a/src/query/storages/common/table_meta/src/meta/v4/snapshot.rs b/src/query/storages/common/table_meta/src/meta/v4/snapshot.rs index d4f71de01e3d0..4f6ea9c3306c6 100644 --- a/src/query/storages/common/table_meta/src/meta/v4/snapshot.rs +++ b/src/query/storages/common/table_meta/src/meta/v4/snapshot.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; use std::io::Cursor; use std::io::Read; use std::sync::Arc; @@ -143,6 +144,8 @@ impl TableSnapshot { return Err(ErrorCode::TransactionTimeout(err_msg)); } + ensure_segments_unique(&segments)?; + Ok(Self { format_version: TableSnapshot::VERSION, snapshot_id: uuid_from_date_time(snapshot_timestamp_adjusted), @@ -244,8 +247,11 @@ impl TableSnapshot { let compression = MetaCompression::try_from(r.read_scalar::()?)?; let snapshot_size: u64 = r.read_scalar::()?; - read_and_deserialize(&mut r, snapshot_size, &encoding, &compression) - .map_err(|x| x.add_message("fail to deserialize table snapshot")) + let snapshot: TableSnapshot = + read_and_deserialize(&mut r, snapshot_size, &encoding, &compression) + .map_err(|x| x.add_message("fail to deserialize table snapshot"))?; + snapshot.ensure_segments_unique()?; + Ok(snapshot) } #[inline] @@ -257,11 +263,36 @@ impl TableSnapshot { pub fn table_statistics_location(&self) -> Option { self.table_statistics_location.clone() } + + #[inline] + pub fn ensure_segments_unique(&self) -> Result<()> { + ensure_segments_unique(&self.segments) + } +} + +fn ensure_segments_unique(segments: &[Location]) -> Result<()> { + if segments.len() < 2 { + return Ok(()); + } + + let mut seen = HashSet::with_capacity(segments.len()); + for loc in segments { + let key = loc.0.as_str(); + if !seen.insert(key) { + log::warn!( + "duplicate segment location {} detected while constructing snapshot", + key + ); + } + } + Ok(()) } // use the chain of converters, for versions before v3 impl From for TableSnapshot { fn from(s: v2::TableSnapshot) -> Self { + ensure_segments_unique(&s.segments) + .expect("duplicate segment location found while converting snapshot from v2"); Self { // NOTE: it is important to let the format_version return from here // carries the format_version of snapshot being converted. @@ -284,6 +315,8 @@ where T: Into { fn from(s: T) -> Self { let s: v3::TableSnapshot = s.into(); + ensure_segments_unique(&s.segments) + .expect("duplicate segment location found while converting snapshot from v3"); Self { // NOTE: it is important to let the format_version return from here // carries the format_version of snapshot being converted. diff --git a/src/query/storages/fuse/src/operations/common/processors/multi_table_insert_commit.rs b/src/query/storages/fuse/src/operations/common/processors/multi_table_insert_commit.rs index 424c2d97151f7..36b4991897c1c 100644 --- a/src/query/storages/fuse/src/operations/common/processors/multi_table_insert_commit.rs +++ b/src/query/storages/fuse/src/operations/common/processors/multi_table_insert_commit.rs @@ -280,6 +280,7 @@ async fn build_update_table_meta_req( table_meta_timestamps, table_stats_gen, )?; + snapshot.ensure_segments_unique()?; // write snapshot let dal = fuse_table.get_operator(); diff --git a/src/query/storages/fuse/src/operations/common/processors/sink_commit.rs b/src/query/storages/fuse/src/operations/common/processors/sink_commit.rs index b005f561a08c5..12322bf93c46e 100644 --- a/src/query/storages/fuse/src/operations/common/processors/sink_commit.rs +++ b/src/query/storages/fuse/src/operations/common/processors/sink_commit.rs @@ -527,6 +527,7 @@ where F: SnapshotGenerator + Send + Sync + 'static snapshot, table_info, } => { + snapshot.ensure_segments_unique()?; let location = self .location_gen .snapshot_location_from_uuid(&snapshot.snapshot_id, TableSnapshot::VERSION)?; diff --git a/src/query/storages/fuse/src/retry/commit.rs b/src/query/storages/fuse/src/retry/commit.rs index e61e774b45ae6..1833d1f1a1283 100644 --- a/src/query/storages/fuse/src/retry/commit.rs +++ b/src/query/storages/fuse/src/retry/commit.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; @@ -39,6 +40,8 @@ use crate::statistics::merge_statistics; use crate::statistics::reducers::deduct_statistics; use crate::FuseTable; +const FUSE_ENGINE: &str = "FUSE"; + pub async fn commit_with_backoff( ctx: Arc, mut req: UpdateMultiTableMetaReq, @@ -47,6 +50,13 @@ pub async fn commit_with_backoff( let mut backoff = set_backoff(None, None, None); let mut retries = 0; + // Compute segments diff for all tables before entering the retry loop. + // This diff represents the actual changes made by the transaction (base -> txn_generated), + // and remains constant across all retries. + // Also cache the original snapshots for statistics merging. + let (table_segments_diffs, table_original_snapshots) = + compute_table_segments_diffs(ctx.clone(), &req).await?; + loop { let ret = catalog .retryable_update_multi_table_meta(req.clone()) @@ -63,14 +73,88 @@ pub async fn commit_with_backoff( }; sleep(duration).await; retries += 1; - try_rebuild_req(ctx.clone(), &mut req, update_failed_tbls).await?; + try_rebuild_req( + ctx.clone(), + &mut req, + update_failed_tbls, + &table_segments_diffs, + &table_original_snapshots, + ) + .await?; } } +async fn compute_table_segments_diffs( + ctx: Arc, + req: &UpdateMultiTableMetaReq, +) -> Result<( + HashMap, + HashMap>>, +)> { + let txn_mgr = ctx.txn_mgr(); + let storage_class = ctx.get_settings().get_s3_storage_class()?; + let mut table_segments_diffs = HashMap::new(); + let mut table_original_snapshots = HashMap::new(); + + for (update_table_meta_req, _) in &req.update_table_metas { + let tid = update_table_meta_req.table_id; + let engine = update_table_meta_req.new_table_meta.engine.as_str(); + + if engine != FUSE_ENGINE { + log::info!( + "Skipping segments diff pre-compute for table {} with engine {}", + tid, + engine + ); + continue; + } + + // Read the base snapshot (snapshot at transaction begin) + let base_snapshot_location = txn_mgr.lock().get_base_snapshot_location(tid); + + // Read the transaction-generated snapshot (original snapshot before any merge) + let new_table = FuseTable::from_table_meta( + update_table_meta_req.table_id, + 0, + update_table_meta_req.new_table_meta.clone(), + storage_class, + )?; + + let base_snapshot = new_table + .read_table_snapshot_with_location(base_snapshot_location) + .await?; + let new_snapshot = new_table.read_table_snapshot().await?; + + let base_segments = base_snapshot + .as_ref() + .map(|s| s.segments.as_slice()) + .unwrap_or(&[]); + let new_segments = new_snapshot + .as_ref() + .map(|s| s.segments.as_slice()) + .unwrap_or(&[]); + + info!( + "Computing segments diff for table {} (base: {} segments, txn: {} segments)", + tid, + base_segments.len(), + new_segments.len() + ); + + let diff = SegmentsDiff::new(base_segments, new_segments); + table_segments_diffs.insert(tid, diff); + table_original_snapshots.insert(tid, new_snapshot); + } + + Ok((table_segments_diffs, table_original_snapshots)) +} + async fn try_rebuild_req( ctx: Arc, req: &mut UpdateMultiTableMetaReq, update_failed_tbls: Vec<(u64, u64, TableMeta)>, + table_segments_diffs: &HashMap, + table_original_snapshots: &HashMap>>, ) -> Result<()> { info!( "try_rebuild_req: update_failed_tbls={:?}", @@ -98,26 +182,35 @@ async fn try_rebuild_req( .iter_mut() .find(|(meta, _)| meta.table_id == tid) .unwrap(); - let new_table = FuseTable::from_table_meta( - update_table_meta_req.table_id, - 0, - update_table_meta_req.new_table_meta.clone(), - storage_class, - )?; - let new_snapshot = new_table.read_table_snapshot().await?; + let base_snapshot_location = txn_mgr.lock().get_base_snapshot_location(tid); - let base_snapshot = new_table - .read_table_snapshot_with_location(base_snapshot_location) + let base_snapshot = latest_table + .read_table_snapshot_with_location(base_snapshot_location.clone()) .await?; - let segments_diff = SegmentsDiff::new(base_snapshot.segments(), new_snapshot.segments()); - let Some(merged_segments) = segments_diff.apply(latest_snapshot.segments().to_vec()) else { + // Get the pre-computed segments diff for this table (computed before retry loop) + let segments_diff = table_segments_diffs.get(&tid).ok_or_else(|| { + ErrorCode::Internal(format!("Missing segments diff for table {}", tid)) + })?; + + let Some(merged_segments) = segments_diff + .clone() + .apply(latest_snapshot.segments().to_vec()) + else { return Err(ErrorCode::UnresolvableConflict(format!( "Unresolvable conflict detected for table {}", tid ))); }; + // Read the original transaction-generated snapshot from cache for statistics merging + let new_snapshot = table_original_snapshots + .get(&tid) + .ok_or_else(|| { + ErrorCode::Internal(format!("Missing original snapshot for table {}", tid)) + })? + .clone(); + let s = merge_statistics( new_snapshot.summary(), &latest_snapshot.summary(), @@ -214,6 +307,7 @@ async fn try_rebuild_req( latest_snapshot.table_statistics_location(), table_meta_timestamps, )?; + merged_snapshot.ensure_segments_unique()?; // write snapshot let dal = latest_table.get_operator(); diff --git a/src/query/storages/fuse/src/retry/diff.rs b/src/query/storages/fuse/src/retry/diff.rs index 352bc6dc50426..b68877022f7b6 100644 --- a/src/query/storages/fuse/src/retry/diff.rs +++ b/src/query/storages/fuse/src/retry/diff.rs @@ -17,6 +17,7 @@ use std::collections::HashSet; use databend_storages_common_table_meta::meta::Location; +#[derive(Clone)] pub struct SegmentsDiff { appended: Vec, replaced: HashMap>, diff --git a/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.py b/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.py new file mode 100755 index 0000000000000..4422e3235ec9d --- /dev/null +++ b/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python3 + +import os +import time +from concurrent.futures import ThreadPoolExecutor, as_completed +from random import Random + +import mysql.connector +from mysql.connector import errors as mysql_errors + + +TABLE_NAME = "txn_snapshot_retry_concurrency" +NUM_THREADS = 16 +TRANSACTIONS_PER_THREAD = 4 +ROWS_PER_TRANSACTION = 2 +MAX_RETRIES = 8 +RETRY_SLEEP_RANGE = (0.01, 0.05) +VALUE_GAP = 1_000_000 + +HOST = os.getenv("QUERY_MYSQL_HANDLER_HOST", "127.0.0.1") +PORT = int(os.getenv("QUERY_MYSQL_HANDLER_PORT", "3307")) +USER = os.getenv("QUERY_MYSQL_HANDLER_USER", "root") +PASSWORD = os.getenv("QUERY_MYSQL_HANDLER_PASSWORD", "root") + + +def create_connection(): + conn = mysql.connector.connect( + host=HOST, port=PORT, user=USER, passwd=PASSWORD, autocommit=False + ) + cursor = conn.cursor() + return conn, cursor + + +def drain(cursor): + try: + cursor.fetchall() + except mysql_errors.InterfaceError: + pass + + +def run_transaction_batch(thread_id: int) -> None: + conn, cursor = create_connection() + rng = Random(thread_id) + + try: + for tx_index in range(TRANSACTIONS_PER_THREAD): + base_value = ( + thread_id * VALUE_GAP + tx_index * ROWS_PER_TRANSACTION + ) + values_clause = ", ".join( + f"({base_value + offset})" for offset in range(ROWS_PER_TRANSACTION) + ) + + attempts = 0 + while True: + attempts += 1 + try: + cursor.execute("BEGIN") + drain(cursor) + + cursor.execute( + f"INSERT INTO {TABLE_NAME} VALUES {values_clause}" + ) + drain(cursor) + + cursor.execute("COMMIT") + drain(cursor) + break + except Exception: + try: + cursor.execute("ROLLBACK") + drain(cursor) + except Exception: + pass + + cursor.close() + conn.close() + + if attempts >= MAX_RETRIES: + raise + + time.sleep(rng.uniform(*RETRY_SLEEP_RANGE)) + conn, cursor = create_connection() + finally: + cursor.close() + conn.close() + + +def main() -> None: + setup_conn, setup_cursor = create_connection() + try: + setup_cursor.execute(f"DROP TABLE IF EXISTS {TABLE_NAME}") + drain(setup_cursor) + setup_cursor.execute(f"CREATE TABLE {TABLE_NAME} (id BIGINT)") + drain(setup_cursor) + finally: + setup_cursor.close() + setup_conn.close() + + with ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: + futures = [ + executor.submit(run_transaction_batch, thread_id) + for thread_id in range(NUM_THREADS) + ] + for future in as_completed(futures): + future.result() + + verify_conn, verify_cursor = create_connection() + try: + expected_rows = NUM_THREADS * TRANSACTIONS_PER_THREAD * ROWS_PER_TRANSACTION + + verify_cursor.execute( + f"SELECT COUNT(*) AS cnt, COUNT(DISTINCT id) AS uniq FROM {TABLE_NAME}" + ) + counts = verify_cursor.fetchall()[0] + total_count, distinct_count = counts[0], counts[1] + + if total_count != expected_rows or distinct_count != expected_rows: + raise AssertionError( + f"Expected {expected_rows} rows, got total={total_count}, distinct={distinct_count}" + ) + + verify_cursor.execute( + f"SELECT id FROM {TABLE_NAME} GROUP BY id HAVING COUNT(*) > 1 LIMIT 1" + ) + duplicates = verify_cursor.fetchall() + if duplicates: + raise AssertionError(f"found duplicated segments: {duplicates}") + + verify_cursor.execute(f"DROP TABLE IF EXISTS {TABLE_NAME}") + drain(verify_cursor) + finally: + verify_cursor.close() + verify_conn.close() + + print("Transaction snapshot retry looks good") + + +if __name__ == "__main__": + main() diff --git a/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.result b/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.result new file mode 100644 index 0000000000000..b3032ce4108a7 --- /dev/null +++ b/tests/suites/0_stateless/01_transaction/01_04_txn_snapshot_retry.result @@ -0,0 +1 @@ +Transaction snapshot retry looks good