1use std::{
22 collections::HashMap,
23 sync::{Arc, Mutex},
24};
25
26use serde::{Deserialize, Serialize};
27
28use crate::{
29 error::AppError,
30 ids::{BucketId, ProfileId},
31};
32
33const DEFAULT_TTL_SECS: i64 = 30 * 60;
41
42#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
52#[serde(tag = "class", rename_all = "camelCase")]
53pub enum CapabilityClass {
54 Allowed,
56 Denied { iam_action: Option<String> },
61 Unsupported { provider: Option<String> },
65 StorageClassBlocked { storage_class: String },
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CapabilityRecord {
77 pub class: CapabilityClass,
79 pub learned_at: i64,
81}
82
83pub type CapabilityMap = HashMap<String, CapabilityRecord>;
92
93#[derive(Debug, Clone, PartialEq, Eq)]
99pub enum ClearScope {
100 All,
102 Bucket(BucketId),
105 Op(String),
107}
108
109type CacheKey = (ProfileId, Option<BucketId>, String);
114
115pub(crate) trait Clock: Send + Sync {
121 fn now_secs(&self) -> i64;
122}
123
124#[derive(Default)]
125struct SystemClock;
126
127impl Clock for SystemClock {
128 fn now_secs(&self) -> i64 {
129 std::time::SystemTime::now()
130 .duration_since(std::time::UNIX_EPOCH)
131 .unwrap_or_default()
132 .as_secs() as i64
133 }
134}
135
136#[derive(Default)]
138pub struct MockClock {
139 inner: Mutex<i64>,
140}
141
142impl MockClock {
143 pub fn new(secs: i64) -> Arc<Self> {
144 Arc::new(Self {
145 inner: Mutex::new(secs),
146 })
147 }
148
149 pub fn advance(&self, delta: i64) {
151 *self.inner.lock().unwrap() += delta;
152 }
153}
154
155impl Clock for MockClock {
156 fn now_secs(&self) -> i64 {
157 *self.inner.lock().unwrap()
158 }
159}
160
161pub struct CapabilityCache {
170 ttl_secs: i64,
171 clock: Arc<dyn Clock>,
172 map: Mutex<HashMap<CacheKey, CapabilityRecord>>,
173}
174
175impl Default for CapabilityCache {
176 fn default() -> Self {
177 Self {
178 ttl_secs: DEFAULT_TTL_SECS,
179 clock: Arc::new(SystemClock),
180 map: Mutex::new(HashMap::new()),
181 }
182 }
183}
184
185impl CapabilityCache {
186 #[cfg(test)]
191 pub(crate) fn with_clock(clock: Arc<dyn Clock>) -> Self {
192 Self {
193 ttl_secs: DEFAULT_TTL_SECS,
194 clock,
195 map: Mutex::new(HashMap::new()),
196 }
197 }
198
199 pub fn record_capability(
204 &self,
205 profile: &ProfileId,
206 bucket: Option<&BucketId>,
207 op: &str,
208 class: CapabilityClass,
209 ) {
210 let key: CacheKey = (profile.clone(), bucket.cloned(), op.to_owned());
211 let record = CapabilityRecord {
212 class,
213 learned_at: self.clock.now_secs(),
214 };
215 self.map.lock().unwrap().insert(key, record);
216 }
217
218 pub fn record_from_error(
225 &self,
226 profile: &ProfileId,
227 bucket: Option<&BucketId>,
228 op: &str,
229 error: &AppError,
230 ) -> Result<bool, AppError> {
231 let class = match error {
232 AppError::AccessDenied { op: action, .. } => CapabilityClass::Denied {
233 iam_action: Some(action.clone()),
234 },
235 AppError::Unsupported { provider, .. } => CapabilityClass::Unsupported {
236 provider: Some(provider.clone()),
237 },
238 AppError::ProviderSpecific { code, message } => {
239 if code == "InvalidStorageClass" || code == "NoSuchTransition" {
241 let storage_class = extract_storage_class(message);
242 CapabilityClass::StorageClassBlocked { storage_class }
243 } else {
244 return Ok(false);
245 }
246 }
247 _ => return Ok(false),
249 };
250
251 self.record_capability(profile, bucket, op, class);
252 Ok(true)
253 }
254
255 pub fn get(
257 &self,
258 profile: &ProfileId,
259 bucket: Option<&BucketId>,
260 op: &str,
261 ) -> Option<CapabilityRecord> {
262 let key: CacheKey = (profile.clone(), bucket.cloned(), op.to_owned());
263 let map = self.map.lock().unwrap();
264 let record = map.get(&key)?;
265 let age = self.clock.now_secs() - record.learned_at;
266 if age >= self.ttl_secs {
267 return None;
268 }
269 Some(record.clone())
270 }
271
272 pub fn get_map(&self, profile: &ProfileId) -> CapabilityMap {
277 let now = self.clock.now_secs();
278 let map = self.map.lock().unwrap();
279 map.iter()
280 .filter(|((pid, _, _), _)| pid == profile)
281 .filter(|(_, record)| (now - record.learned_at) < self.ttl_secs)
282 .map(|((_, bucket, op), record)| {
283 let bucket_part = bucket.as_ref().map(|b| b.as_str()).unwrap_or("");
284 let key = format!("{bucket_part}/{op}");
285 (key, record.clone())
286 })
287 .collect()
288 }
289
290 pub fn clear(&self, profile: &ProfileId, scope: &ClearScope) {
297 let mut map = self.map.lock().unwrap();
298 map.retain(|(pid, bucket, op), _| {
299 if pid != profile {
300 return true; }
302 match scope {
303 ClearScope::All => false,
304 ClearScope::Bucket(bid) => bucket.as_ref() != Some(bid),
305 ClearScope::Op(target_op) => op != target_op,
306 }
307 });
308 }
309}
310
311#[derive(Clone, Default)]
317pub struct CapabilityHandle(pub Arc<CapabilityCache>);
318
319impl CapabilityHandle {
320 pub fn inner(&self) -> &CapabilityCache {
321 &self.0
322 }
323}
324
325impl std::ops::Deref for CapabilityHandle {
326 type Target = CapabilityCache;
327 fn deref(&self) -> &CapabilityCache {
328 &self.0
329 }
330}
331
332fn extract_storage_class(message: &str) -> String {
346 let sentinels = [
348 "class ",
349 "class\t",
350 "for ",
351 "GLACIER",
352 "STANDARD_IA",
353 "DEEP_ARCHIVE",
354 ];
355
356 for known in &[
358 "GLACIER_IR",
359 "GLACIER",
360 "DEEP_ARCHIVE",
361 "STANDARD_IA",
362 "ONEZONE_IA",
363 "INTELLIGENT_TIERING",
364 "STANDARD",
365 "REDUCED_REDUNDANCY",
366 "EXPRESS_ONEZONE",
367 ] {
368 if message.contains(known) {
369 return (*known).to_owned();
370 }
371 }
372
373 for sentinel in &sentinels {
375 if let Some(pos) = message.find(sentinel) {
376 let after = &message[pos + sentinel.len()..];
377 let token: String = after
378 .chars()
379 .take_while(|c| c.is_ascii_uppercase() || c.is_ascii_digit() || *c == '_')
380 .collect();
381 if token.len() >= 3 {
382 return token;
383 }
384 }
385 }
386
387 "UNKNOWN".to_owned()
388}
389
390#[cfg(test)]
395mod tests {
396 use super::*;
397
398 fn profile(s: &str) -> ProfileId {
399 ProfileId::new(s)
400 }
401
402 fn bucket(s: &str) -> BucketId {
403 BucketId::new(s)
404 }
405
406 #[test]
411 fn access_denied_maps_to_denied_with_iam_action() {
412 let cache = CapabilityCache::default();
413 let pid = profile("p1");
414 let bid = bucket("my-bucket");
415
416 let err = AppError::AccessDenied {
417 op: "s3:PutBucketVersioning".to_owned(),
418 resource: "arn:aws:s3:::my-bucket".to_owned(),
419 };
420
421 let classified = cache
422 .record_from_error(&pid, Some(&bid), "PutBucketVersioning", &err)
423 .unwrap();
424 assert!(classified, "AccessDenied must be classified");
425
426 let record = cache
427 .get(&pid, Some(&bid), "PutBucketVersioning")
428 .expect("record must be stored");
429 assert_eq!(
430 record.class,
431 CapabilityClass::Denied {
432 iam_action: Some("s3:PutBucketVersioning".to_owned()),
433 }
434 );
435 }
436
437 #[test]
438 fn unsupported_maps_to_unsupported_with_provider() {
439 let cache = CapabilityCache::default();
440 let pid = profile("p2");
441
442 let err = AppError::Unsupported {
443 op: "SelectObjectContent".to_owned(),
444 provider: "MinIO".to_owned(),
445 };
446
447 let classified = cache
448 .record_from_error(&pid, None, "SelectObjectContent", &err)
449 .unwrap();
450 assert!(classified);
451
452 let record = cache
453 .get(&pid, None, "SelectObjectContent")
454 .expect("record must be stored");
455 assert_eq!(
456 record.class,
457 CapabilityClass::Unsupported {
458 provider: Some("MinIO".to_owned()),
459 }
460 );
461 }
462
463 #[test]
464 fn invalid_storage_class_maps_to_storage_class_blocked() {
465 let cache = CapabilityCache::default();
466 let pid = profile("p3");
467 let bid = bucket("archive-bucket");
468
469 let err = AppError::ProviderSpecific {
470 code: "InvalidStorageClass".to_owned(),
471 message: "The storage class GLACIER is not supported for this transition".to_owned(),
472 };
473
474 let classified = cache
475 .record_from_error(&pid, Some(&bid), "SetStorageClass", &err)
476 .unwrap();
477 assert!(classified);
478
479 let record = cache
480 .get(&pid, Some(&bid), "SetStorageClass")
481 .expect("record must be stored");
482 assert_eq!(
483 record.class,
484 CapabilityClass::StorageClassBlocked {
485 storage_class: "GLACIER".to_owned(),
486 }
487 );
488 }
489
490 #[test]
491 fn no_such_transition_maps_to_storage_class_blocked() {
492 let cache = CapabilityCache::default();
493 let pid = profile("p3b");
494 let bid = bucket("bucket-x");
495
496 let err = AppError::ProviderSpecific {
497 code: "NoSuchTransition".to_owned(),
498 message: "Transition to STANDARD_IA failed".to_owned(),
499 };
500
501 let classified = cache
502 .record_from_error(&pid, Some(&bid), "TransitionStorageClass", &err)
503 .unwrap();
504 assert!(classified);
505
506 let record = cache
507 .get(&pid, Some(&bid), "TransitionStorageClass")
508 .expect("record must be stored");
509 match &record.class {
510 CapabilityClass::StorageClassBlocked { storage_class } => {
511 assert_eq!(storage_class, "STANDARD_IA");
512 }
513 other => panic!("expected StorageClassBlocked, got {other:?}"),
514 }
515 }
516
517 #[test]
518 fn not_found_returns_ok_false_and_stores_nothing() {
519 let cache = CapabilityCache::default();
520 let pid = profile("p4");
521
522 let err = AppError::NotFound {
523 resource: "s3://bucket/key".to_owned(),
524 };
525
526 let classified = cache
527 .record_from_error(&pid, None, "GetObject", &err)
528 .unwrap();
529 assert!(!classified, "NotFound must not be classified");
530
531 assert!(
532 cache.get(&pid, None, "GetObject").is_none(),
533 "no record must be stored for unclassifiable error"
534 );
535 }
536
537 #[test]
542 fn expired_record_returns_none() {
543 let clock = MockClock::new(1_000_000);
544 let cache = CapabilityCache::with_clock(clock.clone());
545 let pid = profile("p-ttl");
546
547 cache.record_capability(&pid, None, "ListBuckets", CapabilityClass::Allowed);
548
549 assert!(cache.get(&pid, None, "ListBuckets").is_some());
551
552 clock.advance(DEFAULT_TTL_SECS);
554
555 assert!(
556 cache.get(&pid, None, "ListBuckets").is_none(),
557 "record must expire after TTL"
558 );
559 }
560
561 #[test]
562 fn record_just_before_ttl_is_returned() {
563 let clock = MockClock::new(1_000_000);
564 let cache = CapabilityCache::with_clock(clock.clone());
565 let pid = profile("p-ttl2");
566
567 cache.record_capability(&pid, None, "ListBuckets", CapabilityClass::Allowed);
568
569 clock.advance(DEFAULT_TTL_SECS - 1);
571
572 assert!(
573 cache.get(&pid, None, "ListBuckets").is_some(),
574 "record must still be live one second before TTL"
575 );
576 }
577
578 #[test]
583 fn clear_all_removes_all_profile_entries() {
584 let cache = CapabilityCache::default();
585 let pid = profile("p-clear");
586 let bid = bucket("bucket-a");
587
588 cache.record_capability(&pid, None, "ListBuckets", CapabilityClass::Allowed);
589 cache.record_capability(&pid, Some(&bid), "PutObject", CapabilityClass::Allowed);
590
591 cache.clear(&pid, &ClearScope::All);
592
593 assert!(cache.get(&pid, None, "ListBuckets").is_none());
594 assert!(cache.get(&pid, Some(&bid), "PutObject").is_none());
595 }
596
597 #[test]
598 fn clear_all_does_not_touch_other_profiles() {
599 let cache = CapabilityCache::default();
600 let p1 = profile("p-clear-a");
601 let p2 = profile("p-clear-b");
602
603 cache.record_capability(&p1, None, "ListBuckets", CapabilityClass::Allowed);
604 cache.record_capability(&p2, None, "ListBuckets", CapabilityClass::Allowed);
605
606 cache.clear(&p1, &ClearScope::All);
607
608 assert!(cache.get(&p1, None, "ListBuckets").is_none());
609 assert!(
610 cache.get(&p2, None, "ListBuckets").is_some(),
611 "other profile's entries must survive"
612 );
613 }
614
615 #[test]
616 fn clear_bucket_removes_only_that_buckets_entries() {
617 let cache = CapabilityCache::default();
618 let pid = profile("p-bucket-clear");
619 let b_foo = bucket("foo");
620 let b_bar = bucket("bar");
621
622 cache.record_capability(&pid, Some(&b_foo), "PutObject", CapabilityClass::Allowed);
623 cache.record_capability(&pid, Some(&b_bar), "PutObject", CapabilityClass::Allowed);
624 cache.record_capability(&pid, None, "ListBuckets", CapabilityClass::Allowed);
625
626 cache.clear(&pid, &ClearScope::Bucket(b_foo.clone()));
627
628 assert!(
629 cache.get(&pid, Some(&b_foo), "PutObject").is_none(),
630 "foo's entry must be removed"
631 );
632 assert!(
633 cache.get(&pid, Some(&b_bar), "PutObject").is_some(),
634 "bar's entry must survive"
635 );
636 assert!(
637 cache.get(&pid, None, "ListBuckets").is_some(),
638 "profile-level entry must survive"
639 );
640 }
641
642 #[test]
647 fn get_map_returns_correct_subset_for_profile() {
648 let cache = CapabilityCache::default();
649 let p1 = profile("map-p1");
650 let p2 = profile("map-p2");
651 let bid = bucket("my-bucket");
652
653 cache.record_capability(&p1, Some(&bid), "PutObject", CapabilityClass::Allowed);
654 cache.record_capability(
655 &p1,
656 None,
657 "ListBuckets",
658 CapabilityClass::Denied {
659 iam_action: Some("s3:ListBuckets".to_owned()),
660 },
661 );
662 cache.record_capability(&p2, None, "ListBuckets", CapabilityClass::Allowed);
663
664 let map = cache.get_map(&p1);
665 assert_eq!(map.len(), 2, "p1 must have exactly 2 entries");
666 assert!(map.contains_key("my-bucket/PutObject"));
667 assert!(map.contains_key("/ListBuckets"));
668
669 let p2_map = cache.get_map(&p2);
670 assert_eq!(p2_map.len(), 1, "p2 must have exactly 1 entry");
671 }
672
673 #[test]
674 fn get_map_excludes_expired_entries() {
675 let clock = MockClock::new(2_000_000);
676 let cache = CapabilityCache::with_clock(clock.clone());
677 let pid = profile("map-ttl");
678
679 cache.record_capability(&pid, None, "ListBuckets", CapabilityClass::Allowed);
680
681 clock.advance(DEFAULT_TTL_SECS);
683
684 let map = cache.get_map(&pid);
685 assert!(map.is_empty(), "get_map must exclude expired entries");
686 }
687}