diff --git a/Cargo.toml b/Cargo.toml index 48fd4ca..08c22d3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["codegen", "examples", "performance_measurement", "performance_measur [package] name = "worktable" -version = "0.8.22" +version = "0.8.23" edition = "2024" authors = ["Handy-caT"] license = "MIT" @@ -16,7 +16,7 @@ perf_measurements = ["dep:performance_measurement", "dep:performance_measurement # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -worktable_codegen = { path = "codegen", version = "=0.8.22" } +worktable_codegen = { path = "codegen", version = "=0.8.23" } async-trait = "0.1.89" eyre = "0.6.12" diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index 5c05d55..ec2d99f 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "worktable_codegen" -version = "0.8.22" +version = "0.8.23" edition = "2024" license = "MIT" description = "WorkTable codegeneration crate" diff --git a/codegen/src/worktable/generator/queries/delete.rs b/codegen/src/worktable/generator/queries/delete.rs index b0a1c93..4b2f794 100644 --- a/codegen/src/worktable/generator/queries/delete.rs +++ b/codegen/src/worktable/generator/queries/delete.rs @@ -44,15 +44,15 @@ impl Generator { where #pk_ident: From { let pk: #pk_ident = pk.into(); - let lock = { - #full_row_lock - }; + let op_lock = { #full_row_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); #delete_logic - lock.unlock(); // Releases locks - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks - core::result::Result::Ok(()) } } @@ -113,8 +113,6 @@ impl Generator { .ok_or(WorkTableError::NotFound) { Ok(l) => l, Err(e) => { - lock.unlock(); // Releases locks - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks return Err(e); } }; diff --git a/codegen/src/worktable/generator/queries/in_place.rs b/codegen/src/worktable/generator/queries/in_place.rs index efe4ca3..a3aa6ea 100644 --- a/codegen/src/worktable/generator/queries/in_place.rs +++ b/codegen/src/worktable/generator/queries/in_place.rs @@ -109,9 +109,12 @@ impl Generator { where #pk_type: From { let pk: #pk_type = by.into(); - let lock = { - #custom_lock - }; + let op_lock = { #custom_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); let link = self .0 .primary_index.pk_map @@ -125,9 +128,6 @@ impl Generator { .map_err(WorkTableError::PagesError)? }; - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - Ok(()) } } diff --git a/codegen/src/worktable/generator/queries/update.rs b/codegen/src/worktable/generator/queries/update.rs index 952fd8f..fa4161f 100644 --- a/codegen/src/worktable/generator/queries/update.rs +++ b/codegen/src/worktable/generator/queries/update.rs @@ -64,21 +64,21 @@ impl Generator { } else { quote! { if true { - lock.unlock(); // Releases locks - let lock = { - #full_row_lock - }; + drop(_guard); + let op_lock = { #full_row_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); let row_old = self.0.data.select_non_ghosted(link)?; if let Err(e) = self.reinsert(row_old, row).await { self.0.update_state.remove(&pk); - lock.unlock(); return Err(e); } self.0.update_state.remove(&pk); - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks return core::result::Result::Ok(()); } @@ -88,25 +88,19 @@ impl Generator { quote! { pub async fn update(&self, row: #row_ident) -> core::result::Result<(), WorkTableError> { let pk = row.get_primary_key(); - let lock = { - #full_row_lock - }; - - let mut link: Link = match self.0 + let op_lock = { #full_row_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); + + let mut link: Link = self.0 .primary_index .pk_map .get(&pk) .map(|v| v.get().value.into()) - .ok_or(WorkTableError::NotFound) - { - Ok(l) => l, - Err(e) => { - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - - return Err(e); - } - }; + .ok_or(WorkTableError::NotFound)?; let row_old = self.0.data.select_non_ghosted(link)?; self.0.update_state.insert(pk.clone(), row_old); @@ -128,9 +122,6 @@ impl Generator { self.0.update_state.remove(&pk); - lock.unlock(); // Releases locks - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks - #persist_call core::result::Result::Ok(()) @@ -266,25 +257,23 @@ impl Generator { let mut need_to_reinsert = true; #(#fields_check)* if need_to_reinsert { - lock.unlock(); - let lock = { - #full_row_lock - }; + drop(_guard); + let op_lock = { #full_row_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); let row_old = self.0.select(pk.clone()).expect("should not be deleted by other thread"); let mut row_new = row_old.clone(); - let pk = row_old.get_primary_key().clone(); #(#row_updates)* if let Err(e) = self.reinsert(row_old, row_new).await { self.0.update_state.remove(&pk); - lock.unlock(); return Err(e); } - lock.unlock(); // Releases locks - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks - return core::result::Result::Ok(()); } } @@ -475,24 +464,19 @@ impl Generator { where #pk_ident: From { let pk = pk.into(); - let lock = { - #custom_lock - }; - - let mut link: Link = match self.0 + let op_lock = { #custom_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); + + let mut link: Link = self.0 .primary_index .pk_map .get(&pk) .map(|v| v.get().value.into()) - .ok_or(WorkTableError::NotFound) { - Ok(l) => l, - Err(e) => { - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - - return Err(e); - } - }; + .ok_or(WorkTableError::NotFound)?; let mut bytes = rkyv::to_bytes::(&row).map_err(|_| WorkTableError::SerializeError)?; let mut archived_row = unsafe { rkyv::access_unchecked_mut::<<#query_ident as rkyv::Archive>::Archived>(&mut bytes[..]).unseal_unchecked() }; @@ -508,9 +492,6 @@ impl Generator { #diff_process_remove - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - #persist_call core::result::Result::Ok(()) @@ -569,27 +550,24 @@ impl Generator { let mut need_to_reinsert = true; #(#fields_check)* if need_to_reinsert { - let op_lock = locks.remove(&pk).expect("should not be deleted as links are unique"); - op_lock.unlock(); - let lock = { - #full_row_lock - }; + let old_guard = guards.remove(&pk).expect("guard should exist for this pk"); + drop(old_guard); + + let op_lock = { #full_row_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); let row_old = self.0.select(pk.clone()).expect("should not be deleted by other thread"); let mut row_new = row_old.clone(); #(#row_updates)* if let Err(e) = self.reinsert(row_old, row_new).await { self.0.update_state.remove(&pk); - lock.unlock(); - return Err(e); } - lock.unlock(); // Releases locks - self.0.lock_manager.remove_with_lock_check(&pk); // Removes locks - continue; - } else { - pk_to_unlock.insert(pk.clone(), locks.remove(&pk).expect("should not be deleted as links are unique")); } } } else { @@ -614,17 +592,14 @@ impl Generator { pub async fn #method_ident(&self, row: #query_ident, by: #by_ident) -> core::result::Result<(), WorkTableError> { let links: Vec<_> = self.0.indexes.#index.get(#by).map(|(_, l)| l.0).collect(); - let mut locks = std::collections::HashMap::new(); + let mut guards: std::collections::HashMap<_, _> = std::collections::HashMap::new(); for link in links.iter() { let pk = self.0.data.select_non_ghosted(*link)?.get_primary_key().clone(); - let op_lock = { - #custom_lock - }; - locks.insert(pk, op_lock); + let op_lock = { #custom_lock }; + guards.insert(pk.clone(), LockGuard::new(op_lock, self.0.lock_manager.clone(), pk.clone())); } let links: Vec<_> = self.0.indexes.#index.get(#by).map(|(_, l)| l.0).collect(); - let mut pk_to_unlock: std::collections::HashMap<_, std::sync::Arc> = std::collections::HashMap::new(); let op_id = OperationId::Multi(uuid::Uuid::now_v7()); for link in links.into_iter() { let pk = self.0.data.select_non_ghosted(link)?.get_primary_key().clone(); @@ -649,10 +624,8 @@ impl Generator { #diff_process_remove #persist_call - } - for (pk, lock) in pk_to_unlock { - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); + + guards.remove(&pk); } core::result::Result::Ok(()) } @@ -719,23 +692,18 @@ impl Generator { let pk = self.0.data.select_non_ghosted(link)?.get_primary_key().clone(); - let lock = { - #custom_lock - }; + let op_lock = { #custom_lock }; + let _guard = LockGuard::new( + op_lock, + self.0.lock_manager.clone(), + pk.clone(), + ); let link = loop { - let link = match self.0.indexes.#index + let link = self.0.indexes.#index .get(#by) .map(|v| v.get().value.into()) - .ok_or(WorkTableError::NotFound) { - Ok(l) => l, - Err(e) => { - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - - return Err(e); - } - }; + .ok_or(WorkTableError::NotFound)?; if let Err(e) = self.0.data.select_non_vacuumed(link) { if e.is_vacuumed() { @@ -760,9 +728,6 @@ impl Generator { #diff_process_remove - lock.unlock(); - self.0.lock_manager.remove_with_lock_check(&pk); - #persist_call core::result::Result::Ok(()) diff --git a/src/lib.rs b/src/lib.rs index 9e2648b..230b259 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,8 +19,9 @@ pub mod prelude { pub use crate::in_memory::{ ArchivedRowWrapper, Data, DataPages, Query, RowWrapper, StorableRow, }; - pub use crate::lock::LockMap; + pub use crate::lock::FullRowLock; pub use crate::lock::{Lock, RowLock}; + pub use crate::lock::{LockGuard, LockMap}; pub use crate::mem_stat::MemStat; pub use crate::persistence::{ DeleteOperation, IndexTableOfContents, InsertOperation, Operation, OperationId, diff --git a/src/lock/mod.rs b/src/lock/mod.rs index 7dbe0d0..7cee60b 100644 --- a/src/lock/mod.rs +++ b/src/lock/mod.rs @@ -1,8 +1,11 @@ mod map; mod row_lock; +use std::cell::Cell; +use std::fmt::Debug; use std::future::Future; use std::hash::{Hash, Hasher}; +use std::marker::PhantomData; use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; @@ -14,6 +17,60 @@ use parking_lot::Mutex; pub use map::LockMap; pub use row_lock::{FullRowLock, RowLock}; +/// RAII guard that automatically unlocks a [`Lock`] when dropped. +/// +/// The [`Lock`] is automatically released when the [`LockGuard`] is +/// [`Drop`]ped, or can be explicitly released early using the `unlock()` +/// method. +/// +/// The guard will also attempt to remove the lock entry from the map on drop +/// (preventing memory leaks). +pub struct LockGuard { + lock: Arc, + lock_map: Arc>, + primary_key: PrimaryKey, + /// Marker to make this type ![`Sync`] (but still [`Send`]) + _not_sync: PhantomData>, +} + +impl LockGuard +where + LockType: RowLock, + PrimaryKey: Hash + Eq + Debug + Clone, +{ + /// Creates a new [`LockGuard`] that will clean up the [`Lock`] entry from + /// the [`LockMap`] on [`Drop`]. + pub fn new( + lock: Arc, + lock_map: Arc>, + primary_key: PrimaryKey, + ) -> Self { + Self { + lock, + lock_map, + primary_key, + _not_sync: PhantomData, + } + } + + /// Explicitly unlocks the [`Lock`] before the [`LockGuard`] is [`Drop`]ped. + pub fn unlock(self) { + self.lock.unlock(); + self.lock_map.remove_with_lock_check(&self.primary_key); + } +} + +impl Drop for LockGuard +where + LockType: RowLock, + PrimaryKey: Hash + Eq + Debug + Clone, +{ + fn drop(&mut self) { + self.lock.unlock(); + self.lock_map.remove_with_lock_check(&self.primary_key); + } +} + #[derive(Debug)] pub struct Lock { id: u16, @@ -104,3 +161,114 @@ impl Future for LockWait { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::panic::AssertUnwindSafe; + + #[test] + fn test_unlock_on_drop() { + let lock = Arc::new(Lock::new(1)); + let lock_map: Arc> = Arc::new(LockMap::default()); + let pk = 1u64; + assert!(lock.is_locked()); + + { + let _guard = LockGuard::::new(lock.clone(), lock_map.clone(), pk); + assert!(lock.is_locked()); + } + + assert!(!lock.is_locked()); + } + + #[test] + fn test_explicit_unlock() { + let lock = Arc::new(Lock::new(1)); + let lock_map: Arc> = Arc::new(LockMap::default()); + let pk = 1u64; + assert!(lock.is_locked()); + + let guard = LockGuard::::new(lock.clone(), lock_map.clone(), pk); + assert!(lock.is_locked()); + + guard.unlock(); + + assert!(!lock.is_locked()); + } + + #[test] + fn test_panic_releases_lock() { + let lock = Arc::new(Lock::new(1)); + let lock_map: Arc> = Arc::new(LockMap::default()); + let pk = 1u64; + assert!(lock.is_locked()); + + let result = std::panic::catch_unwind(AssertUnwindSafe(|| { + let _guard = LockGuard::::new(lock.clone(), lock_map.clone(), pk); + panic!("test panic"); + })); + + assert!(result.is_err()); + + assert!(!lock.is_locked()); + } + + #[test] + fn test_multiple_guards_can_be_held() { + let lock1 = Arc::new(Lock::new(1)); + let lock2 = Arc::new(Lock::new(2)); + let lock3 = Arc::new(Lock::new(3)); + let lock_map: Arc> = Arc::new(LockMap::default()); + + assert!(lock1.is_locked()); + assert!(lock2.is_locked()); + assert!(lock3.is_locked()); + + { + let _guard1 = LockGuard::::new(lock1.clone(), lock_map.clone(), 1u64); + let _guard2 = LockGuard::::new(lock2.clone(), lock_map.clone(), 2u64); + let _guard3 = LockGuard::::new(lock3.clone(), lock_map.clone(), 3u64); + + assert!(lock1.is_locked()); + assert!(lock2.is_locked()); + assert!(lock3.is_locked()); + } + + assert!(!lock1.is_locked()); + assert!(!lock2.is_locked()); + assert!(!lock3.is_locked()); + } + + #[test] + fn test_guard_is_send() { + fn assert_send() {} + // LockGuard is Send if LockType and PrimaryKey are Send + assert_send::>(); + } + + #[tokio::test] + async fn test_lock_cleanup_on_guard_drop() { + use crate::lock::FullRowLock; + use crate::lock::RowLock; + + let lock_map: Arc> = Arc::new(LockMap::default()); + let pk = 42u64; + + // Create and insert a lock + let (lock_type, lock) = FullRowLock::with_lock(lock_map.next_id()); + let rw_lock = Arc::new(tokio::sync::RwLock::new(lock_type)); + lock_map.insert(pk, rw_lock); + + // Verify the lock is in the map + assert!(lock_map.get(&pk).is_some()); + + // Create a guard and drop it + { + let _guard = LockGuard::new(lock, lock_map.clone(), pk); + } + + // Verify the lock entry was removed from the map + assert!(lock_map.get(&pk).is_none()); + } +} diff --git a/src/lock/row_lock.rs b/src/lock/row_lock.rs index 7b75ba4..e874373 100644 --- a/src/lock/row_lock.rs +++ b/src/lock/row_lock.rs @@ -1,7 +1,10 @@ -use crate::lock::{Lock, LockWait}; use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; use std::sync::Arc; +use crate::lock::{Lock, LockGuard, LockMap, LockWait}; + pub trait RowLock { /// Checks if any column of this row is locked. fn is_locked(&self) -> bool; @@ -32,6 +35,16 @@ impl FullRowLock { self.l.unlock(); } + /// Creates a [`LockGuard`] that will automatically unlock this lock when + /// dropped. + pub fn guard( + self, + lock_map: Arc>, + primary_key: PrimaryKey, + ) -> LockGuard { + LockGuard::new(self.l, lock_map, primary_key) + } + pub fn wait(&self) -> LockWait { self.l.wait() } diff --git a/src/table/vacuum/vacuum.rs b/src/table/vacuum/vacuum.rs index 1663885..e93291a 100644 --- a/src/table/vacuum/vacuum.rs +++ b/src/table/vacuum/vacuum.rs @@ -16,7 +16,8 @@ use rkyv::util::AlignedVec; use rkyv::{Archive, Deserialize, Serialize}; use crate::in_memory::{ArchivedRowWrapper, DataPages, RowWrapper, StorableRow}; -use crate::prelude::{Lock, LockMap, OffsetEqLink, RowLock, TablePrimaryKey}; +use crate::lock::{Lock, LockMap, RowLock}; +use crate::prelude::{OffsetEqLink, TablePrimaryKey}; use crate::vacuum::VacuumStats; use crate::vacuum::WorkTableVacuum; use crate::vacuum::fragmentation_info::FragmentationInfo; @@ -269,8 +270,9 @@ where .save_raw_row(&raw_data) .expect("page is not full as checked on links collection"); self.update_index_after_move(pk.clone(), from_link.0, new_link); - self.lock_manager.remove_with_lock_check(&pk); + lock.unlock(); + self.lock_manager.remove_with_lock_check(&pk); } (from_page_will_be_moved, to_page_will_be_filled)