dittolive_ditto/ditto/init/mod.rs
1use std::{
2 ffi::CString,
3 future::Future,
4 os::raw::{c_uint, c_void},
5 str::FromStr,
6 sync::{Arc, Weak},
7 time::Duration,
8};
9
10use async_fn_traits::AsyncFn2;
11use async_trait::async_trait;
12use extern_c::extern_c;
13use ffi_sdk::{FsComponent, TransportConfigMode};
14
15pub use self::config::{DittoConfig, DittoConfigConnect};
16use crate::{
17 ditto::{
18 init::config::{ActualConfig, InternalConfig},
19 DittoFields,
20 },
21 error,
22 identity::DittoAuthenticator,
23 small_peer_info::SmallPeerInfo,
24 utils::{make_continuation, prelude::*},
25 warn,
26};
27
28pub(crate) mod config;
29
30impl Ditto {
31 /// Open a new Ditto instance using a [`DittoConfig`]
32 ///
33 /// # Example
34 ///
35 /// ```
36 /// # use dittolive_ditto::prelude::*;
37 /// # async fn example() -> anyhow::Result<()> {
38 /// // Load your database ID somehow, ENV is a good option
39 /// let database_id = std::env::var("DITTO_DATABASE_ID")?;
40 ///
41 /// // Choose one of the following types of connection config
42 /// let connect = DittoConfigConnect::Server {
43 /// url: "https://example.com/your-server-url".parse().unwrap(),
44 /// };
45 /// let connect = DittoConfigConnect::SmallPeersOnly {
46 /// private_key: Some("https://example.com/your-server-url".bytes().collect()),
47 /// };
48 /// let connect = DittoConfigConnect::SmallPeersOnly { private_key: None };
49 ///
50 /// let config = DittoConfig::new(database_id, connect);
51 /// let ditto = Ditto::open(config).await?;
52 /// # Ok(())
53 /// # }
54 /// ```
55 pub async fn open(config: DittoConfig) -> Result<Ditto, DittoError> {
56 Self::init_sdk_version();
57 let default_root_dir = default_root_directory();
58
59 let customer_config = config;
60 let mut internal_config = InternalConfig {
61 legacy_persistence_directory: None,
62 };
63
64 // SDKS-3187: When the user hasn't set an explicit persistence directory,
65 // provide the v4 default path as a legacy fallback so that v4→v5 upgrades
66 // don't silently create a new empty database. The v4 Rust SDK used the
67 // executable's parent directory directly as the default root.
68 if customer_config.persistence_directory.is_none() {
69 if let Some(root) = &default_root_dir {
70 internal_config.legacy_persistence_directory = Some(root.clone());
71 }
72 }
73
74 let actual_config = ActualConfig {
75 customer_facing: customer_config,
76 internal: internal_config,
77 };
78
79 let config_cbor: &[u8] =
80 &serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
81 let (continuation, recv) = make_continuation();
82 let default_root_dir_ref = default_root_dir
83 .as_deref()
84 .and_then(|root| root.to_str())
85 .ok_or_else(|| {
86 DittoError::new(
87 ErrorKind::IO,
88 "Unable to resolve a default data directory on this platform".to_string(),
89 )
90 })?;
91 let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
92 .expect("should construct CString from no-nulls &str");
93 let default_root_dir_cstr = &*default_root_dir_cstring;
94 let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
95
96 // CLIPPY: continuation.into() needed to pass CI
97 #[allow(clippy::useless_conversion)]
98 ffi_sdk::dittoffi_ditto_open_async_throws(
99 config_cbor.into(),
100 TransportConfigMode::PlatformIndependent,
101 default_root_dir_charp,
102 continuation.into(),
103 );
104
105 let ffi_result = recv.await.unwrap();
106 let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_result.into_rust_result()?;
107
108 Self::finish_open(ffi_ditto, actual_config.customer_facing)
109 }
110
111 /// Open a new Ditto instance using a [`DittoConfig`]
112 ///
113 /// This is a synchronous blocking variant of [`Ditto::open()`] that will wait until
114 /// initialization is complete.
115 ///
116 /// # Example
117 ///
118 /// ```
119 /// # use dittolive_ditto::prelude::*;
120 /// # fn example() -> anyhow::Result<()> {
121 /// // Load your database ID somehow, ENV is a good option
122 /// let database_id = std::env::var("DITTO_DATABASE_ID")?;
123 ///
124 /// // Choose one of the following types of connection config
125 /// let connect = DittoConfigConnect::Server {
126 /// url: "https://example.com/your-server-url".parse().unwrap(),
127 /// };
128 /// let connect = DittoConfigConnect::SmallPeersOnly {
129 /// private_key: Some("https://example.com/your-server-url".bytes().collect()),
130 /// };
131 /// let connect = DittoConfigConnect::SmallPeersOnly { private_key: None };
132 ///
133 /// let config = DittoConfig::new(database_id, connect);
134 /// let ditto = Ditto::open_sync(config)?;
135 /// # Ok(())
136 /// # }
137 /// ```
138 pub fn open_sync(config: DittoConfig) -> Result<Ditto, DittoError> {
139 Self::init_sdk_version();
140 let default_root_dir = default_root_directory();
141
142 let customer_config = config;
143 let mut internal_config = InternalConfig {
144 legacy_persistence_directory: None,
145 };
146
147 // SDKS-3187: Same legacy fallback as open() above.
148 if customer_config.persistence_directory.is_none() {
149 if let Some(root) = &default_root_dir {
150 internal_config.legacy_persistence_directory = Some(root.clone());
151 }
152 }
153
154 let actual_config = ActualConfig {
155 customer_facing: customer_config,
156 internal: internal_config,
157 };
158
159 let config_cbor: &[u8] =
160 &serde_cbor::to_vec(&actual_config).expect("should serialize well-formed DittoConfig");
161 let default_root_dir_ref = default_root_dir
162 .as_deref()
163 .and_then(|root| root.to_str())
164 .ok_or_else(|| {
165 DittoError::new(
166 ErrorKind::IO,
167 "Unable to resolve a default data directory on this platform".to_string(),
168 )
169 })?;
170 let default_root_dir_cstring = CString::from_str(default_root_dir_ref)
171 .expect("should construct CString from no-nulls &str");
172 let default_root_dir_cstr = &*default_root_dir_cstring;
173 let default_root_dir_charp: char_p::Ref<'_> = default_root_dir_cstr.into();
174
175 let ffi_ditto: repr_c::Box<ffi_sdk::Ditto> = ffi_sdk::dittoffi_ditto_open_throws(
176 config_cbor.into(),
177 TransportConfigMode::PlatformIndependent,
178 default_root_dir_charp,
179 )
180 .into_rust_result()?;
181
182 Self::finish_open(ffi_ditto, actual_config.customer_facing)
183 }
184
185 fn finish_open(
186 ffi_ditto: repr_c::Box<ffi_sdk::Ditto>,
187 config: DittoConfig,
188 ) -> Result<Ditto, DittoError> {
189 let ditto: Arc<repr_c::Box<ffi_sdk::Ditto>> = Arc::new(ffi_ditto);
190 let has_auth = matches!(&config.connect, DittoConfigConnect::Server { .. });
191
192 let disk_usage = DiskUsage::new(ditto.retain(), FsComponent::Root);
193 let small_peer_info = SmallPeerInfo::new(ditto.retain());
194 let fields = Arc::new_cyclic(|weak_fields: &arc::Weak<_>| {
195 let store = Store::new(ditto.retain(), weak_fields.clone());
196 let sync = crate::sync::Sync::new(weak_fields.clone());
197 let presence = Arc::new(Presence::new(weak_fields.clone()));
198
199 DittoFields {
200 ditto: ditto.retain(),
201 has_auth,
202 config,
203 store,
204 sync,
205 presence,
206 disk_usage,
207 small_peer_info,
208 }
209 });
210
211 let ditto = Ditto {
212 fields,
213 is_shut_down_able: true,
214 };
215
216 Ok(ditto)
217 }
218}
219
220fn default_root_directory() -> Option<PathBuf> {
221 std::env::current_exe()
222 .ok()
223 .and_then(|abspath| abspath.parent().map(|x| x.to_path_buf()))
224}
225
226impl DittoAuthenticator {
227 /// Set a callback to notify the client when the authentication is expiring.
228 ///
229 /// When using `DittoConfigConnect::Server { .. }` mode, Ditto _requires_ you to register an
230 /// "expiration handler" for authentication. This handler is called for an initial
231 /// authentication and periodically thereafter when the current authentication is near to
232 /// expiration.
233 ///
234 /// For more details about authentication, see the [Ditto Auth and Authorization docs][0].
235 ///
236 /// [0]: https://docs.ditto.live/sdk/latest/auth-and-authorization/cloud-authentication#login
237 ///
238 /// For more details about the expiration handler, see the [`DittoAuthExpirationHandler`] trait.
239 ///
240 /// # Example
241 ///
242 /// ```
243 /// # #[allow(deprecated)]
244 /// # use tracing::error;
245 /// # use dittolive_ditto::prelude::*;
246 /// # fn main() -> anyhow::Result<()> {
247 /// # let (_root, ditto) = dittolive_ditto::doctest_helpers::doctest_online_ditto();
248 /// async fn sample_get_token() -> anyhow::Result<String> {
249 /// // e.g. reqwest::get("https://example.com/token").await?.text().await?;
250 /// Ok("token".to_string())
251 /// }
252 ///
253 /// let auth = ditto
254 /// .auth()
255 /// .expect("Auth is available in Server connect mode");
256 /// auth.set_expiration_handler(async |ditto: &Ditto, duration_remaining| {
257 /// let auth = ditto
258 /// .auth()
259 /// .expect("Auth is available in Server connect mode");
260 ///
261 /// // Call your auth service to get a new token
262 /// let token = sample_get_token().await.unwrap();
263 ///
264 /// // Where "my-provider" is the name of the Authentication Webhook
265 /// // you've configured on the Ditto Portal
266 /// let result = auth.login(&token, "my-provider");
267 ///
268 /// if let Err(login_error) = result {
269 /// // Handle login error, e.g.:
270 /// error!("Failed to login: {}", login_error);
271 /// }
272 /// });
273 /// # Ok(())
274 /// # }
275 /// ```
276 pub fn set_expiration_handler<F>(&self, handler: F)
277 where
278 F: DittoAuthExpirationHandler,
279 {
280 let Some(ditto) = self.ditto_fields.upgrade() else {
281 #[allow(deprecated)] // Workaround for patched tracing
282 {
283 error!("Failed to set expiration handler, Ditto has shut down");
284 }
285 return;
286 };
287
288 let login_provider = make_login_provider(self.ditto_fields.clone(), Arc::new(handler));
289 ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, Some(login_provider));
290 }
291
292 /// Clear the expiration handler, if set.
293 pub fn clear_expiration_handler(&self) {
294 let Some(ditto) = self.ditto_fields.upgrade() else {
295 #[allow(deprecated)] // Workaround for patched tracing
296 {
297 error!("Failed to clear expiration handler, Ditto has shut down");
298 }
299 return;
300 };
301
302 ffi_sdk::ditto_auth_set_login_provider(&ditto.ditto, None);
303 }
304}
305
306/// Trait describing types which can be used as an authentication expiration handler for Ditto.
307///
308/// When using `DittoConfigConnect::Server { .. }` mode, Ditto _requires_ you to register an
309/// "expiration handler" for authentication. This handler is called for an initial authentication
310/// and periodically thereafter when the current authentication is near expiration.
311///
312/// This trait is implemented for async closures and for functions returning an
313/// `impl Future<Output = ()>`, which take the expected arguments of an auth expiration handler.
314///
315/// Expiration handlers are expected to call `auth.login(...)` with a valid authentication token.
316///
317/// For more details about authentication, see the [Ditto Auth and Authorization docs][0].
318///
319/// [0]: https://docs.ditto.live/sdk/latest/auth-and-authorization/cloud-authentication#login
320///
321/// **NOTE**: Because expiration handlers are asynchronous, they must not block the thread.
322///
323/// # Example
324///
325/// ```
326/// # use std::time::Duration;
327/// # use dittolive_ditto::prelude::*;
328/// # let (_root, ditto) = dittolive_ditto::doctest_helpers::doctest_online_ditto();
329/// let auth = ditto
330/// .auth()
331/// .expect("Auth is available for DittoConfigConnect::Server mode");
332///
333/// // Option 1: Use an async closure
334/// auth.set_expiration_handler(async |ditto: &Ditto, duration| {
335/// // Your authentication handler code here
336/// });
337///
338/// // Option 2: Use a closure returning an async block
339/// auth.set_expiration_handler(|ditto: &Ditto, duration| async {
340/// // Your authentication handler code here
341/// });
342///
343/// // Option 3: Use a custom type and trait impl
344/// struct MyAuthHandler;
345/// impl DittoAuthExpirationHandler for MyAuthHandler {
346/// async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
347/// // Your authentication handler code here
348/// }
349/// }
350/// auth.set_expiration_handler(MyAuthHandler);
351/// ```
352pub trait DittoAuthExpirationHandler: 'static + Send + Sync {
353 /// Provide an async handler that will be called when the authentication is expiring.
354 fn on_expiration(
355 &self,
356 ditto: &Ditto,
357 duration_remaining: Duration,
358 ) -> impl Send + Future<Output = ()>;
359}
360
361impl<F> DittoAuthExpirationHandler for F
362where
363 F: 'static + Send + Sync,
364 F: for<'r> AsyncFn2<&'r Ditto, Duration, Output = (), OutputFuture: Send>,
365{
366 async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
367 self(ditto, duration_remaining).await
368 }
369}
370
371#[async_trait]
372pub(crate) trait DynDittoAuthExpirationHandler: 'static + Send + Sync {
373 async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration);
374}
375
376#[async_trait]
377impl<F: DittoAuthExpirationHandler> DynDittoAuthExpirationHandler for F {
378 async fn dyn_on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
379 self.on_expiration(ditto, duration_remaining).await
380 }
381}
382
383impl DittoAuthExpirationHandler for dyn '_ + DynDittoAuthExpirationHandler {
384 async fn on_expiration(&self, ditto: &Ditto, duration_remaining: Duration) {
385 self.dyn_on_expiration(ditto, duration_remaining).await
386 }
387}
388
389pub(crate) fn make_login_provider(
390 ditto_fields: Weak<DittoFields>,
391 auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
392) -> repr_c::Box<ffi_sdk::LoginProvider> {
393 struct LoginProviderCtx {
394 ditto_fields: Weak<DittoFields>,
395 auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
396 }
397
398 let login_provider_ctx = Arc::new(LoginProviderCtx {
399 auth_expiration_handler,
400 ditto_fields,
401 });
402
403 let ffi_ctx = Arc::as_ptr(&login_provider_ctx) as *mut c_void;
404 let ffi_retain = Some(extern_c(|ctx: *mut c_void| unsafe {
405 Arc::<LoginProviderCtx>::increment_strong_count(ctx.cast())
406 }) as unsafe extern "C" fn(_));
407 let ffi_release = Some(extern_c(|ctx: *mut c_void| unsafe {
408 Arc::<LoginProviderCtx>::decrement_strong_count(ctx.cast())
409 }) as unsafe extern "C" fn(_));
410
411 // This callback is just a "trigger" to initiate the caller's authentication handler
412 // The handler is async and spawned in a separate task
413 let ffi_handler = extern_c(|ctx: *mut c_void, secs_remaining: c_uint| {
414 let login_provider_ctx: &LoginProviderCtx = unsafe { &*ctx.cast() };
415 let auth_expiration_handler = login_provider_ctx.auth_expiration_handler.retain();
416 let Ok(ditto) = Ditto::upgrade(&login_provider_ctx.ditto_fields) else {
417 #[allow(deprecated)] // Workaround for patched tracing
418 {
419 error!("Failed to dispatch auth handler, Ditto has been shut down");
420 }
421 return;
422 };
423
424 dispatch_auth_handler(
425 auth_expiration_handler,
426 ditto,
427 Duration::from_secs(secs_remaining.into()),
428 );
429 });
430
431 unsafe {
432 ffi_sdk::ditto_auth_client_make_login_provider(
433 ffi_ctx,
434 ffi_retain,
435 ffi_release,
436 ffi_handler,
437 )
438 }
439}
440
441/// Dispatches the users's authentication expiration handler.
442///
443/// This function checks whether the current execution context is within a tokio runtime.
444///
445/// - If a tokio runtime is available, the handler is spawned onto a task.
446/// - If no tokio runtime is available, a temporary current-thread runtime is created and the
447/// dispatcher will block until the handler completes.
448///
449/// # Panics
450///
451/// This function will panic if:
452///
453/// - There is no current tokio runtime available, and
454/// - We fail to create a temporary current-thread runtime
455fn dispatch_auth_handler(
456 auth_expiration_handler: Arc<dyn DynDittoAuthExpirationHandler>,
457 ditto: Ditto,
458 duration_remaining: Duration,
459) {
460 // Check if we're in a tokio runtime context
461 match tokio::runtime::Handle::try_current() {
462 Ok(handle) => {
463 // We have a tokio runtime, spawn the handler as usual
464 handle.spawn(async move {
465 auth_expiration_handler
466 .on_expiration(&ditto, duration_remaining)
467 .await;
468 });
469 }
470 Err(_) => {
471 // No tokio runtime available, create a temporary current-thread runtime
472 #[allow(deprecated)] // Workaround for patched tracing
473 {
474 warn!(
475 "No tokio runtime available for expiration handler. Creating temporary \
476 runtime."
477 );
478 }
479
480 match tokio::runtime::Builder::new_current_thread()
481 .enable_all()
482 .build()
483 {
484 Ok(rt) => {
485 rt.block_on(async move {
486 auth_expiration_handler
487 .on_expiration(&ditto, duration_remaining)
488 .await;
489 });
490 }
491 Err(e) => {
492 panic!(
493 "Failed to create tokio runtime for expiration handler: {}. Consider \
494 running within a tokio runtime context.",
495 e
496 );
497 }
498 }
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505
506 #[test]
507 fn test_runtime_detection_behavior() {
508 // This test demonstrates the runtime detection behavior
509 println!("Testing runtime detection...");
510
511 // Check if we're in a tokio context (should be false in regular test)
512 match tokio::runtime::Handle::try_current() {
513 Ok(_handle) => {
514 println!("✓ Tokio runtime detected - handlers will be spawned");
515 }
516 Err(_) => {
517 println!("✗ No tokio runtime - will create temporary runtime");
518
519 // Demonstrate temporary runtime creation (as done in make_login_provider)
520 match tokio::runtime::Builder::new_current_thread()
521 .enable_all()
522 .build()
523 {
524 Ok(_rt) => {
525 println!("✓ Successfully created temporary runtime");
526 }
527 Err(e) => {
528 println!("✗ Failed to create temporary runtime: {}", e);
529 }
530 }
531 }
532 }
533 }
534
535 #[tokio::test]
536 async fn test_with_tokio_runtime_available() {
537 // This test runs with a tokio runtime available
538 println!("Testing with tokio runtime available...");
539
540 // This should succeed
541 match tokio::runtime::Handle::try_current() {
542 Ok(_handle) => {
543 println!("✓ Tokio runtime is available for async handlers");
544 }
545 Err(_) => {
546 panic!("Expected tokio runtime to be available in tokio::test");
547 }
548 }
549
550 println!("✓ Runtime detection works correctly in async context");
551 }
552
553 #[test]
554 fn test_improved_error_handling() {
555 // Test that we can create the same type of runtime as the implementation
556 let rt_result = tokio::runtime::Builder::new_current_thread()
557 .enable_all()
558 .build();
559
560 match rt_result {
561 Ok(rt) => {
562 println!("✓ Temporary runtime creation works");
563
564 // Demonstrate that we can use it for async work
565 rt.block_on(async {
566 println!("✓ Async work executes successfully in temporary runtime");
567 });
568 }
569 Err(e) => {
570 println!("✗ Failed to create runtime: {}", e);
571 }
572 }
573 }
574}