Skip to main content

brows3r_lib/profiles/
validation.rs

1//! Profile validation — `sts:GetCallerIdentity` for AWS, `list_buckets` probe
2//! for compat providers.
3//!
4//! # Design
5//!
6//! - AWS profiles (no `endpoint_url`): call `sts:GetCallerIdentity` via
7//!   `aws-sdk-sts`. This surfaces the `account` and `arn` of the caller.
8//! - Compat providers (has `endpoint_url`): call `s3:ListBuckets` as the probe
9//!   because STS may not be supported at those endpoints.
10//!
11//! # Error mapping
12//!
13//! SDK errors are centralized in `map_sts_error` / `map_s3_list_error`.
14//! Adding a new SDK error code means adding one arm to those functions — no
15//! other code changes.
16//!
17//! # OCP
18//!
19//! - `ProviderKind` enum is open for new variants (`Sso`, `FederatedEnterprise`).
20//! - `validate_profile` accepts an injected `&ClientPool` — testable with real
21//!   LocalStack or a mock.
22//! - The `validate_with_caller<F>` helper exposes pure error-mapping logic to
23//!   unit tests without making any AWS SDK call.
24
25use std::sync::Arc;
26
27use aws_config::BehaviorVersion;
28use aws_credential_types::provider::SharedCredentialsProvider;
29use aws_sdk_s3::config::Builder as S3ConfigBuilder;
30use aws_smithy_http_client::{tls, Builder as HttpBuilder};
31use serde::Serialize;
32
33use crate::{
34    error::AppError,
35    ids::ProfileId,
36    profiles::{keychain::Secret, Profile, ProfileSource},
37    s3::ClientPool,
38};
39
40// ---------------------------------------------------------------------------
41// ProviderKind — open for extension
42// ---------------------------------------------------------------------------
43
44/// The category of S3 provider a profile targets.
45///
46/// OCP: add `Sso`, `FederatedEnterprise`, `WebIdentity`, … as new variants
47/// without changing any existing arm.
48#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
49#[serde(rename_all = "camelCase")]
50pub enum ProviderKind {
51    /// Standard AWS — validated via `sts:GetCallerIdentity`.
52    Aws,
53    /// S3-compatible provider (MinIO, LocalStack, R2, …) — validated via
54    /// `s3:ListBuckets` probe.
55    Compatible,
56}
57
58// ---------------------------------------------------------------------------
59// ValidationReport
60// ---------------------------------------------------------------------------
61
62/// Result of a `profile_validate` call.
63///
64/// `ok = true` means the validation probe succeeded.  `ok = false` means it
65/// failed; the `error` field carries the mapped `AppError`.
66///
67/// `account_id` and `arn` are populated only for AWS profiles that succeed.
68#[derive(Debug, Clone, Serialize)]
69#[serde(rename_all = "camelCase")]
70pub struct ValidationReport {
71    /// The profile that was validated.
72    pub profile_id: ProfileId,
73    /// Whether the probe succeeded.
74    pub ok: bool,
75    /// AWS account ID returned by `GetCallerIdentity`. `None` for compat
76    /// providers or on failure.
77    pub account_id: Option<String>,
78    /// IAM ARN returned by `GetCallerIdentity`. `None` for compat providers
79    /// or on failure.
80    pub arn: Option<String>,
81    /// Unix-millisecond timestamp of when validation ran. `0` on failure.
82    pub validated_at: i64,
83    /// Provider category used for the probe.
84    pub provider_kind: ProviderKind,
85    /// Mapped error when `ok = false`. `None` on success.
86    pub error: Option<AppError>,
87}
88
89// ---------------------------------------------------------------------------
90// now_unix_ms — current time in Unix milliseconds
91// ---------------------------------------------------------------------------
92
93fn now_unix_ms() -> i64 {
94    use std::time::{SystemTime, UNIX_EPOCH};
95    SystemTime::now()
96        .duration_since(UNIX_EPOCH)
97        .map(|d| d.as_millis() as i64)
98        .unwrap_or(0)
99}
100
101// ---------------------------------------------------------------------------
102// Error mapping helpers
103// ---------------------------------------------------------------------------
104
105/// Categorize a raw STS SDK error string into the canonical `AppError`.
106///
107/// Centralized so unit tests can exercise the mapping without a live AWS call.
108///
109/// `status` is the HTTP status code (0 if unknown); `code` is the AWS error
110/// code string; `message` is the human-readable error from the SDK.
111fn map_sts_error(status: u16, code: &str, message: &str) -> AppError {
112    match code {
113        // Credential or token problems.
114        "InvalidClientTokenId"
115        | "SignatureDoesNotMatch"
116        | "ExpiredTokenException"
117        | "ExpiredToken"
118        | "UnrecognizedClientException"
119        | "InvalidAccessKeyId" => AppError::Auth {
120            reason: code.to_string(),
121        },
122        // 5xx or connection-level errors.
123        _ if status >= 500 => AppError::Network {
124            source: message.to_string(),
125        },
126        // Everything else.
127        _ => AppError::ProviderSpecific {
128            code: code.to_string(),
129            message: message.to_string(),
130        },
131    }
132}
133
134/// Categorize a raw S3 ListBuckets SDK error into the canonical `AppError`.
135fn map_s3_list_error(status: u16, code: &str, message: &str) -> AppError {
136    match code {
137        // Credential or token problems.
138        "InvalidClientTokenId"
139        | "SignatureDoesNotMatch"
140        | "ExpiredTokenException"
141        | "ExpiredToken"
142        | "UnrecognizedClientException"
143        | "InvalidAccessKeyId" => AppError::Auth {
144            reason: code.to_string(),
145        },
146        // 403 Forbidden = access denied.
147        "AccessDenied" | "Forbidden" => AppError::AccessDenied {
148            op: "ListBuckets".to_string(),
149            resource: "*".to_string(),
150        },
151        _ if status == 403 => AppError::AccessDenied {
152            op: "ListBuckets".to_string(),
153            resource: "*".to_string(),
154        },
155        // 5xx or connection-level errors.
156        _ if status >= 500 => AppError::Network {
157            source: message.to_string(),
158        },
159        // Everything else.
160        _ => AppError::ProviderSpecific {
161            code: code.to_string(),
162            message: message.to_string(),
163        },
164    }
165}
166
167// ---------------------------------------------------------------------------
168// validate_with_caller — pure error-mapping helper (testable without SDK)
169// ---------------------------------------------------------------------------
170
171/// Result type for the STS caller identity response.
172#[derive(Debug, Clone)]
173pub struct CallerIdentity {
174    pub account_id: String,
175    pub arn: String,
176}
177
178/// Pure validation logic for AWS profiles.
179///
180/// Accepts a `caller` closure that returns the STS result (or an error triple
181/// `(status, code, message)`). This makes the mapping logic testable without
182/// making any network call.
183///
184/// # Returns
185///
186/// - `Ok(CallerIdentity)` on success.
187/// - `Err(AppError)` — the mapped error on failure.
188pub fn validate_with_caller<F>(
189    profile_id: &ProfileId,
190    caller: F,
191) -> Result<CallerIdentity, AppError>
192where
193    F: FnOnce() -> Result<CallerIdentity, (u16, String, String)>,
194{
195    caller().map_err(|(status, code, message)| {
196        let _ = profile_id; // used by callers for context; keep the parameter.
197        map_sts_error(status, &code, &message)
198    })
199}
200
201// ---------------------------------------------------------------------------
202// build_sts_client — build a one-shot STS client from a profile + secret
203// ---------------------------------------------------------------------------
204
205/// Build an `aws_sdk_sts::Client` for the given profile.
206///
207/// This is intentionally NOT pooled — STS clients are only used during
208/// validation, not on the hot path. We build a fresh one per validation call.
209async fn build_sts_client(profile: &Profile, secret: Option<&Secret>) -> aws_sdk_sts::Client {
210    let http_client = HttpBuilder::new()
211        .tls_provider(tls::Provider::Rustls(
212            aws_smithy_http_client::tls::rustls_provider::CryptoMode::Ring,
213        ))
214        .build_https();
215
216    let region_str = profile
217        .default_region
218        .clone()
219        .unwrap_or_else(|| "us-east-1".to_string());
220    let region = aws_config::Region::new(region_str);
221
222    let mut loader = aws_config::defaults(BehaviorVersion::latest())
223        .region(region)
224        .http_client(http_client);
225
226    if let Some(secret) = secret {
227        use aws_credential_types::Credentials;
228        let creds = Credentials::new(
229            &secret.access_key_id,
230            &secret.secret_access_key,
231            secret.session_token.clone(),
232            None,
233            "brows3r-manual",
234        );
235        loader = loader.credentials_provider(SharedCredentialsProvider::new(creds));
236    } else if matches!(
237        profile.source,
238        ProfileSource::AwsCredentials | ProfileSource::AwsConfig
239    ) {
240        // Pin the loader to the selected named profile so the SDK resolves the
241        // SSO / assume-role / credential-process provider that the profile
242        // declares. Without this the loader falls back to the default chain
243        // (env → `default` profile → IMDS) and validation tests the wrong
244        // identity.
245        loader = loader.profile_name(profile.display_name.as_str());
246    }
247
248    let sdk_config = loader.load().await;
249    aws_sdk_sts::Client::new(&sdk_config)
250}
251
252// ---------------------------------------------------------------------------
253// build_s3_client_for_compat — build a one-shot S3 client for compat probe
254// ---------------------------------------------------------------------------
255
256/// Build an `aws_sdk_s3::Client` for a compat provider validation probe.
257///
258/// Uses the endpoint URL from `compat_flags` and path-style addressing.
259/// Injected credentials come from `secret` (manual profile) or the SDK chain.
260async fn build_s3_client_for_compat(
261    profile: &Profile,
262    secret: Option<&Secret>,
263) -> aws_sdk_s3::Client {
264    let http_client = HttpBuilder::new()
265        .tls_provider(tls::Provider::Rustls(
266            aws_smithy_http_client::tls::rustls_provider::CryptoMode::Ring,
267        ))
268        .build_https();
269
270    let region_str = profile
271        .default_region
272        .clone()
273        .unwrap_or_else(|| "us-east-1".to_string());
274    let region = aws_config::Region::new(region_str);
275
276    let endpoint_url = profile
277        .compat_flags
278        .endpoint_url
279        .clone()
280        .unwrap_or_default();
281
282    let mut loader = aws_config::defaults(BehaviorVersion::latest())
283        .region(region)
284        .http_client(http_client)
285        .endpoint_url(endpoint_url);
286
287    if let Some(secret) = secret {
288        use aws_credential_types::Credentials;
289        let creds = Credentials::new(
290            &secret.access_key_id,
291            &secret.secret_access_key,
292            secret.session_token.clone(),
293            None,
294            "brows3r-manual",
295        );
296        loader = loader.credentials_provider(SharedCredentialsProvider::new(creds));
297    } else if matches!(
298        profile.source,
299        ProfileSource::AwsCredentials | ProfileSource::AwsConfig
300    ) {
301        loader = loader.profile_name(profile.display_name.as_str());
302    }
303
304    let sdk_config = loader.load().await;
305
306    // Always use path-style for compat providers.
307    let mut s3_builder = S3ConfigBuilder::from(&sdk_config);
308    s3_builder = s3_builder.force_path_style(true);
309
310    aws_sdk_s3::Client::from_conf(s3_builder.build())
311}
312
313// ---------------------------------------------------------------------------
314// validate_profile — main entry point
315// ---------------------------------------------------------------------------
316
317/// Validate a profile by running the appropriate probe.
318///
319/// - AWS profiles (no `endpoint_url`): `sts:GetCallerIdentity`.
320/// - Compat providers (has `endpoint_url`): `s3:ListBuckets`.
321///
322/// The `pool` parameter is accepted for API symmetry and future use; the
323/// compat path constructs a fresh client to avoid registering a transient
324/// profile into the shared pool.
325///
326/// Always returns `Ok(report)`. SDK-level probe failures are captured inside
327/// the `ValidationReport` as `ok = false` + `error`. The `Err` path is
328/// reserved for future catastrophic conditions (e.g. missing required config).
329pub async fn validate_profile(
330    profile: &Profile,
331    secret: Option<&Secret>,
332    _pool: &Arc<ClientPool>,
333) -> Result<ValidationReport, AppError> {
334    let is_compat = profile.compat_flags.endpoint_url.is_some();
335
336    let report = if is_compat {
337        validate_compat(profile, secret).await
338    } else {
339        validate_aws(profile, secret).await
340    };
341
342    Ok(report)
343}
344
345/// Inner: AWS path using `sts:GetCallerIdentity`.
346async fn validate_aws(profile: &Profile, secret: Option<&Secret>) -> ValidationReport {
347    let client = build_sts_client(profile, secret).await;
348
349    let result = client.get_caller_identity().send().await;
350
351    match result {
352        Ok(resp) => ValidationReport {
353            profile_id: profile.id.clone(),
354            ok: true,
355            account_id: resp.account().map(str::to_owned),
356            arn: resp.arn().map(str::to_owned),
357            validated_at: now_unix_ms(),
358            provider_kind: ProviderKind::Aws,
359            error: None,
360        },
361        Err(sdk_err) => {
362            let (status, code, message) = extract_sts_error_parts(&sdk_err);
363            let mapped = map_sts_error(status, &code, &message);
364            ValidationReport {
365                profile_id: profile.id.clone(),
366                ok: false,
367                account_id: None,
368                arn: None,
369                validated_at: 0,
370                provider_kind: ProviderKind::Aws,
371                error: Some(mapped),
372            }
373        }
374    }
375}
376
377/// Inner: compat provider path using `s3:ListBuckets`.
378async fn validate_compat(profile: &Profile, secret: Option<&Secret>) -> ValidationReport {
379    let client = build_s3_client_for_compat(profile, secret).await;
380
381    let result = client.list_buckets().send().await;
382
383    match result {
384        Ok(_) => ValidationReport {
385            profile_id: profile.id.clone(),
386            ok: true,
387            account_id: None,
388            arn: None,
389            validated_at: now_unix_ms(),
390            provider_kind: ProviderKind::Compatible,
391            error: None,
392        },
393        Err(sdk_err) => {
394            let (status, code, message) = extract_s3_error_parts(&sdk_err);
395            let mapped = map_s3_list_error(status, &code, &message);
396            ValidationReport {
397                profile_id: profile.id.clone(),
398                ok: false,
399                account_id: None,
400                arn: None,
401                validated_at: 0,
402                provider_kind: ProviderKind::Compatible,
403                error: Some(mapped),
404            }
405        }
406    }
407}
408
409// ---------------------------------------------------------------------------
410// Error part extractors — convert SDK error types to (status, code, message)
411// ---------------------------------------------------------------------------
412
413fn extract_sts_error_parts(
414    err: &aws_sdk_sts::error::SdkError<
415        aws_sdk_sts::operation::get_caller_identity::GetCallerIdentityError,
416    >,
417) -> (u16, String, String) {
418    use aws_sdk_sts::error::SdkError;
419
420    match err {
421        SdkError::ServiceError(svc) => {
422            let status = svc.raw().status().as_u16();
423            let inner = svc.err();
424            let code = inner.meta().code().unwrap_or("Unknown").to_string();
425            let message = inner.meta().message().unwrap_or("").to_string();
426            (status, code, message)
427        }
428        SdkError::ConstructionFailure(_) => (0, "ConstructionFailure".to_string(), err.to_string()),
429        SdkError::TimeoutError(_) => (0, "TimeoutError".to_string(), err.to_string()),
430        SdkError::DispatchFailure(_) => (503, "DispatchFailure".to_string(), err.to_string()),
431        SdkError::ResponseError(_) => (500, "ResponseError".to_string(), err.to_string()),
432        _ => (0, "Unknown".to_string(), err.to_string()),
433    }
434}
435
436fn extract_s3_error_parts(
437    err: &aws_sdk_s3::error::SdkError<aws_sdk_s3::operation::list_buckets::ListBucketsError>,
438) -> (u16, String, String) {
439    use aws_sdk_s3::error::SdkError;
440
441    match err {
442        SdkError::ServiceError(svc) => {
443            let status = svc.raw().status().as_u16();
444            let inner = svc.err();
445            let code = inner.meta().code().unwrap_or("Unknown").to_string();
446            let message = inner.meta().message().unwrap_or("").to_string();
447            (status, code, message)
448        }
449        SdkError::ConstructionFailure(_) => (0, "ConstructionFailure".to_string(), err.to_string()),
450        SdkError::TimeoutError(_) => (0, "TimeoutError".to_string(), err.to_string()),
451        SdkError::DispatchFailure(_) => (503, "DispatchFailure".to_string(), err.to_string()),
452        SdkError::ResponseError(_) => (500, "ResponseError".to_string(), err.to_string()),
453        _ => (0, "Unknown".to_string(), err.to_string()),
454    }
455}
456
457// ---------------------------------------------------------------------------
458// Tests
459// ---------------------------------------------------------------------------
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use crate::ids::ProfileId;
465
466    fn make_profile_id() -> ProfileId {
467        ProfileId::new("test-profile")
468    }
469
470    // ------------------------------------------------------------------
471    // Unit: map_sts_error — all mapped variants
472    // ------------------------------------------------------------------
473
474    #[test]
475    fn map_sts_invalid_token_is_auth() {
476        let err = map_sts_error(403, "InvalidClientTokenId", "invalid token");
477        assert!(
478            matches!(err, AppError::Auth { .. }),
479            "expected Auth, got {err:?}"
480        );
481    }
482
483    #[test]
484    fn map_sts_signature_mismatch_is_auth() {
485        let err = map_sts_error(403, "SignatureDoesNotMatch", "sig mismatch");
486        assert!(
487            matches!(err, AppError::Auth { .. }),
488            "expected Auth, got {err:?}"
489        );
490    }
491
492    #[test]
493    fn map_sts_expired_token_is_auth() {
494        let err = map_sts_error(400, "ExpiredTokenException", "token expired");
495        assert!(
496            matches!(err, AppError::Auth { .. }),
497            "expected Auth, got {err:?}"
498        );
499    }
500
501    #[test]
502    fn map_sts_expired_token_short_code_is_auth() {
503        let err = map_sts_error(400, "ExpiredToken", "token expired");
504        assert!(matches!(err, AppError::Auth { .. }));
505    }
506
507    #[test]
508    fn map_sts_unrecognized_client_is_auth() {
509        let err = map_sts_error(403, "UnrecognizedClientException", "bad client");
510        assert!(matches!(err, AppError::Auth { .. }));
511    }
512
513    #[test]
514    fn map_sts_invalid_access_key_is_auth() {
515        let err = map_sts_error(403, "InvalidAccessKeyId", "bad key");
516        assert!(matches!(err, AppError::Auth { .. }));
517    }
518
519    #[test]
520    fn map_sts_5xx_is_network() {
521        let err = map_sts_error(503, "ServiceUnavailable", "AWS is down");
522        assert!(
523            matches!(err, AppError::Network { .. }),
524            "expected Network, got {err:?}"
525        );
526    }
527
528    #[test]
529    fn map_sts_500_is_network() {
530        let err = map_sts_error(500, "InternalFailure", "internal");
531        assert!(matches!(err, AppError::Network { .. }));
532    }
533
534    #[test]
535    fn map_sts_unknown_code_is_provider_specific() {
536        let err = map_sts_error(400, "SomeOtherError", "details");
537        assert!(
538            matches!(err, AppError::ProviderSpecific { .. }),
539            "expected ProviderSpecific, got {err:?}"
540        );
541    }
542
543    // ------------------------------------------------------------------
544    // Unit: map_s3_list_error — all mapped variants
545    // ------------------------------------------------------------------
546
547    #[test]
548    fn map_s3_403_access_denied_code() {
549        let err = map_s3_list_error(403, "AccessDenied", "access denied");
550        assert!(
551            matches!(err, AppError::AccessDenied { .. }),
552            "expected AccessDenied, got {err:?}"
553        );
554    }
555
556    #[test]
557    fn map_s3_403_status_without_code() {
558        let err = map_s3_list_error(403, "SomeUnknownCode", "forbidden");
559        assert!(matches!(err, AppError::AccessDenied { .. }));
560    }
561
562    #[test]
563    fn map_s3_invalid_token_is_auth() {
564        let err = map_s3_list_error(403, "InvalidClientTokenId", "bad token");
565        assert!(matches!(err, AppError::Auth { .. }));
566    }
567
568    #[test]
569    fn map_s3_5xx_is_network() {
570        let err = map_s3_list_error(502, "BadGateway", "proxy error");
571        assert!(matches!(err, AppError::Network { .. }));
572    }
573
574    #[test]
575    fn map_s3_unknown_is_provider_specific() {
576        let err = map_s3_list_error(400, "BucketRegionError", "wrong region");
577        assert!(matches!(err, AppError::ProviderSpecific { .. }));
578    }
579
580    // ------------------------------------------------------------------
581    // Unit: validate_with_caller — pure mapping via closure injection
582    // ------------------------------------------------------------------
583
584    #[test]
585    fn validate_with_caller_success_passes_through() {
586        let id = make_profile_id();
587        let result = validate_with_caller(&id, || {
588            Ok(CallerIdentity {
589                account_id: "123456789012".to_string(),
590                arn: "arn:aws:iam::123456789012:user/test".to_string(),
591            })
592        });
593        let identity = result.expect("expected success");
594        assert_eq!(identity.account_id, "123456789012");
595        assert_eq!(identity.arn, "arn:aws:iam::123456789012:user/test");
596    }
597
598    #[test]
599    fn validate_with_caller_invalid_token_maps_to_auth() {
600        let id = make_profile_id();
601        let result = validate_with_caller(&id, || {
602            Err((
603                403u16,
604                "InvalidClientTokenId".to_string(),
605                "bad token".to_string(),
606            ))
607        });
608        let err = result.unwrap_err();
609        assert!(
610            matches!(err, AppError::Auth { .. }),
611            "InvalidClientTokenId must map to Auth; got {err:?}"
612        );
613    }
614
615    #[test]
616    fn validate_with_caller_5xx_maps_to_network() {
617        let id = make_profile_id();
618        let result = validate_with_caller(&id, || {
619            Err((503u16, "ServiceUnavailable".to_string(), "down".to_string()))
620        });
621        let err = result.unwrap_err();
622        assert!(
623            matches!(err, AppError::Network { .. }),
624            "5xx must map to Network; got {err:?}"
625        );
626    }
627
628    #[test]
629    fn validate_with_caller_unknown_code_maps_to_provider_specific() {
630        let id = make_profile_id();
631        let result = validate_with_caller(&id, || {
632            Err((
633                400u16,
634                "Throttling".to_string(),
635                "too many requests".to_string(),
636            ))
637        });
638        let err = result.unwrap_err();
639        assert!(
640            matches!(err, AppError::ProviderSpecific { .. }),
641            "unknown code must map to ProviderSpecific; got {err:?}"
642        );
643    }
644
645    // ------------------------------------------------------------------
646    // Unit: now_unix_ms sanity check
647    // ------------------------------------------------------------------
648
649    #[test]
650    fn now_unix_ms_returns_reasonable_value() {
651        let ts = now_unix_ms();
652        // Must be after 2024-01-01T00:00:00Z = 1_704_067_200_000 ms.
653        assert!(ts > 1_704_067_200_000i64, "timestamp looks stale: {ts}");
654    }
655}