dittolive_ditto/ditto/init/
mod.rs

1use std::{
2    ffi::CString,
3    future::Future,
4    os::raw::{c_uint, c_void},
5    str::FromStr,
6    sync::{Arc, Weak},
7    time::Duration,
8};
9
10use async_fn_traits::AsyncFn2;
11use async_trait::async_trait;
12use extern_c::extern_c;
13use ffi_sdk::{FsComponent, TransportConfigMode};
14
15pub use self::config::{DittoConfig, DittoConfigConnect};
16use crate::{
17    ditto::{
18        init::config::{ActualConfig, InternalConfig},
19        DittoFields,
20    },
21    error,
22    identity::DittoAuthenticator,
23    small_peer_info::SmallPeerInfo,
24    utils::{make_continuation, prelude::*},
25    warn,
26};
27
28pub(crate) mod config;
29
30impl Ditto {
31    /// Open a new Ditto instance using a [`DittoConfig`]
32    ///
33    /// # Example
34    ///
35    /// ```
36    /// # use dittolive_ditto::prelude::*;
37    /// # async fn example() -> anyhow::Result<()> {
38    /// // Load your database ID somehow, ENV is a good option
39    /// let database_id = std::env::var("DITTO_DATABASE_ID")?;
40    ///
41    /// // Choose one of the following types of connection config
42    /// let connect = DittoConfigConnect::Server {
43    ///     url: "https://example.com/your-server-url".parse().unwrap(),
44    /// };
45    /// let connect = DittoConfigConnect::SmallPeersOnly {
46    ///     private_key: Some("https://example.com/your-server-url".bytes().collect()),
47    /// };
48    /// let connect = DittoConfigConnect::SmallPeersOnly { private_key: None };
49    ///
50    /// let config = DittoConfig::new(database_id, connect);
51    /// let ditto = Ditto::open(config).await?;
52    /// # Ok(())
53    /// # }
54    /// ```
55    pub async fn open(config: DittoConfig) -> Result<Ditto, DittoError> {
56        Self::init_sdk_version();
57        let default_root_dir = default_root_directory();
58
59        let customer_config = config;
60        let mut internal_config = InternalConfig {
61            legacy_persistence_directory: None,
62        };
63
64        // SDKS-3187: When the user hasn't set an explicit persistence directory,
65        // provide the v4 default path as a legacy fallback so that v4→v5 upgrades
66        // don't silently create a new empty database. The v4 Rust SDK used the
67        // executable's parent directory directly as the default root.
68        if customer_config.persistence_directory.is_none() {
69            if let Some(root) = &default_root_dir {
70                internal_config.legacy_persistence_directory = Some(root.clone());
71            }
72        }
73
74        let actual_config = ActualConfig {
75            customer_facing: customer_config,
76            internal: internal_config,
77        };
78
79        let config_cbor: &[u8] =
80            &serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
81        let (continuation, recv) = make_continuation();
82        let default_root_dir_ref = default_root_dir
83            .as_deref()
84            .and_then(|root| root.to_str())
85            .ok_or_else(|| {
86                DittoError::new(
87                    ErrorKind::IO,
88                    "Unable to resolve a default data directory on this platform".to_string(),
89                )
90            })?;
91        let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
92            .expect("should construct CString from no-nulls &str");
93        let default_root_dir_cstr = &*default_root_dir_cstring;
94        let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
95
96        // CLIPPY: continuation.into() needed to pass CI
97        #[allow(clippy::useless_conversion)]
98        ffi_sdk::dittoffi_ditto_open_async_throws(
99            config_cbor.into(),
100            TransportConfigMode::PlatformIndependent,
101            default_root_dir_charp,
102            continuation.into(),
103        );
104
105        let ffi_result = recv.await.unwrap();
106        let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_result.into_rust_result()?;
107
108        Self::finish_open(ffi_ditto, actual_config.customer_facing)
109    }
110
111    /// Open a new Ditto instance using a [`DittoConfig`]
112    ///
113    /// This is a synchronous blocking variant of [`Ditto::open()`] that will wait until
114    /// initialization is complete.
115    ///
116    /// # Example
117    ///
118    /// ```
119    /// # use dittolive_ditto::prelude::*;
120    /// # fn example() -> anyhow::Result<()> {
121    /// // Load your database ID somehow, ENV is a good option
122    /// let database_id = std::env::var("DITTO_DATABASE_ID")?;
123    ///
124    /// // Choose one of the following types of connection config
125    /// let connect = DittoConfigConnect::Server {
126    ///     url: "https://example.com/your-server-url".parse().unwrap(),
127    /// };
128    /// let connect = DittoConfigConnect::SmallPeersOnly {
129    ///     private_key: Some("https://example.com/your-server-url".bytes().collect()),
130    /// };
131    /// let connect = DittoConfigConnect::SmallPeersOnly { private_key: None };
132    ///
133    /// let config = DittoConfig::new(database_id, connect);
134    /// let ditto = Ditto::open_sync(config)?;
135    /// # Ok(())
136    /// # }
137    /// ```
138    pub fn open_sync(config: DittoConfig) -> Result<Ditto, DittoError> {
139        Self::init_sdk_version();
140        let default_root_dir = default_root_directory();
141
142        let customer_config = config;
143        let mut internal_config = InternalConfig {
144            legacy_persistence_directory: None,
145        };
146
147        // SDKS-3187: Same legacy fallback as open() above.
148        if customer_config.persistence_directory.is_none() {
149            if let Some(root) = &default_root_dir {
150                internal_config.legacy_persistence_directory = Some(root.clone());
151            }
152        }
153
154        let actual_config = ActualConfig {
155            customer_facing: customer_config,
156            internal: internal_config,
157        };
158
159        let config_cbor: &[u8] =
160            &serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
161        let default_root_dir_ref = default_root_dir
162            .as_deref()
163            .and_then(|root| root.to_str())
164            .ok_or_else(|| {
165                DittoError::new(
166                    ErrorKind::IO,
167                    "Unable to resolve a default data directory on this platform".to_string(),
168                )
169            })?;
170        let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
171            .expect("should construct CString from no-nulls &str");
172        let default_root_dir_cstr = &*default_root_dir_cstring;
173        let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
174
175        let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_sdk::dittoffi_ditto_open_throws(
176            config_cbor.into(),
177            TransportConfigMode::PlatformIndependent,
178            default_root_dir_charp,
179        )
180        .into_rust_result()?;
181
182        Self::finish_open(ffi_ditto, actual_config.customer_facing)
183    }
184
185    fn finish_open(
186        ffi_ditto: repr_c::Box<ffi_sdk::Ditto>,
187        config: DittoConfig,
188    ) -> Result<Ditto, DittoError> {
189        let ditto: Arc<repr_c::Box<ffi_sdk::Ditto>> = Arc::new(ffi_ditto);
190        let has_auth = matches!(&config.connect, DittoConfigConnect::Server { .. });
191
192        let disk_usage = DiskUsage::new(ditto.retain(), FsComponent::Root);
193        let small_peer_info = SmallPeerInfo::new(ditto.retain());
194        let fields = Arc::new_cyclic(|weak_fields: &arc::Weak<_>| {
195            let store = Store::new(ditto.retain(), weak_fields.clone());
196            let sync = crate::sync::Sync::new(weak_fields.clone());
197            let presence = Arc::new(Presence::new(weak_fields.clone()));
198
199            DittoFields {
200                ditto: ditto.retain(),
201                has_auth,
202                config,
203                store,
204                sync,
205                presence,
206                disk_usage,
207                small_peer_info,
208            }
209        });
210
211        let ditto = Ditto {
212            fields,
213            is_shut_down_able: true,
214        };
215
216        Ok(ditto)
217    }
218}
219
220fn default_root_directory() -> Option<PathBuf> {
221    std::env::current_exe()
222        .ok()
223        .and_then(|abspath| abspath.parent().map(|x| x.to_path_buf()))
224}
225
226impl DittoAuthenticator {
227    /// Set a callback to notify the client when the authentication is expiring.
228    ///
229    /// When using `DittoConfigConnect::Server { .. }` mode, Ditto _requires_ you to register an
230    /// "expiration handler" for authentication. This handler is called for an initial
231    /// authentication and periodically thereafter when the current authentication is near to
232    /// expiration.
233    ///
234    /// For more details about authentication, see the [Ditto Auth and Authorization docs][0].
235    ///
236    /// [0]: https://docs.ditto.live/sdk/latest/auth-and-authorization/cloud-authentication#login
237    ///
238    /// For more details about the expiration handler, see the [`DittoAuthExpirationHandler`] trait.
239    ///
240    /// # Example
241    ///
242    /// ```
243    /// # #[allow(deprecated)]
244    /// # use tracing::error;
245    /// # use dittolive_ditto::prelude::*;
246    /// # fn main() -> anyhow::Result<()> {
247    /// # let (_root, ditto) = dittolive_ditto::doctest_helpers::doctest_online_ditto();
248    /// async fn sample_get_token() -> anyhow::Result<String> {
249    ///     // e.g. reqwest::get("https://example.com/token").await?.text().await?;
250    ///     Ok("token".to_string())
251    /// }
252    ///
253    /// let auth = ditto
254    ///     .auth()
255    ///     .expect("Auth is available in Server connect mode");
256    /// auth.set_expiration_handler(async |ditto: &Ditto, duration_remaining| {
257    ///     let auth = ditto
258    ///         .auth()
259    ///         .expect("Auth is available in Server connect mode");
260    ///
261    ///     // Call your auth service to get a new token
262    ///     let token = sample_get_token().await.unwrap();
263    ///
264    ///     // Where "my-provider" is the name of the Authentication Webhook
265    ///     // you've configured on the Ditto Portal
266    ///     let result = auth.login(&token, "my-provider");
267    ///
268    ///     if let Err(login_error) = result {
269    ///         // Handle login error, e.g.:
270    ///         error!("Failed to login: {}", login_error);
271    ///     }
272    /// });
273    /// # Ok(())
274    /// # }
275    /// ```
276    pub fn set_expiration_handler<F>(&self, handler: F)
277    where
278        F: DittoAuthExpirationHandler,
279    {
280        let Some(ditto) = self.ditto_fields.upgrade() else {
281            #[allow(deprecated)] // Workaround for patched tracing
282            {
283                error!("Failed to set expiration handler, Ditto has shut down");
284            }
285            return;
286        };
287
288        let login_provider = make_login_provider(self.ditto_fields.clone(), Arc::new(handler));
289        ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, Some(login_provider));
290    }
291
292    /// Clear the expiration handler, if set.
293    pub fn clear_expiration_handler(&self) {
294        let Some(ditto) = self.ditto_fields.upgrade() else {
295            #[allow(deprecated)] // Workaround for patched tracing
296            {
297                error!("Failed to clear expiration handler, Ditto has shut down");
298            }
299            return;
300        };
301
302        ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, None);
303    }
304}
305
306/// Trait describing types which can be used as an authentication expiration handler for Ditto.
307///
308/// When using `DittoConfigConnect::Server { .. }` mode, Ditto _requires_ you to register an
309/// "expiration handler" for authentication. This handler is called for an initial authentication
310/// and periodically thereafter when the current authentication is near expiration.
311///
312/// This trait is implemented for async closures and for functions returning an
313/// `impl Future<Output = ()>`, which take the expected arguments of an auth expiration handler.
314///
315/// Expiration handlers are expected to call `auth.login(...)` with a valid authentication token.
316///
317/// For more details about authentication, see the [Ditto Auth and Authorization docs][0].
318///
319/// [0]: https://docs.ditto.live/sdk/latest/auth-and-authorization/cloud-authentication#login
320///
321/// **NOTE**: Because expiration handlers are asynchronous, they must not block the thread.
322///
323/// # Example
324///
325/// ```
326/// # use std::time::Duration;
327/// # use dittolive_ditto::prelude::*;
328/// # let (_root, ditto) = dittolive_ditto::doctest_helpers::doctest_online_ditto();
329/// let auth = ditto
330///     .auth()
331///     .expect("Auth is available for DittoConfigConnect::Server mode");
332///
333/// // Option 1: Use an async closure
334/// auth.set_expiration_handler(async |ditto: &Ditto, duration| {
335///     // Your authentication handler code here
336/// });
337///
338/// // Option 2: Use a closure returning an async block
339/// auth.set_expiration_handler(|ditto: &Ditto, duration| async {
340///     // Your authentication handler code here
341/// });
342///
343/// // Option 3: Use a custom type and trait impl
344/// struct MyAuthHandler;
345/// impl DittoAuthExpirationHandler for MyAuthHandler {
346///     async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
347///         // Your authentication handler code here
348///     }
349/// }
350/// auth.set_expiration_handler(MyAuthHandler);
351/// ```
352pub trait DittoAuthExpirationHandler: 'static + Send + Sync {
353    /// Provide an async handler that will be called when the authentication is expiring.
354    fn on_expiration(
355        &self,
356        ditto: &Ditto,
357        duration_remaining: Duration,
358    ) -> impl Send + Future<Output = ()>;
359}
360
361impl<F> DittoAuthExpirationHandler for F
362where
363    F: 'static + Send + Sync,
364    F: for<'r> AsyncFn2<&'r Ditto, Duration, Output = (), OutputFuture: Send>,
365{
366    async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
367        self(ditto, duration_remaining).await
368    }
369}
370
371#[async_trait]
372pub(crate) trait DynDittoAuthExpirationHandler: 'static + Send + Sync {
373    async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration);
374}
375
376#[async_trait]
377impl<F: DittoAuthExpirationHandler> DynDittoAuthExpirationHandler for F {
378    async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
379        self.on_expiration(ditto, duration_remaining).await
380    }
381}
382
383impl DittoAuthExpirationHandler for dyn '_ + DynDittoAuthExpirationHandler {
384    async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
385        self.dyn_on_expiration(ditto, duration_remaining).await
386    }
387}
388
389pub(crate) fn make_login_provider(
390    ditto_fields: Weak<DittoFields>,
391    auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
392) -> repr_c::Box<ffi_sdk::LoginProvider> {
393    struct LoginProviderCtx {
394        ditto_fields: Weak<DittoFields>,
395        auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
396    }
397
398    let login_provider_ctx = Arc::new(LoginProviderCtx {
399        auth_expiration_handler,
400        ditto_fields,
401    });
402
403    let ffi_ctx = Arc::as_ptr(&login_provider_ctx) as *mut c_void;
404    let ffi_retain = Some(extern_c(|ctx: *mut c_void| unsafe {
405        Arc::<LoginProviderCtx>::increment_strong_count(ctx.cast())
406    }) as unsafe extern "C" fn(_));
407    let ffi_release = Some(extern_c(|ctx: *mut c_void| unsafe {
408        Arc::<LoginProviderCtx>::decrement_strong_count(ctx.cast())
409    }) as unsafe extern "C" fn(_));
410
411    // This callback is just a "trigger" to initiate the caller's authentication handler
412    // The handler is async and spawned in a separate task
413    let ffi_handler = extern_c(|ctx: *mut c_void, secs_remaining: c_uint| {
414        let login_provider_ctx: &LoginProviderCtx = unsafe { &*ctx.cast() };
415        let auth_expiration_handler = login_provider_ctx.auth_expiration_handler.retain();
416        let Ok(ditto) = Ditto::upgrade(&login_provider_ctx.ditto_fields) else {
417            #[allow(deprecated)] // Workaround for patched tracing
418            {
419                error!("Failed to dispatch auth handler, Ditto has been shut down");
420            }
421            return;
422        };
423
424        dispatch_auth_handler(
425            auth_expiration_handler,
426            ditto,
427            Duration::from_secs(secs_remaining.into()),
428        );
429    });
430
431    unsafe {
432        ffi_sdk::ditto_auth_client_make_login_provider(
433            ffi_ctx,
434            ffi_retain,
435            ffi_release,
436            ffi_handler,
437        )
438    }
439}
440
441/// Dispatches the users's authentication expiration handler.
442///
443/// This function checks whether the current execution context is within a tokio runtime.
444///
445/// - If a tokio runtime is available, the handler is spawned onto a task.
446/// - If no tokio runtime is available, a temporary current-thread runtime is created and the
447///   dispatcher will block until the handler completes.
448///
449/// # Panics
450///
451/// This function will panic if:
452///
453/// - There is no current tokio runtime available, and
454/// - We fail to create a temporary current-thread runtime
455fn dispatch_auth_handler(
456    auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
457    ditto: Ditto,
458    duration_remaining: Duration,
459) {
460    // Check if we're in a tokio runtime context
461    match tokio::runtime::Handle::try_current() {
462        Ok(handle) => {
463            // We have a tokio runtime, spawn the handler as usual
464            handle.spawn(async move {
465                auth_expiration_handler
466                    .on_expiration(&ditto, duration_remaining)
467                    .await;
468            });
469        }
470        Err(_) => {
471            // No tokio runtime available, create a temporary current-thread runtime
472            #[allow(deprecated)] // Workaround for patched tracing
473            {
474                warn!(
475                    "No tokio runtime available for expiration handler. Creating temporary \
476                     runtime."
477                );
478            }
479
480            match tokio::runtime::Builder::new_current_thread()
481                .enable_all()
482                .build()
483            {
484                Ok(rt) => {
485                    rt.block_on(async move {
486                        auth_expiration_handler
487                            .on_expiration(&ditto, duration_remaining)
488                            .await;
489                    });
490                }
491                Err(e) => {
492                    panic!(
493                        "Failed to create tokio runtime for expiration handler: {}. Consider \
494                         running within a tokio runtime context.",
495                        e
496                    );
497                }
498            }
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505
506    #[test]
507    fn test_runtime_detection_behavior() {
508        // This test demonstrates the runtime detection behavior
509        println!("Testing runtime detection...");
510
511        // Check if we're in a tokio context (should be false in regular test)
512        match tokio::runtime::Handle::try_current() {
513            Ok(_handle) => {
514                println!("✓ Tokio runtime detected - handlers will be spawned");
515            }
516            Err(_) => {
517                println!("✗ No tokio runtime - will create temporary runtime");
518
519                // Demonstrate temporary runtime creation (as done in make_login_provider)
520                match tokio::runtime::Builder::new_current_thread()
521                    .enable_all()
522                    .build()
523                {
524                    Ok(_rt) => {
525                        println!("✓ Successfully created temporary runtime");
526                    }
527                    Err(e) => {
528                        println!("✗ Failed to create temporary runtime: {}", e);
529                    }
530                }
531            }
532        }
533    }
534
535    #[tokio::test]
536    async fn test_with_tokio_runtime_available() {
537        // This test runs with a tokio runtime available
538        println!("Testing with tokio runtime available...");
539
540        // This should succeed
541        match tokio::runtime::Handle::try_current() {
542            Ok(_handle) => {
543                println!("✓ Tokio runtime is available for async handlers");
544            }
545            Err(_) => {
546                panic!("Expected tokio runtime to be available in tokio::test");
547            }
548        }
549
550        println!("✓ Runtime detection works correctly in async context");
551    }
552
553    #[test]
554    fn test_improved_error_handling() {
555        // Test that we can create the same type of runtime as the implementation
556        let rt_result = tokio::runtime::Builder::new_current_thread()
557            .enable_all()
558            .build();
559
560        match rt_result {
561            Ok(rt) => {
562                println!("✓ Temporary runtime creation works");
563
564                // Demonstrate that we can use it for async work
565                rt.block_on(async {
566                    println!("✓ Async work executes successfully in temporary runtime");
567                });
568            }
569            Err(e) => {
570                println!("✗ Failed to create runtime: {}", e);
571            }
572        }
573    }
574}