Skip to main content

brows3r_lib/media_server/
mod.rs

1//! Loopback media server — streams S3 objects over `http://127.0.0.1:<port>`.
2//!
3//! # Overview
4//!
5//! At app start, [`start_on_localhost`] spawns an `axum` HTTP server bound to
6//! `127.0.0.1:0` (OS-assigned port).  The caller receives a
7//! [`MediaServerHandle`] that carries the port, the shared
8//! [`TokenRegistryHandle`], and a shutdown sender.
9//!
10//! The server exposes two routes:
11//!
12//! - `GET /m/:token` — validates the token, streams the S3 object (with
13//!   optional byte-range support so video `<seek>` works).
14//! - `GET /healthz` — returns `200 OK`; used by diagnostics and tests.
15//!
16//! # Range support
17//!
18//! The server forwards an `Range: bytes=START-END` header to the S3
19//! `get_object` call and returns a `206 Partial Content` response with the
20//! matching `Content-Range` header.
21//!
22//! # OCP
23//!
24//! - Adding a new route is one `.route(...)` call in `build_router`.
25//! - The range parser ([`parse_range`]) is a pure function, independently
26//!   testable, and reusable by other routes.
27//! - [`TokenRegistry`] is decoupled from the server; swapping the storage
28//!   backend requires only changing [`TokenRegistryHandle`].
29
30pub mod tokens;
31
32use std::sync::{Arc, Mutex};
33
34use axum::{
35    body::Body,
36    extract::{Path, State as AxumState},
37    http::{HeaderMap, StatusCode},
38    response::{IntoResponse, Response},
39    routing::get,
40    Router,
41};
42
43use tokio::{net::TcpListener, sync::oneshot};
44use tokio_util::io::ReaderStream;
45
46pub use tokens::{TokenRegistry, TokenRegistryHandle};
47
48use crate::{error::AppError, s3::S3ClientPoolHandle};
49
50// ---------------------------------------------------------------------------
51// RangeSpec — parsed byte range
52// ---------------------------------------------------------------------------
53
54/// A parsed byte range from an HTTP `Range: bytes=…` header.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum RangeSpec {
57    /// `bytes=START-END` (both bounds inclusive).
58    Bounded { start: u64, end: u64 },
59    /// `bytes=START-` (from START to end of file).
60    From { start: u64 },
61    /// `bytes=-SUFFIX` (last SUFFIX bytes).
62    Suffix { last: u64 },
63}
64
65/// Parse `Range: bytes=<spec>` into a [`RangeSpec`].
66///
67/// Returns `None` for absent, malformed, or non-bytes range headers.
68pub fn parse_range(header: &str) -> Option<RangeSpec> {
69    let spec = header.strip_prefix("bytes=")?;
70    if let Some(last_str) = spec.strip_prefix('-') {
71        // bytes=-N  (suffix)
72        let last: u64 = last_str.parse().ok()?;
73        return Some(RangeSpec::Suffix { last });
74    }
75    let mut parts = spec.splitn(2, '-');
76    let start_str = parts.next()?;
77    let end_str = parts.next()?;
78    let start: u64 = start_str.parse().ok()?;
79    if end_str.is_empty() {
80        // bytes=START-
81        Some(RangeSpec::From { start })
82    } else {
83        // bytes=START-END
84        let end: u64 = end_str.parse().ok()?;
85        Some(RangeSpec::Bounded { start, end })
86    }
87}
88
89/// Convert a [`RangeSpec`] into the `Range` header value forwarded to S3.
90fn range_spec_to_s3(spec: &RangeSpec) -> String {
91    match spec {
92        RangeSpec::Bounded { start, end } => format!("bytes={start}-{end}"),
93        RangeSpec::From { start } => format!("bytes={start}-"),
94        RangeSpec::Suffix { last } => format!("bytes=-{last}"),
95    }
96}
97
98// ---------------------------------------------------------------------------
99// AppState — shared across all axum handlers
100// ---------------------------------------------------------------------------
101
102#[derive(Clone)]
103struct AppState {
104    registry: TokenRegistryHandle,
105    pool: S3ClientPoolHandle,
106}
107
108// ---------------------------------------------------------------------------
109// Handlers
110// ---------------------------------------------------------------------------
111
112/// `GET /healthz` — returns 200 OK for diagnostics.
113async fn healthz() -> impl IntoResponse {
114    (StatusCode::OK, "ok")
115}
116
117/// `GET /m/:token` — validate token and stream the S3 object.
118async fn serve_media(
119    Path(token): Path<String>,
120    headers: HeaderMap,
121    AxumState(state): AxumState<AppState>,
122) -> Response {
123    // 1. Token lookup — distinguish unknown (404) from expired (403).
124    let record = match state.registry.lookup_with_status(&token) {
125        Err(()) => {
126            return (StatusCode::NOT_FOUND, "token not found").into_response();
127        }
128        Ok(None) => {
129            return (StatusCode::FORBIDDEN, "token expired").into_response();
130        }
131        Ok(Some(r)) => r,
132    };
133
134    // 2. Parse optional Range header.
135    let range_spec = headers
136        .get(axum::http::header::RANGE)
137        .and_then(|v| v.to_str().ok())
138        .and_then(parse_range);
139
140    // 3. Build S3 client and call get_object.
141    let client = match state
142        .pool
143        .inner
144        .get_or_build(&record.profile_id, &record.region)
145        .await
146    {
147        Some(c) => c,
148        None => {
149            eprintln!(
150                "[media_server] no S3 client for profile {}",
151                record.profile_id
152            );
153            return (StatusCode::INTERNAL_SERVER_ERROR, "s3 client unavailable").into_response();
154        }
155    };
156
157    let mut req = client
158        .get_object()
159        .bucket(record.bucket.as_str())
160        .key(&record.key);
161
162    if let Some(ref spec) = range_spec {
163        req = req.range(range_spec_to_s3(spec));
164    }
165
166    let output = match req.send().await {
167        Ok(o) => o,
168        Err(e) => {
169            eprintln!("[media_server] get_object error: {e}");
170            return (StatusCode::BAD_GATEWAY, "s3 error").into_response();
171        }
172    };
173
174    // 4. Build response.
175    let content_type = output
176        .content_type()
177        .unwrap_or("application/octet-stream")
178        .to_owned();
179
180    let content_length = output.content_length();
181    let content_range = output.content_range().map(|s| s.to_owned());
182
183    let stream = ReaderStream::new(output.body.into_async_read());
184    let body = Body::from_stream(stream);
185
186    let status = if range_spec.is_some() {
187        StatusCode::PARTIAL_CONTENT
188    } else {
189        StatusCode::OK
190    };
191
192    // CORS: the loopback origin (127.0.0.1:<port>) is different from the
193    // WebView origin (localhost:1420 in dev, tauri://localhost in release).
194    // <img>/<video>/<audio>/<iframe> with crossorigin-less src bypass CORS,
195    // but pdf.js uses fetch() under the hood and the browser rejects the
196    // response without an explicit Access-Control-Allow-Origin header.
197    //
198    // We mint loopback URLs ourselves and the tokens are unguessable, so a
199    // permissive `*` here doesn't expose anything an attacker couldn't get
200    // by guessing the token first.
201    let mut builder = axum::response::Response::builder()
202        .status(status)
203        .header(axum::http::header::CONTENT_TYPE, content_type)
204        .header(axum::http::header::ACCESS_CONTROL_ALLOW_ORIGIN, "*");
205
206    if let Some(len) = content_length {
207        if len >= 0 {
208            builder = builder.header(axum::http::header::CONTENT_LENGTH, len);
209        }
210    }
211    if let Some(cr) = content_range {
212        builder = builder.header("Content-Range", cr);
213    }
214
215    builder.body(body).unwrap_or_else(|_| {
216        (StatusCode::INTERNAL_SERVER_ERROR, "response build error").into_response()
217    })
218}
219
220// ---------------------------------------------------------------------------
221// Router builder
222// ---------------------------------------------------------------------------
223
224fn build_router(state: AppState) -> Router {
225    Router::new()
226        .route("/healthz", get(healthz))
227        // axum 0.8 changed path-param syntax: ":token" → "{token}".
228        .route("/m/{token}", get(serve_media))
229        .with_state(state)
230}
231
232// ---------------------------------------------------------------------------
233// MediaServerHandle — returned to the caller after startup
234// ---------------------------------------------------------------------------
235
236/// Handle to the running loopback media server.
237///
238/// Tauri manages this as app state so commands can read `port` and access the
239/// `registry` to mint / revoke tokens.
240///
241/// The `shutdown` sender is wrapped in a `Mutex<Option<_>>` so that:
242/// 1. `MediaServerHandle` is `Sync` (required by Tauri managed state).
243/// 2. Shutdown can be triggered exactly once by taking the sender.
244pub struct MediaServerHandle {
245    /// The OS-assigned port the server is bound to.
246    pub port: u16,
247    /// Shared token registry — commands mint and revoke tokens here.
248    pub registry: TokenRegistryHandle,
249    /// Session identifier minted at app start; all tokens are tagged with this.
250    /// `revoke_session` is called with this on app exit.
251    pub session_id: String,
252    /// Send `()` to trigger graceful shutdown (consume with `.lock().take()`).
253    pub shutdown: Mutex<Option<oneshot::Sender<()>>>,
254}
255
256// ---------------------------------------------------------------------------
257// start_on_localhost
258// ---------------------------------------------------------------------------
259
260/// Start the loopback media server and return a [`MediaServerHandle`].
261///
262/// Binds to `127.0.0.1:0` (OS-assigned port), spawns the server on the
263/// current Tokio runtime, and returns immediately.
264///
265/// The caller is responsible for calling `handle.shutdown.send(())` on app
266/// exit to stop the server, and calling `registry.revoke_session(session_id)`
267/// to sweep all live tokens.
268///
269/// # Arguments
270///
271/// - `pool` — S3 client pool passed through to each request handler.
272/// - `registry` — shared token registry.
273/// - `session_id` — UUID v4 string minted at app start; tags all tokens.
274pub async fn start_on_localhost(
275    pool: S3ClientPoolHandle,
276    registry: TokenRegistryHandle,
277    session_id: String,
278) -> Result<MediaServerHandle, AppError> {
279    let listener = TcpListener::bind("127.0.0.1:0")
280        .await
281        .map_err(|e| AppError::Internal {
282            trace_id: format!("media server bind: {e}"),
283        })?;
284
285    let port = listener
286        .local_addr()
287        .map_err(|e| AppError::Internal {
288            trace_id: format!("media server local_addr: {e}"),
289        })?
290        .port();
291
292    let state = AppState {
293        registry: Arc::clone(&registry),
294        pool,
295    };
296
297    let router = build_router(state);
298
299    let (tx, rx) = oneshot::channel::<()>();
300
301    tokio::spawn(async move {
302        let server = axum::serve(listener, router).with_graceful_shutdown(async move {
303            let _ = rx.await;
304        });
305        if let Err(e) = server.await {
306            eprintln!("[media_server] exited with error: {e}");
307        }
308    });
309
310    eprintln!("[media_server] listening on 127.0.0.1:{port}");
311
312    Ok(MediaServerHandle {
313        port,
314        registry,
315        session_id,
316        shutdown: Mutex::new(Some(tx)),
317    })
318}
319
320// ---------------------------------------------------------------------------
321// Tests — range parser
322// ---------------------------------------------------------------------------
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327
328    #[test]
329    fn parse_range_bounded() {
330        let spec = parse_range("bytes=0-1023").unwrap();
331        assert_eq!(
332            spec,
333            RangeSpec::Bounded {
334                start: 0,
335                end: 1023
336            }
337        );
338    }
339
340    #[test]
341    fn parse_range_open_ended() {
342        let spec = parse_range("bytes=500-").unwrap();
343        assert_eq!(spec, RangeSpec::From { start: 500 });
344    }
345
346    #[test]
347    fn parse_range_suffix() {
348        let spec = parse_range("bytes=-500").unwrap();
349        assert_eq!(spec, RangeSpec::Suffix { last: 500 });
350    }
351
352    #[test]
353    fn parse_range_missing_prefix_returns_none() {
354        assert!(parse_range("0-1023").is_none());
355    }
356
357    #[test]
358    fn parse_range_malformed_returns_none() {
359        assert!(parse_range("bytes=abc-def").is_none());
360    }
361
362    #[test]
363    fn parse_range_non_bytes_unit_returns_none() {
364        assert!(parse_range("items=0-10").is_none());
365    }
366
367    #[test]
368    fn range_spec_to_s3_bounded() {
369        assert_eq!(
370            range_spec_to_s3(&RangeSpec::Bounded {
371                start: 0,
372                end: 1023
373            }),
374            "bytes=0-1023"
375        );
376    }
377
378    #[test]
379    fn range_spec_to_s3_from() {
380        assert_eq!(
381            range_spec_to_s3(&RangeSpec::From { start: 500 }),
382            "bytes=500-"
383        );
384    }
385
386    #[test]
387    fn range_spec_to_s3_suffix() {
388        assert_eq!(
389            range_spec_to_s3(&RangeSpec::Suffix { last: 500 }),
390            "bytes=-500"
391        );
392    }
393}