diff --git a/crates/sqlez b/crates/sqlez deleted file mode 160000 index 10a78dbe53..0000000000 --- a/crates/sqlez +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 10a78dbe535a0c270b6b4bc469fbbffe9fc8c36f diff --git a/crates/sqlez/.gitignore b/crates/sqlez/.gitignore new file mode 100644 index 0000000000..8130c3ab47 --- /dev/null +++ b/crates/sqlez/.gitignore @@ -0,0 +1,2 @@ +debug/ +target/ diff --git a/crates/sqlez/Cargo.lock b/crates/sqlez/Cargo.lock new file mode 100644 index 0000000000..33348baed9 --- /dev/null +++ b/crates/sqlez/Cargo.lock @@ -0,0 +1,150 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "addr2line" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ecd88a8c8378ca913a680cd98f0f13ac67383d35993f86c90a70e3f137816b" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "anyhow" +version = "1.0.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "216261ddc8289130e551ddcd5ce8a064710c0d064a4d2895c67151c92b5443f6" +dependencies = [ + "backtrace", +] + +[[package]] +name = "backtrace" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab84319d616cfb654d03394f38ab7e6f0919e181b1b57e1fd15e7fb4077d9a7" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + +[[package]] +name = "cc" +version = "1.0.73" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "gimli" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22030e2c5a68ec659fde1e949a745124b48e6fa8b045b7ed5bd1fe4ccc5c4e5d" + +[[package]] +name = "indoc" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adab1eaa3408fb7f0c777a73e7465fd5656136fc93b670eb6df3c88c2c1344e3" + +[[package]] +name = "libc" +version = "0.2.137" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7fcc620a3bff7cdd7a365be3376c97191aeaccc2a603e600951e452615bf89" + +[[package]] +name = "libsqlite3-sys" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "memchr" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" + +[[package]] +name = "miniz_oxide" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96590ba8f175222643a85693f33d26e9c8a015f599c216509b1a6894af675d34" +dependencies = [ + "adler", +] + +[[package]] +name = "object" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21158b2c33aa6d4561f1c0a6ea283ca92bc54802a93b263e910746d679a7eb53" +dependencies = [ + "memchr", +] + +[[package]] +name = "once_cell" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1" + +[[package]] +name = "pkg-config" +version = "0.3.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" + +[[package]] +name = "rustc-demangle" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ef03e0a2b150c7a90d01faf6254c9c48a41e95fb2a8c2ac1c6f0d2b9aefc342" + +[[package]] +name = "sqlez" +version = "0.1.0" +dependencies = [ + "anyhow", + "indoc", + "libsqlite3-sys", + "thread_local", +] + +[[package]] +name = "thread_local" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5516c27b78311c50bf42c071425c560ac799b11c30b31f87e3081965fe5e0180" +dependencies = [ + "once_cell", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" diff --git a/crates/sqlez/Cargo.toml b/crates/sqlez/Cargo.toml new file mode 100644 index 0000000000..cbb4504a04 --- /dev/null +++ b/crates/sqlez/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "sqlez" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +anyhow = { version = "1.0.38", features = ["backtrace"] } +indoc = "1.0.7" +libsqlite3-sys = { version = "0.25.2", features = ["bundled"] } +thread_local = "1.1.4" \ No newline at end of file diff --git a/crates/sqlez/src/bindable.rs b/crates/sqlez/src/bindable.rs new file mode 100644 index 0000000000..ca3ba401cf --- /dev/null +++ b/crates/sqlez/src/bindable.rs @@ -0,0 +1,209 @@ +use anyhow::Result; + +use crate::statement::{SqlType, Statement}; + +pub trait Bind { + fn bind(&self, statement: &Statement, start_index: i32) -> Result; +} + +pub trait Column: Sized { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)>; +} + +impl Bind for &[u8] { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_blob(start_index, self)?; + Ok(start_index + 1) + } +} + +impl Bind for Vec { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_blob(start_index, self)?; + Ok(start_index + 1) + } +} + +impl Column for Vec { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_blob(start_index)?; + Ok((Vec::from(result), start_index + 1)) + } +} + +impl Bind for f64 { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_double(start_index, *self)?; + Ok(start_index + 1) + } +} + +impl Column for f64 { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_double(start_index)?; + Ok((result, start_index + 1)) + } +} + +impl Bind for i32 { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_int(start_index, *self)?; + Ok(start_index + 1) + } +} + +impl Column for i32 { + fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_int(start_index)?; + Ok((result, start_index + 1)) + } +} + +impl Bind for i64 { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_int64(start_index, *self)?; + Ok(start_index + 1) + } +} + +impl Column for i64 { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_int64(start_index)?; + Ok((result, start_index + 1)) + } +} + +impl Bind for usize { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + (*self as i64).bind(statement, start_index) + } +} + +impl Column for usize { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_int64(start_index)?; + Ok((result as usize, start_index + 1)) + } +} + +impl Bind for () { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_null(start_index)?; + Ok(start_index + 1) + } +} + +impl Bind for &str { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_text(start_index, self)?; + Ok(start_index + 1) + } +} + +impl Bind for String { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + statement.bind_text(start_index, self)?; + Ok(start_index + 1) + } +} + +impl Column for String { + fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let result = statement.column_text(start_index)?; + Ok((result.to_owned(), start_index + 1)) + } +} + +impl Bind for (T1, T2) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + self.1.bind(statement, next_index) + } +} + +impl Column for (T1, T2) { + fn column<'a>(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + Ok(((first, second), next_index)) + } +} + +impl Bind for (T1, T2, T3) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + let next_index = self.1.bind(statement, next_index)?; + self.2.bind(statement, next_index) + } +} + +impl Column for (T1, T2, T3) { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + let (third, next_index) = T3::column(statement, next_index)?; + Ok(((first, second, third), next_index)) + } +} + +impl Bind for (T1, T2, T3, T4) { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let next_index = self.0.bind(statement, start_index)?; + let next_index = self.1.bind(statement, next_index)?; + let next_index = self.2.bind(statement, next_index)?; + self.3.bind(statement, next_index) + } +} + +impl Column for (T1, T2, T3, T4) { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let (first, next_index) = T1::column(statement, start_index)?; + let (second, next_index) = T2::column(statement, next_index)?; + let (third, next_index) = T3::column(statement, next_index)?; + let (forth, next_index) = T4::column(statement, next_index)?; + Ok(((first, second, third, forth), next_index)) + } +} + +impl Bind for Option { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + if let Some(this) = self { + this.bind(statement, start_index) + } else { + statement.bind_null(start_index)?; + Ok(start_index + 1) + } + } +} + +impl Column for Option { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + if let SqlType::Null = statement.column_type(start_index)? { + Ok((None, start_index + 1)) + } else { + T::column(statement, start_index).map(|(result, next_index)| (Some(result), next_index)) + } + } +} + +impl Bind for [T; COUNT] { + fn bind(&self, statement: &Statement, start_index: i32) -> Result { + let mut current_index = start_index; + for binding in self { + current_index = binding.bind(statement, current_index)? + } + + Ok(current_index) + } +} + +impl Column for [T; COUNT] { + fn column(statement: &mut Statement, start_index: i32) -> Result<(Self, i32)> { + let mut array = [Default::default(); COUNT]; + let mut current_index = start_index; + for i in 0..COUNT { + (array[i], current_index) = T::column(statement, current_index)?; + } + Ok((array, current_index)) + } +} diff --git a/crates/sqlez/src/connection.rs b/crates/sqlez/src/connection.rs new file mode 100644 index 0000000000..81bb9dfe78 --- /dev/null +++ b/crates/sqlez/src/connection.rs @@ -0,0 +1,220 @@ +use std::{ + ffi::{CStr, CString}, + marker::PhantomData, +}; + +use anyhow::{anyhow, Result}; +use libsqlite3_sys::*; + +use crate::statement::Statement; + +pub struct Connection { + pub(crate) sqlite3: *mut sqlite3, + persistent: bool, + phantom: PhantomData, +} +unsafe impl Send for Connection {} + +impl Connection { + fn open(uri: &str, persistent: bool) -> Result { + let mut connection = Self { + sqlite3: 0 as *mut _, + persistent, + phantom: PhantomData, + }; + + let flags = SQLITE_OPEN_CREATE | SQLITE_OPEN_NOMUTEX | SQLITE_OPEN_READWRITE; + unsafe { + sqlite3_open_v2( + CString::new(uri)?.as_ptr(), + &mut connection.sqlite3, + flags, + 0 as *const _, + ); + + connection.last_error()?; + } + + Ok(connection) + } + + /// Attempts to open the database at uri. If it fails, a shared memory db will be opened + /// instead. + pub fn open_file(uri: &str) -> Self { + Self::open(uri, true).unwrap_or_else(|_| Self::open_memory(uri)) + } + + pub fn open_memory(uri: &str) -> Self { + let in_memory_path = format!("file:{}?mode=memory&cache=shared", uri); + Self::open(&in_memory_path, false).expect("Could not create fallback in memory db") + } + + pub fn persistent(&self) -> bool { + self.persistent + } + + pub fn exec(&self, query: impl AsRef) -> Result<()> { + unsafe { + sqlite3_exec( + self.sqlite3, + CString::new(query.as_ref())?.as_ptr(), + None, + 0 as *mut _, + 0 as *mut _, + ); + self.last_error()?; + } + Ok(()) + } + + pub fn prepare>(&self, query: T) -> Result { + Statement::prepare(&self, query) + } + + pub fn backup_main(&self, destination: &Connection) -> Result<()> { + unsafe { + let backup = sqlite3_backup_init( + destination.sqlite3, + CString::new("main")?.as_ptr(), + self.sqlite3, + CString::new("main")?.as_ptr(), + ); + sqlite3_backup_step(backup, -1); + sqlite3_backup_finish(backup); + destination.last_error() + } + } + + pub(crate) fn last_error(&self) -> Result<()> { + const NON_ERROR_CODES: &[i32] = &[SQLITE_OK, SQLITE_ROW]; + unsafe { + let code = sqlite3_errcode(self.sqlite3); + if NON_ERROR_CODES.contains(&code) { + return Ok(()); + } + + let message = sqlite3_errmsg(self.sqlite3); + let message = if message.is_null() { + None + } else { + Some( + String::from_utf8_lossy(CStr::from_ptr(message as *const _).to_bytes()) + .into_owned(), + ) + }; + + Err(anyhow!( + "Sqlite call failed with code {} and message: {:?}", + code as isize, + message + )) + } + } +} + +impl Drop for Connection { + fn drop(&mut self) { + unsafe { sqlite3_close(self.sqlite3) }; + } +} + +#[cfg(test)] +mod test { + use anyhow::Result; + use indoc::indoc; + + use crate::connection::Connection; + + #[test] + fn string_round_trips() -> Result<()> { + let connection = Connection::open_memory("string_round_trips"); + connection + .exec(indoc! {" + CREATE TABLE text ( + text TEXT + );"}) + .unwrap(); + + let text = "Some test text"; + + connection + .prepare("INSERT INTO text (text) VALUES (?);") + .unwrap() + .bound(text) + .unwrap() + .run() + .unwrap(); + + assert_eq!( + &connection + .prepare("SELECT text FROM text;") + .unwrap() + .row::() + .unwrap(), + text + ); + + Ok(()) + } + + #[test] + fn tuple_round_trips() { + let connection = Connection::open_memory("tuple_round_trips"); + connection + .exec(indoc! {" + CREATE TABLE test ( + text TEXT, + integer INTEGER, + blob BLOB + );"}) + .unwrap(); + + let tuple1 = ("test".to_string(), 64, vec![0, 1, 2, 4, 8, 16, 32, 64]); + let tuple2 = ("test2".to_string(), 32, vec![64, 32, 16, 8, 4, 2, 1, 0]); + + let mut insert = connection + .prepare("INSERT INTO test (text, integer, blob) VALUES (?, ?, ?)") + .unwrap(); + + insert.bound(tuple1.clone()).unwrap().run().unwrap(); + insert.bound(tuple2.clone()).unwrap().run().unwrap(); + + assert_eq!( + connection + .prepare("SELECT * FROM test") + .unwrap() + .rows::<(String, usize, Vec)>() + .unwrap(), + vec![tuple1, tuple2] + ); + } + + #[test] + fn backup_works() { + let connection1 = Connection::open_memory("backup_works"); + connection1 + .exec(indoc! {" + CREATE TABLE blobs ( + data BLOB + );"}) + .unwrap(); + let blob = &[0, 1, 2, 4, 8, 16, 32, 64]; + let mut write = connection1 + .prepare("INSERT INTO blobs (data) VALUES (?);") + .unwrap(); + write.bind_blob(1, blob).unwrap(); + write.run().unwrap(); + + // Backup connection1 to connection2 + let connection2 = Connection::open_memory("backup_works_other"); + connection1.backup_main(&connection2).unwrap(); + + // Delete the added blob and verify its deleted on the other side + let read_blobs = connection1 + .prepare("SELECT * FROM blobs;") + .unwrap() + .rows::>() + .unwrap(); + assert_eq!(read_blobs, vec![blob]); + } +} diff --git a/crates/sqlez/src/lib.rs b/crates/sqlez/src/lib.rs new file mode 100644 index 0000000000..3bed7a06cb --- /dev/null +++ b/crates/sqlez/src/lib.rs @@ -0,0 +1,6 @@ +pub mod bindable; +pub mod connection; +pub mod migrations; +pub mod savepoint; +pub mod statement; +pub mod thread_safe_connection; diff --git a/crates/sqlez/src/migrations.rs b/crates/sqlez/src/migrations.rs new file mode 100644 index 0000000000..4721b353c6 --- /dev/null +++ b/crates/sqlez/src/migrations.rs @@ -0,0 +1,261 @@ +// Migrations are constructed by domain, and stored in a table in the connection db with domain name, +// effected tables, actual query text, and order. +// If a migration is run and any of the query texts don't match, the app panics on startup (maybe fallback +// to creating a new db?) +// Otherwise any missing migrations are run on the connection + +use anyhow::{anyhow, Result}; +use indoc::{formatdoc, indoc}; + +use crate::connection::Connection; + +const MIGRATIONS_MIGRATION: Migration = Migration::new( + "migrations", + // The migrations migration must be infallable because it runs to completion + // with every call to migration run and is run unchecked. + &[indoc! {" + CREATE TABLE IF NOT EXISTS migrations ( + domain TEXT, + step INTEGER, + migration TEXT + ); + "}], +); + +pub struct Migration { + domain: &'static str, + migrations: &'static [&'static str], +} + +impl Migration { + pub const fn new(domain: &'static str, migrations: &'static [&'static str]) -> Self { + Self { domain, migrations } + } + + fn run_unchecked(&self, connection: &Connection) -> Result<()> { + connection.exec(self.migrations.join(";\n")) + } + + pub fn run(&self, connection: &Connection) -> Result<()> { + // Setup the migrations table unconditionally + MIGRATIONS_MIGRATION.run_unchecked(connection)?; + + let completed_migrations = connection + .prepare(indoc! {" + SELECT domain, step, migration FROM migrations + WHERE domain = ? + ORDER BY step + "})? + .bound(self.domain)? + .rows::<(String, usize, String)>()?; + + let mut store_completed_migration = connection + .prepare("INSERT INTO migrations (domain, step, migration) VALUES (?, ?, ?)")?; + + for (index, migration) in self.migrations.iter().enumerate() { + if let Some((_, _, completed_migration)) = completed_migrations.get(index) { + if completed_migration != migration { + return Err(anyhow!(formatdoc! {" + Migration changed for {} at step {} + + Stored migration: + {} + + Proposed migration: + {}", self.domain, index, completed_migration, migration})); + } else { + // Migration already run. Continue + continue; + } + } + + connection.exec(migration)?; + store_completed_migration + .bound((self.domain, index, *migration))? + .run()?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod test { + use indoc::indoc; + + use crate::{connection::Connection, migrations::Migration}; + + #[test] + fn test_migrations_are_added_to_table() { + let connection = Connection::open_memory("migrations_are_added_to_table"); + + // Create first migration with a single step and run it + let mut migration = Migration::new( + "test", + &[indoc! {" + CREATE TABLE test1 ( + a TEXT, + b TEXT + );"}], + ); + migration.run(&connection).unwrap(); + + // Verify it got added to the migrations table + assert_eq!( + &connection + .prepare("SELECT (migration) FROM migrations") + .unwrap() + .rows::() + .unwrap()[..], + migration.migrations + ); + + // Add another step to the migration and run it again + migration.migrations = &[ + indoc! {" + CREATE TABLE test1 ( + a TEXT, + b TEXT + );"}, + indoc! {" + CREATE TABLE test2 ( + c TEXT, + d TEXT + );"}, + ]; + migration.run(&connection).unwrap(); + + // Verify it is also added to the migrations table + assert_eq!( + &connection + .prepare("SELECT (migration) FROM migrations") + .unwrap() + .rows::() + .unwrap()[..], + migration.migrations + ); + } + + #[test] + fn test_migration_setup_works() { + let connection = Connection::open_memory("migration_setup_works"); + + connection + .exec(indoc! {"CREATE TABLE IF NOT EXISTS migrations ( + domain TEXT, + step INTEGER, + migration TEXT + );"}) + .unwrap(); + + let mut store_completed_migration = connection + .prepare(indoc! {" + INSERT INTO migrations (domain, step, migration) + VALUES (?, ?, ?)"}) + .unwrap(); + + let domain = "test_domain"; + for i in 0..5 { + // Create a table forcing a schema change + connection + .exec(format!("CREATE TABLE table{} ( test TEXT );", i)) + .unwrap(); + + store_completed_migration + .bound((domain, i, i.to_string())) + .unwrap() + .run() + .unwrap(); + } + } + + #[test] + fn migrations_dont_rerun() { + let connection = Connection::open_memory("migrations_dont_rerun"); + + // Create migration which clears a table + let migration = Migration::new("test", &["DELETE FROM test_table"]); + + // Manually create the table for that migration with a row + connection + .exec(indoc! {" + CREATE TABLE test_table ( + test_column INTEGER + ); + INSERT INTO test_table (test_column) VALUES (1)"}) + .unwrap(); + + assert_eq!( + connection + .prepare("SELECT * FROM test_table") + .unwrap() + .row::() + .unwrap(), + 1 + ); + + // Run the migration verifying that the row got dropped + migration.run(&connection).unwrap(); + assert_eq!( + connection + .prepare("SELECT * FROM test_table") + .unwrap() + .rows::() + .unwrap(), + Vec::new() + ); + + // Recreate the dropped row + connection + .exec("INSERT INTO test_table (test_column) VALUES (2)") + .unwrap(); + + // Run the same migration again and verify that the table was left unchanged + migration.run(&connection).unwrap(); + assert_eq!( + connection + .prepare("SELECT * FROM test_table") + .unwrap() + .row::() + .unwrap(), + 2 + ); + } + + #[test] + fn changed_migration_fails() { + let connection = Connection::open_memory("changed_migration_fails"); + + // Create a migration with two steps and run it + Migration::new( + "test migration", + &[ + indoc! {" + CREATE TABLE test ( + col INTEGER + )"}, + indoc! {" + INSERT INTO test (col) VALUES (1)"}, + ], + ) + .run(&connection) + .unwrap(); + + // Create another migration with the same domain but different steps + let second_migration_result = Migration::new( + "test migration", + &[ + indoc! {" + CREATE TABLE test ( + color INTEGER + )"}, + indoc! {" + INSERT INTO test (color) VALUES (1)"}, + ], + ) + .run(&connection); + + // Verify new migration returns error when run + assert!(second_migration_result.is_err()) + } +} diff --git a/crates/sqlez/src/savepoint.rs b/crates/sqlez/src/savepoint.rs new file mode 100644 index 0000000000..749c0dc948 --- /dev/null +++ b/crates/sqlez/src/savepoint.rs @@ -0,0 +1,110 @@ +use anyhow::Result; + +use crate::connection::Connection; + +impl Connection { + // Run a set of commands within the context of a `SAVEPOINT name`. If the callback + // returns Ok(None) or Err(_), the savepoint will be rolled back. Otherwise, the save + // point is released. + pub fn with_savepoint(&mut self, name: impl AsRef, f: F) -> Result> + where + F: FnOnce(&mut Connection) -> Result>, + { + let name = name.as_ref().to_owned(); + self.exec(format!("SAVEPOINT {}", &name))?; + let result = f(self); + match result { + Ok(Some(_)) => { + self.exec(format!("RELEASE {}", name))?; + } + Ok(None) | Err(_) => { + self.exec(format!("ROLLBACK TO {}", name))?; + self.exec(format!("RELEASE {}", name))?; + } + } + result + } +} + +#[cfg(test)] +mod tests { + use crate::connection::Connection; + use anyhow::Result; + use indoc::indoc; + + #[test] + fn test_nested_savepoints() -> Result<()> { + let mut connection = Connection::open_memory("nested_savepoints"); + + connection + .exec(indoc! {" + CREATE TABLE text ( + text TEXT, + idx INTEGER + );"}) + .unwrap(); + + let save1_text = "test save1"; + let save2_text = "test save2"; + + connection.with_savepoint("first", |save1| { + save1 + .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? + .bound((save1_text, 1))? + .run()?; + + assert!(save1 + .with_savepoint("second", |save2| -> Result, anyhow::Error> { + save2 + .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? + .bound((save2_text, 2))? + .run()?; + + assert_eq!( + save2 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text, save2_text], + ); + + anyhow::bail!("Failed second save point :(") + }) + .err() + .is_some()); + + assert_eq!( + save1 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text], + ); + + save1.with_savepoint("second", |save2| { + save2 + .prepare("INSERT INTO text(text, idx) VALUES (?, ?)")? + .bound((save2_text, 2))? + .run()?; + + assert_eq!( + save2 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text, save2_text], + ); + + Ok(Some(())) + })?; + + assert_eq!( + save1 + .prepare("SELECT text FROM text ORDER BY text.idx ASC")? + .rows::()?, + vec![save1_text, save2_text], + ); + + Ok(Some(())) + })?; + + Ok(()) + } +} diff --git a/crates/sqlez/src/statement.rs b/crates/sqlez/src/statement.rs new file mode 100644 index 0000000000..774cda0e34 --- /dev/null +++ b/crates/sqlez/src/statement.rs @@ -0,0 +1,342 @@ +use std::ffi::{c_int, CString}; +use std::marker::PhantomData; +use std::{slice, str}; + +use anyhow::{anyhow, Context, Result}; +use libsqlite3_sys::*; + +use crate::bindable::{Bind, Column}; +use crate::connection::Connection; + +pub struct Statement<'a> { + raw_statement: *mut sqlite3_stmt, + connection: &'a Connection, + phantom: PhantomData, +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum StepResult { + Row, + Done, + Misuse, + Other(i32), +} + +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub enum SqlType { + Text, + Integer, + Blob, + Float, + Null, +} + +impl<'a> Statement<'a> { + pub fn prepare>(connection: &'a Connection, query: T) -> Result { + let mut statement = Self { + raw_statement: 0 as *mut _, + connection, + phantom: PhantomData, + }; + + unsafe { + sqlite3_prepare_v2( + connection.sqlite3, + CString::new(query.as_ref())?.as_ptr(), + -1, + &mut statement.raw_statement, + 0 as *mut _, + ); + + connection.last_error().context("Prepare call failed.")?; + } + + Ok(statement) + } + + pub fn reset(&mut self) { + unsafe { + sqlite3_reset(self.raw_statement); + } + } + + pub fn bind_blob(&self, index: i32, blob: &[u8]) -> Result<()> { + let index = index as c_int; + let blob_pointer = blob.as_ptr() as *const _; + let len = blob.len() as c_int; + unsafe { + sqlite3_bind_blob( + self.raw_statement, + index, + blob_pointer, + len, + SQLITE_TRANSIENT(), + ); + } + self.connection.last_error() + } + + pub fn column_blob<'b>(&'b mut self, index: i32) -> Result<&'b [u8]> { + let index = index as c_int; + let pointer = unsafe { sqlite3_column_blob(self.raw_statement, index) }; + + self.connection.last_error()?; + if pointer.is_null() { + return Ok(&[]); + } + let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; + self.connection.last_error()?; + unsafe { Ok(slice::from_raw_parts(pointer as *const u8, len)) } + } + + pub fn bind_double(&self, index: i32, double: f64) -> Result<()> { + let index = index as c_int; + + unsafe { + sqlite3_bind_double(self.raw_statement, index, double); + } + self.connection.last_error() + } + + pub fn column_double(&self, index: i32) -> Result { + let index = index as c_int; + let result = unsafe { sqlite3_column_double(self.raw_statement, index) }; + self.connection.last_error()?; + Ok(result) + } + + pub fn bind_int(&self, index: i32, int: i32) -> Result<()> { + let index = index as c_int; + + unsafe { + sqlite3_bind_int(self.raw_statement, index, int); + } + self.connection.last_error() + } + + pub fn column_int(&self, index: i32) -> Result { + let index = index as c_int; + let result = unsafe { sqlite3_column_int(self.raw_statement, index) }; + self.connection.last_error()?; + Ok(result) + } + + pub fn bind_int64(&self, index: i32, int: i64) -> Result<()> { + let index = index as c_int; + unsafe { + sqlite3_bind_int64(self.raw_statement, index, int); + } + self.connection.last_error() + } + + pub fn column_int64(&self, index: i32) -> Result { + let index = index as c_int; + let result = unsafe { sqlite3_column_int64(self.raw_statement, index) }; + self.connection.last_error()?; + Ok(result) + } + + pub fn bind_null(&self, index: i32) -> Result<()> { + let index = index as c_int; + unsafe { + sqlite3_bind_null(self.raw_statement, index); + } + self.connection.last_error() + } + + pub fn bind_text(&self, index: i32, text: &str) -> Result<()> { + let index = index as c_int; + let text_pointer = text.as_ptr() as *const _; + let len = text.len() as c_int; + unsafe { + sqlite3_bind_blob( + self.raw_statement, + index, + text_pointer, + len, + SQLITE_TRANSIENT(), + ); + } + self.connection.last_error() + } + + pub fn column_text<'b>(&'b mut self, index: i32) -> Result<&'b str> { + let index = index as c_int; + let pointer = unsafe { sqlite3_column_text(self.raw_statement, index) }; + + self.connection.last_error()?; + if pointer.is_null() { + return Ok(""); + } + let len = unsafe { sqlite3_column_bytes(self.raw_statement, index) as usize }; + self.connection.last_error()?; + + let slice = unsafe { slice::from_raw_parts(pointer as *const u8, len) }; + Ok(str::from_utf8(slice)?) + } + + pub fn bind(&self, value: T) -> Result<()> { + value.bind(self, 1)?; + Ok(()) + } + + pub fn column(&mut self) -> Result { + let (result, _) = T::column(self, 0)?; + Ok(result) + } + + pub fn column_type(&mut self, index: i32) -> Result { + let result = unsafe { sqlite3_column_type(self.raw_statement, index) }; // SELECT FROM TABLE + self.connection.last_error()?; + match result { + SQLITE_INTEGER => Ok(SqlType::Integer), + SQLITE_FLOAT => Ok(SqlType::Float), + SQLITE_TEXT => Ok(SqlType::Text), + SQLITE_BLOB => Ok(SqlType::Blob), + SQLITE_NULL => Ok(SqlType::Null), + _ => Err(anyhow!("Column type returned was incorrect ")), + } + } + + pub fn bound(&mut self, bindings: impl Bind) -> Result<&mut Self> { + self.bind(bindings)?; + Ok(self) + } + + fn step(&mut self) -> Result { + unsafe { + match sqlite3_step(self.raw_statement) { + SQLITE_ROW => Ok(StepResult::Row), + SQLITE_DONE => Ok(StepResult::Done), + SQLITE_MISUSE => Ok(StepResult::Misuse), + other => self + .connection + .last_error() + .map(|_| StepResult::Other(other)), + } + } + } + + pub fn run(&mut self) -> Result<()> { + fn logic(this: &mut Statement) -> Result<()> { + while this.step()? == StepResult::Row {} + Ok(()) + } + let result = logic(self); + self.reset(); + result + } + + pub fn map(&mut self, callback: impl FnMut(&mut Statement) -> Result) -> Result> { + fn logic( + this: &mut Statement, + mut callback: impl FnMut(&mut Statement) -> Result, + ) -> Result> { + let mut mapped_rows = Vec::new(); + while this.step()? == StepResult::Row { + mapped_rows.push(callback(this)?); + } + Ok(mapped_rows) + } + + let result = logic(self, callback); + self.reset(); + result + } + + pub fn rows(&mut self) -> Result> { + self.map(|s| s.column::()) + } + + pub fn single(&mut self, callback: impl FnOnce(&mut Statement) -> Result) -> Result { + fn logic( + this: &mut Statement, + callback: impl FnOnce(&mut Statement) -> Result, + ) -> Result { + if this.step()? != StepResult::Row { + return Err(anyhow!( + "Single(Map) called with query that returns no rows." + )); + } + callback(this) + } + let result = logic(self, callback); + self.reset(); + result + } + + pub fn row(&mut self) -> Result { + self.single(|this| this.column::()) + } + + pub fn maybe( + &mut self, + callback: impl FnOnce(&mut Statement) -> Result, + ) -> Result> { + fn logic( + this: &mut Statement, + callback: impl FnOnce(&mut Statement) -> Result, + ) -> Result> { + if this.step()? != StepResult::Row { + return Ok(None); + } + callback(this).map(|r| Some(r)) + } + let result = logic(self, callback); + self.reset(); + result + } + + pub fn maybe_row(&mut self) -> Result> { + self.maybe(|this| this.column::()) + } +} + +impl<'a> Drop for Statement<'a> { + fn drop(&mut self) { + unsafe { + sqlite3_finalize(self.raw_statement); + self.connection + .last_error() + .expect("sqlite3 finalize failed for statement :("); + }; + } +} + +#[cfg(test)] +mod test { + use indoc::indoc; + + use crate::{connection::Connection, statement::StepResult}; + + #[test] + fn blob_round_trips() { + let connection1 = Connection::open_memory("blob_round_trips"); + connection1 + .exec(indoc! {" + CREATE TABLE blobs ( + data BLOB + );"}) + .unwrap(); + + let blob = &[0, 1, 2, 4, 8, 16, 32, 64]; + + let mut write = connection1 + .prepare("INSERT INTO blobs (data) VALUES (?);") + .unwrap(); + write.bind_blob(1, blob).unwrap(); + assert_eq!(write.step().unwrap(), StepResult::Done); + + // Read the blob from the + let connection2 = Connection::open_memory("blob_round_trips"); + let mut read = connection2.prepare("SELECT * FROM blobs;").unwrap(); + assert_eq!(read.step().unwrap(), StepResult::Row); + assert_eq!(read.column_blob(0).unwrap(), blob); + assert_eq!(read.step().unwrap(), StepResult::Done); + + // Delete the added blob and verify its deleted on the other side + connection2.exec("DELETE FROM blobs;").unwrap(); + let mut read = connection1.prepare("SELECT * FROM blobs;").unwrap(); + assert_eq!(read.step().unwrap(), StepResult::Done); + } +} diff --git a/crates/sqlez/src/thread_safe_connection.rs b/crates/sqlez/src/thread_safe_connection.rs new file mode 100644 index 0000000000..8885edc2c0 --- /dev/null +++ b/crates/sqlez/src/thread_safe_connection.rs @@ -0,0 +1,78 @@ +use std::{ops::Deref, sync::Arc}; + +use connection::Connection; +use thread_local::ThreadLocal; + +use crate::connection; + +pub struct ThreadSafeConnection { + uri: Arc, + persistent: bool, + initialize_query: Option<&'static str>, + connection: Arc>, +} + +impl ThreadSafeConnection { + pub fn new(uri: &str, persistent: bool) -> Self { + Self { + uri: Arc::from(uri), + persistent, + initialize_query: None, + connection: Default::default(), + } + } + + /// Sets the query to run every time a connection is opened. This must + /// be infallible (EG only use pragma statements) + pub fn with_initialize_query(mut self, initialize_query: &'static str) -> Self { + self.initialize_query = Some(initialize_query); + self + } + + /// Opens a new db connection with the initialized file path. This is internal and only + /// called from the deref function. + /// If opening fails, the connection falls back to a shared memory connection + fn open_file(&self) -> Connection { + Connection::open_file(self.uri.as_ref()) + } + + /// Opens a shared memory connection using the file path as the identifier. This unwraps + /// as we expect it always to succeed + fn open_shared_memory(&self) -> Connection { + Connection::open_memory(self.uri.as_ref()) + } +} + +impl Clone for ThreadSafeConnection { + fn clone(&self) -> Self { + Self { + uri: self.uri.clone(), + persistent: self.persistent, + initialize_query: self.initialize_query.clone(), + connection: self.connection.clone(), + } + } +} + +impl Deref for ThreadSafeConnection { + type Target = Connection; + + fn deref(&self) -> &Self::Target { + self.connection.get_or(|| { + let connection = if self.persistent { + self.open_file() + } else { + self.open_shared_memory() + }; + + if let Some(initialize_query) = self.initialize_query { + connection.exec(initialize_query).expect(&format!( + "Initialize query failed to execute: {}", + initialize_query + )); + } + + connection + }) + } +}