1use std::sync::Arc;
26
27use aws_config::BehaviorVersion;
28use aws_credential_types::provider::SharedCredentialsProvider;
29use aws_sdk_s3::config::Builder as S3ConfigBuilder;
30use aws_smithy_http_client::{tls, Builder as HttpBuilder};
31use serde::Serialize;
32
33use crate::{
34 error::AppError,
35 ids::ProfileId,
36 profiles::{keychain::Secret, Profile, ProfileSource},
37 s3::ClientPool,
38};
39
40#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
49#[serde(rename_all = "camelCase")]
50pub enum ProviderKind {
51 Aws,
53 Compatible,
56}
57
58#[derive(Debug, Clone, Serialize)]
69#[serde(rename_all = "camelCase")]
70pub struct ValidationReport {
71 pub profile_id: ProfileId,
73 pub ok: bool,
75 pub account_id: Option<String>,
78 pub arn: Option<String>,
81 pub validated_at: i64,
83 pub provider_kind: ProviderKind,
85 pub error: Option<AppError>,
87}
88
89fn now_unix_ms() -> i64 {
94 use std::time::{SystemTime, UNIX_EPOCH};
95 SystemTime::now()
96 .duration_since(UNIX_EPOCH)
97 .map(|d| d.as_millis() as i64)
98 .unwrap_or(0)
99}
100
101fn map_sts_error(status: u16, code: &str, message: &str) -> AppError {
112 match code {
113 "InvalidClientTokenId"
115 | "SignatureDoesNotMatch"
116 | "ExpiredTokenException"
117 | "ExpiredToken"
118 | "UnrecognizedClientException"
119 | "InvalidAccessKeyId" => AppError::Auth {
120 reason: code.to_string(),
121 },
122 _ if status >= 500 => AppError::Network {
124 source: message.to_string(),
125 },
126 _ => AppError::ProviderSpecific {
128 code: code.to_string(),
129 message: message.to_string(),
130 },
131 }
132}
133
134fn map_s3_list_error(status: u16, code: &str, message: &str) -> AppError {
136 match code {
137 "InvalidClientTokenId"
139 | "SignatureDoesNotMatch"
140 | "ExpiredTokenException"
141 | "ExpiredToken"
142 | "UnrecognizedClientException"
143 | "InvalidAccessKeyId" => AppError::Auth {
144 reason: code.to_string(),
145 },
146 "AccessDenied" | "Forbidden" => AppError::AccessDenied {
148 op: "ListBuckets".to_string(),
149 resource: "*".to_string(),
150 },
151 _ if status == 403 => AppError::AccessDenied {
152 op: "ListBuckets".to_string(),
153 resource: "*".to_string(),
154 },
155 _ if status >= 500 => AppError::Network {
157 source: message.to_string(),
158 },
159 _ => AppError::ProviderSpecific {
161 code: code.to_string(),
162 message: message.to_string(),
163 },
164 }
165}
166
167#[derive(Debug, Clone)]
173pub struct CallerIdentity {
174 pub account_id: String,
175 pub arn: String,
176}
177
178pub fn validate_with_caller<F>(
189 profile_id: &ProfileId,
190 caller: F,
191) -> Result<CallerIdentity, AppError>
192where
193 F: FnOnce() -> Result<CallerIdentity, (u16, String, String)>,
194{
195 caller().map_err(|(status, code, message)| {
196 let _ = profile_id; map_sts_error(status, &code, &message)
198 })
199}
200
201async fn build_sts_client(profile: &Profile, secret: Option<&Secret>) -> aws_sdk_sts::Client {
210 let http_client = HttpBuilder::new()
211 .tls_provider(tls::Provider::Rustls(
212 aws_smithy_http_client::tls::rustls_provider::CryptoMode::Ring,
213 ))
214 .build_https();
215
216 let region_str = profile
217 .default_region
218 .clone()
219 .unwrap_or_else(|| "us-east-1".to_string());
220 let region = aws_config::Region::new(region_str);
221
222 let mut loader = aws_config::defaults(BehaviorVersion::latest())
223 .region(region)
224 .http_client(http_client);
225
226 if let Some(secret) = secret {
227 use aws_credential_types::Credentials;
228 let creds = Credentials::new(
229 &secret.access_key_id,
230 &secret.secret_access_key,
231 secret.session_token.clone(),
232 None,
233 "brows3r-manual",
234 );
235 loader = loader.credentials_provider(SharedCredentialsProvider::new(creds));
236 } else if matches!(
237 profile.source,
238 ProfileSource::AwsCredentials | ProfileSource::AwsConfig
239 ) {
240 loader = loader.profile_name(profile.display_name.as_str());
246 }
247
248 let sdk_config = loader.load().await;
249 aws_sdk_sts::Client::new(&sdk_config)
250}
251
252async fn build_s3_client_for_compat(
261 profile: &Profile,
262 secret: Option<&Secret>,
263) -> aws_sdk_s3::Client {
264 let http_client = HttpBuilder::new()
265 .tls_provider(tls::Provider::Rustls(
266 aws_smithy_http_client::tls::rustls_provider::CryptoMode::Ring,
267 ))
268 .build_https();
269
270 let region_str = profile
271 .default_region
272 .clone()
273 .unwrap_or_else(|| "us-east-1".to_string());
274 let region = aws_config::Region::new(region_str);
275
276 let endpoint_url = profile
277 .compat_flags
278 .endpoint_url
279 .clone()
280 .unwrap_or_default();
281
282 let mut loader = aws_config::defaults(BehaviorVersion::latest())
283 .region(region)
284 .http_client(http_client)
285 .endpoint_url(endpoint_url);
286
287 if let Some(secret) = secret {
288 use aws_credential_types::Credentials;
289 let creds = Credentials::new(
290 &secret.access_key_id,
291 &secret.secret_access_key,
292 secret.session_token.clone(),
293 None,
294 "brows3r-manual",
295 );
296 loader = loader.credentials_provider(SharedCredentialsProvider::new(creds));
297 } else if matches!(
298 profile.source,
299 ProfileSource::AwsCredentials | ProfileSource::AwsConfig
300 ) {
301 loader = loader.profile_name(profile.display_name.as_str());
302 }
303
304 let sdk_config = loader.load().await;
305
306 let mut s3_builder = S3ConfigBuilder::from(&sdk_config);
308 s3_builder = s3_builder.force_path_style(true);
309
310 aws_sdk_s3::Client::from_conf(s3_builder.build())
311}
312
313pub async fn validate_profile(
330 profile: &Profile,
331 secret: Option<&Secret>,
332 _pool: &Arc<ClientPool>,
333) -> Result<ValidationReport, AppError> {
334 let is_compat = profile.compat_flags.endpoint_url.is_some();
335
336 let report = if is_compat {
337 validate_compat(profile, secret).await
338 } else {
339 validate_aws(profile, secret).await
340 };
341
342 Ok(report)
343}
344
345async fn validate_aws(profile: &Profile, secret: Option<&Secret>) -> ValidationReport {
347 let client = build_sts_client(profile, secret).await;
348
349 let result = client.get_caller_identity().send().await;
350
351 match result {
352 Ok(resp) => ValidationReport {
353 profile_id: profile.id.clone(),
354 ok: true,
355 account_id: resp.account().map(str::to_owned),
356 arn: resp.arn().map(str::to_owned),
357 validated_at: now_unix_ms(),
358 provider_kind: ProviderKind::Aws,
359 error: None,
360 },
361 Err(sdk_err) => {
362 let (status, code, message) = extract_sts_error_parts(&sdk_err);
363 let mapped = map_sts_error(status, &code, &message);
364 ValidationReport {
365 profile_id: profile.id.clone(),
366 ok: false,
367 account_id: None,
368 arn: None,
369 validated_at: 0,
370 provider_kind: ProviderKind::Aws,
371 error: Some(mapped),
372 }
373 }
374 }
375}
376
377async fn validate_compat(profile: &Profile, secret: Option<&Secret>) -> ValidationReport {
379 let client = build_s3_client_for_compat(profile, secret).await;
380
381 let result = client.list_buckets().send().await;
382
383 match result {
384 Ok(_) => ValidationReport {
385 profile_id: profile.id.clone(),
386 ok: true,
387 account_id: None,
388 arn: None,
389 validated_at: now_unix_ms(),
390 provider_kind: ProviderKind::Compatible,
391 error: None,
392 },
393 Err(sdk_err) => {
394 let (status, code, message) = extract_s3_error_parts(&sdk_err);
395 let mapped = map_s3_list_error(status, &code, &message);
396 ValidationReport {
397 profile_id: profile.id.clone(),
398 ok: false,
399 account_id: None,
400 arn: None,
401 validated_at: 0,
402 provider_kind: ProviderKind::Compatible,
403 error: Some(mapped),
404 }
405 }
406 }
407}
408
409fn extract_sts_error_parts(
414 err: &aws_sdk_sts::error::SdkError<
415 aws_sdk_sts::operation::get_caller_identity::GetCallerIdentityError,
416 >,
417) -> (u16, String, String) {
418 use aws_sdk_sts::error::SdkError;
419
420 match err {
421 SdkError::ServiceError(svc) => {
422 let status = svc.raw().status().as_u16();
423 let inner = svc.err();
424 let code = inner.meta().code().unwrap_or("Unknown").to_string();
425 let message = inner.meta().message().unwrap_or("").to_string();
426 (status, code, message)
427 }
428 SdkError::ConstructionFailure(_) => (0, "ConstructionFailure".to_string(), err.to_string()),
429 SdkError::TimeoutError(_) => (0, "TimeoutError".to_string(), err.to_string()),
430 SdkError::DispatchFailure(_) => (503, "DispatchFailure".to_string(), err.to_string()),
431 SdkError::ResponseError(_) => (500, "ResponseError".to_string(), err.to_string()),
432 _ => (0, "Unknown".to_string(), err.to_string()),
433 }
434}
435
436fn extract_s3_error_parts(
437 err: &aws_sdk_s3::error::SdkError<aws_sdk_s3::operation::list_buckets::ListBucketsError>,
438) -> (u16, String, String) {
439 use aws_sdk_s3::error::SdkError;
440
441 match err {
442 SdkError::ServiceError(svc) => {
443 let status = svc.raw().status().as_u16();
444 let inner = svc.err();
445 let code = inner.meta().code().unwrap_or("Unknown").to_string();
446 let message = inner.meta().message().unwrap_or("").to_string();
447 (status, code, message)
448 }
449 SdkError::ConstructionFailure(_) => (0, "ConstructionFailure".to_string(), err.to_string()),
450 SdkError::TimeoutError(_) => (0, "TimeoutError".to_string(), err.to_string()),
451 SdkError::DispatchFailure(_) => (503, "DispatchFailure".to_string(), err.to_string()),
452 SdkError::ResponseError(_) => (500, "ResponseError".to_string(), err.to_string()),
453 _ => (0, "Unknown".to_string(), err.to_string()),
454 }
455}
456
457#[cfg(test)]
462mod tests {
463 use super::*;
464 use crate::ids::ProfileId;
465
466 fn make_profile_id() -> ProfileId {
467 ProfileId::new("test-profile")
468 }
469
470 #[test]
475 fn map_sts_invalid_token_is_auth() {
476 let err = map_sts_error(403, "InvalidClientTokenId", "invalid token");
477 assert!(
478 matches!(err, AppError::Auth { .. }),
479 "expected Auth, got {err:?}"
480 );
481 }
482
483 #[test]
484 fn map_sts_signature_mismatch_is_auth() {
485 let err = map_sts_error(403, "SignatureDoesNotMatch", "sig mismatch");
486 assert!(
487 matches!(err, AppError::Auth { .. }),
488 "expected Auth, got {err:?}"
489 );
490 }
491
492 #[test]
493 fn map_sts_expired_token_is_auth() {
494 let err = map_sts_error(400, "ExpiredTokenException", "token expired");
495 assert!(
496 matches!(err, AppError::Auth { .. }),
497 "expected Auth, got {err:?}"
498 );
499 }
500
501 #[test]
502 fn map_sts_expired_token_short_code_is_auth() {
503 let err = map_sts_error(400, "ExpiredToken", "token expired");
504 assert!(matches!(err, AppError::Auth { .. }));
505 }
506
507 #[test]
508 fn map_sts_unrecognized_client_is_auth() {
509 let err = map_sts_error(403, "UnrecognizedClientException", "bad client");
510 assert!(matches!(err, AppError::Auth { .. }));
511 }
512
513 #[test]
514 fn map_sts_invalid_access_key_is_auth() {
515 let err = map_sts_error(403, "InvalidAccessKeyId", "bad key");
516 assert!(matches!(err, AppError::Auth { .. }));
517 }
518
519 #[test]
520 fn map_sts_5xx_is_network() {
521 let err = map_sts_error(503, "ServiceUnavailable", "AWS is down");
522 assert!(
523 matches!(err, AppError::Network { .. }),
524 "expected Network, got {err:?}"
525 );
526 }
527
528 #[test]
529 fn map_sts_500_is_network() {
530 let err = map_sts_error(500, "InternalFailure", "internal");
531 assert!(matches!(err, AppError::Network { .. }));
532 }
533
534 #[test]
535 fn map_sts_unknown_code_is_provider_specific() {
536 let err = map_sts_error(400, "SomeOtherError", "details");
537 assert!(
538 matches!(err, AppError::ProviderSpecific { .. }),
539 "expected ProviderSpecific, got {err:?}"
540 );
541 }
542
543 #[test]
548 fn map_s3_403_access_denied_code() {
549 let err = map_s3_list_error(403, "AccessDenied", "access denied");
550 assert!(
551 matches!(err, AppError::AccessDenied { .. }),
552 "expected AccessDenied, got {err:?}"
553 );
554 }
555
556 #[test]
557 fn map_s3_403_status_without_code() {
558 let err = map_s3_list_error(403, "SomeUnknownCode", "forbidden");
559 assert!(matches!(err, AppError::AccessDenied { .. }));
560 }
561
562 #[test]
563 fn map_s3_invalid_token_is_auth() {
564 let err = map_s3_list_error(403, "InvalidClientTokenId", "bad token");
565 assert!(matches!(err, AppError::Auth { .. }));
566 }
567
568 #[test]
569 fn map_s3_5xx_is_network() {
570 let err = map_s3_list_error(502, "BadGateway", "proxy error");
571 assert!(matches!(err, AppError::Network { .. }));
572 }
573
574 #[test]
575 fn map_s3_unknown_is_provider_specific() {
576 let err = map_s3_list_error(400, "BucketRegionError", "wrong region");
577 assert!(matches!(err, AppError::ProviderSpecific { .. }));
578 }
579
580 #[test]
585 fn validate_with_caller_success_passes_through() {
586 let id = make_profile_id();
587 let result = validate_with_caller(&id, || {
588 Ok(CallerIdentity {
589 account_id: "123456789012".to_string(),
590 arn: "arn:aws:iam::123456789012:user/test".to_string(),
591 })
592 });
593 let identity = result.expect("expected success");
594 assert_eq!(identity.account_id, "123456789012");
595 assert_eq!(identity.arn, "arn:aws:iam::123456789012:user/test");
596 }
597
598 #[test]
599 fn validate_with_caller_invalid_token_maps_to_auth() {
600 let id = make_profile_id();
601 let result = validate_with_caller(&id, || {
602 Err((
603 403u16,
604 "InvalidClientTokenId".to_string(),
605 "bad token".to_string(),
606 ))
607 });
608 let err = result.unwrap_err();
609 assert!(
610 matches!(err, AppError::Auth { .. }),
611 "InvalidClientTokenId must map to Auth; got {err:?}"
612 );
613 }
614
615 #[test]
616 fn validate_with_caller_5xx_maps_to_network() {
617 let id = make_profile_id();
618 let result = validate_with_caller(&id, || {
619 Err((503u16, "ServiceUnavailable".to_string(), "down".to_string()))
620 });
621 let err = result.unwrap_err();
622 assert!(
623 matches!(err, AppError::Network { .. }),
624 "5xx must map to Network; got {err:?}"
625 );
626 }
627
628 #[test]
629 fn validate_with_caller_unknown_code_maps_to_provider_specific() {
630 let id = make_profile_id();
631 let result = validate_with_caller(&id, || {
632 Err((
633 400u16,
634 "Throttling".to_string(),
635 "too many requests".to_string(),
636 ))
637 });
638 let err = result.unwrap_err();
639 assert!(
640 matches!(err, AppError::ProviderSpecific { .. }),
641 "unknown code must map to ProviderSpecific; got {err:?}"
642 );
643 }
644
645 #[test]
650 fn now_unix_ms_returns_reasonable_value() {
651 let ts = now_unix_ms();
652 assert!(ts > 1_704_067_200_000i64, "timestamp looks stale: {ts}");
654 }
655}