1use std::{
24 path::PathBuf,
25 sync::{
26 atomic::{AtomicBool, Ordering},
27 Arc,
28 },
29 time::{SystemTime, UNIX_EPOCH},
30};
31
32use aws_sdk_s3::{
33 primitives::ByteStream,
34 types::{CompletedMultipartUpload, CompletedPart},
35 Client,
36};
37use tokio::{fs as tokio_fs, sync::Semaphore};
38
39use crate::{
40 error::AppError,
41 events::{EventEmitter, EventKind},
42 ids::{BucketId, ObjectKey, ProfileId},
43 locks::{LockId, LockRegistry, LockScope, ReleaseReason},
44 notifications::{os::OsNotifyChannel, NotificationLogHandle},
45 s3::multipart::{MultipartRecord, MultipartTable},
46 transfers::{
47 notify::notify_terminal,
48 progress::{emit_progress, emit_state, emit_state_with_error, ProgressThrottle},
49 TransferRegistryHandle, TransferState,
50 },
51};
52
53const SINGLE_PART_THRESHOLD: u64 = 5 * 1024 * 1024; const MIN_PART_SIZE: u64 = 8 * 1024 * 1024; const MAX_PARTS: u64 = 9_999;
65
66#[derive(Clone, serde::Serialize)]
74#[serde(rename_all = "camelCase")]
75struct ObjectsUpdatedPayload {
76 profile_id: String,
77 bucket: String,
78 prefix: String,
79}
80
81pub async fn upload_object<E, C>(
105 client: Arc<Client>,
106 bucket: BucketId,
107 key: String,
108 source_path: PathBuf,
109 request_id: String,
110 channel: &E,
111 registry: TransferRegistryHandle,
112 lock_registry: Arc<LockRegistry>,
113 multipart_table: Arc<MultipartTable>,
114 transfer_concurrency_per_part: u32,
115 profile_id: ProfileId,
116 lock_ttl_secs: u64,
117 cancel_flag: Arc<AtomicBool>,
118 log: NotificationLogHandle,
119 os_notifier: &crate::notifications::os::OsNotifier<C>,
120) -> Result<(), AppError>
121where
122 E: EventEmitter,
123 C: OsNotifyChannel,
124{
125 let now = now_secs();
129 let scope = LockScope {
130 profile: profile_id.clone(),
131 bucket: Some(bucket.clone()),
132 prefix: None,
133 key: Some(ObjectKey::new(key.clone())),
134 };
135
136 let lock_id = lock_registry.acquire(scope, "upload", lock_ttl_secs, now)?;
137
138 {
142 let mut reg = registry.0.write().await;
143 let _ = reg.update(&request_id, |t| {
144 t.state = TransferState::Running;
145 });
146 }
147 let _ = emit_state(channel, &request_id, TransferState::Running);
148
149 let meta = match tokio_fs::metadata(&source_path).await {
153 Ok(m) => m,
154 Err(e) => {
155 let err = AppError::Network {
156 source: format!("metadata failed: {e}"),
157 };
158 cleanup_on_error(
159 &request_id,
160 ®istry,
161 &lock_registry,
162 &lock_id,
163 channel,
164 err.clone(),
165 &log,
166 os_notifier,
167 )
168 .await;
169 return Err(err);
170 }
171 };
172 let total_bytes = meta.len();
173
174 {
175 let mut reg = registry.0.write().await;
176 let _ = reg.update(&request_id, |t| {
177 t.total_bytes = Some(total_bytes);
178 });
179 }
180
181 let result = if total_bytes < SINGLE_PART_THRESHOLD {
185 single_part_upload(
186 &client,
187 &bucket,
188 &key,
189 &source_path,
190 total_bytes,
191 &request_id,
192 channel,
193 ®istry,
194 &cancel_flag,
195 )
196 .await
197 } else {
198 multipart_upload(
199 &client,
200 &bucket,
201 &key,
202 &source_path,
203 total_bytes,
204 &request_id,
205 channel,
206 ®istry,
207 &multipart_table,
208 transfer_concurrency_per_part,
209 &profile_id,
210 &cancel_flag,
211 )
212 .await
213 };
214
215 match &result {
219 Ok(()) => {
220 let finished_at = now_ms();
221 {
222 let mut reg = registry.0.write().await;
223 let _ = reg.update(&request_id, |t| {
224 t.state = TransferState::Done;
225 t.finished_at = Some(finished_at);
226 });
227 }
228 let _ = emit_state(channel, &request_id, TransferState::Done);
229
230 let prefix = key
232 .rfind('/')
233 .map(|i| key[..=i].to_owned())
234 .unwrap_or_default();
235 let _ = crate::events::emit(
236 channel,
237 EventKind::ObjectsUpdated,
238 ObjectsUpdatedPayload {
239 profile_id: profile_id.as_str().to_owned(),
240 bucket: bucket.as_str().to_owned(),
241 prefix,
242 },
243 );
244
245 if let Ok(lock) = lock_registry.release(&lock_id) {
246 let _ = crate::locks::emit_released(channel, &lock, ReleaseReason::Success);
247 }
248
249 if let Some(transfer) = registry.0.read().await.get(&request_id).cloned() {
251 let _ = notify_terminal(&transfer, channel, &log, os_notifier).await;
252 }
253 }
254 Err(AppError::Cancelled) => {
255 let finished_at = now_ms();
256 {
257 let mut reg = registry.0.write().await;
258 let _ = reg.update(&request_id, |t| {
259 t.state = TransferState::Canceled;
260 t.finished_at = Some(finished_at);
261 });
262 }
263 let _ = emit_state(channel, &request_id, TransferState::Canceled);
264
265 if let Ok(lock) = lock_registry.release(&lock_id) {
266 let _ = crate::locks::emit_released(channel, &lock, ReleaseReason::Cancel);
267 }
268
269 if let Some(transfer) = registry.0.read().await.get(&request_id).cloned() {
272 let _ = notify_terminal(&transfer, channel, &log, os_notifier).await;
273 }
274 }
275 Err(e) => {
276 let finished_at = now_ms();
277 {
278 let mut reg = registry.0.write().await;
279 let _ = reg.update(&request_id, |t| {
280 t.state = TransferState::Failed;
281 t.finished_at = Some(finished_at);
282 t.error = Some(e.clone());
283 });
284 }
285 let _ =
286 emit_state_with_error(channel, &request_id, TransferState::Failed, Some(e.clone()));
287
288 if let Ok(lock) = lock_registry.release(&lock_id) {
289 let _ = crate::locks::emit_released(channel, &lock, ReleaseReason::Failure);
290 }
291
292 if let Some(transfer) = registry.0.read().await.get(&request_id).cloned() {
294 let _ = notify_terminal(&transfer, channel, &log, os_notifier).await;
295 }
296 }
297 }
298
299 result
300}
301
302async fn single_part_upload<E: EventEmitter>(
307 client: &Client,
308 bucket: &BucketId,
309 key: &str,
310 source_path: &PathBuf,
311 total_bytes: u64,
312 request_id: &str,
313 channel: &E,
314 registry: &TransferRegistryHandle,
315 cancel_flag: &AtomicBool,
316) -> Result<(), AppError> {
317 if cancel_flag.load(Ordering::Acquire) {
318 return Err(AppError::Cancelled);
319 }
320
321 let file_bytes = tokio_fs::read(source_path)
322 .await
323 .map_err(|e| AppError::Network {
324 source: format!("read source file failed: {e}"),
325 })?;
326
327 let body = ByteStream::from(file_bytes);
328
329 client
330 .put_object()
331 .bucket(bucket.as_str())
332 .key(key)
333 .content_length(total_bytes as i64)
334 .body(body)
335 .send()
336 .await
337 .map_err(|e| AppError::Network {
338 source: format!("put_object failed: {e}"),
339 })?;
340
341 let mut throttle = ProgressThrottle::new();
343 let now = now_ms();
344 let _ = emit_progress(
345 channel,
346 request_id,
347 total_bytes,
348 Some(total_bytes),
349 0,
350 0,
351 &mut throttle,
352 now,
353 );
354
355 {
357 let mut reg = registry.0.write().await;
358 let _ = reg.update(request_id, |t| {
359 t.transferred_bytes = total_bytes;
360 });
361 }
362
363 Ok(())
364}
365
366#[allow(clippy::too_many_arguments)]
371async fn multipart_upload<E: EventEmitter>(
372 client: &Client,
373 bucket: &BucketId,
374 key: &str,
375 source_path: &PathBuf,
376 total_bytes: u64,
377 request_id: &str,
378 channel: &E,
379 registry: &TransferRegistryHandle,
380 multipart_table: &MultipartTable,
381 transfer_concurrency_per_part: u32,
382 profile_id: &ProfileId,
383 cancel_flag: &AtomicBool,
384) -> Result<(), AppError> {
385 if cancel_flag.load(Ordering::Acquire) {
389 return Err(AppError::Cancelled);
390 }
391
392 let create_resp = client
393 .create_multipart_upload()
394 .bucket(bucket.as_str())
395 .key(key)
396 .send()
397 .await
398 .map_err(|e| AppError::Network {
399 source: format!("create_multipart_upload failed: {e}"),
400 })?;
401
402 let upload_id = create_resp
403 .upload_id()
404 .ok_or_else(|| AppError::Internal {
405 trace_id: "create_multipart_upload returned no upload_id".to_owned(),
406 })?
407 .to_owned();
408
409 let record = MultipartRecord {
413 upload_id: upload_id.clone(),
414 started_at: now_ms(),
415 source: "brows3r".to_owned(),
416 profile_id: profile_id.clone(),
417 bucket: bucket.clone(),
418 key: key.to_owned(),
419 };
420 multipart_table.record(&record)?;
421
422 let part_size = std::cmp::max(MIN_PART_SIZE, (total_bytes + MAX_PARTS - 1) / MAX_PARTS);
426 let parts_total = ((total_bytes + part_size - 1) / part_size) as u32;
427
428 {
429 let mut reg = registry.0.write().await;
430 let _ = reg.update(request_id, |t| {
431 t.parts_total = parts_total;
432 });
433 }
434
435 let file_bytes = tokio_fs::read(source_path)
439 .await
440 .map_err(|e| AppError::Network {
441 source: format!("read source file failed: {e}"),
442 })?;
443
444 let part_chunks: Vec<(usize, Vec<u8>)> = file_bytes
447 .chunks(part_size as usize)
448 .enumerate()
449 .map(|(i, c)| (i, c.to_vec()))
450 .collect();
451 drop(file_bytes);
453
454 let semaphore = Arc::new(Semaphore::new(transfer_concurrency_per_part as usize));
455 let client_arc = Arc::new(client.clone());
456
457 let mut part_tasks = Vec::new();
458
459 for (i, chunk_data) in part_chunks {
460 if cancel_flag.load(Ordering::Acquire) {
461 abort_multipart(client, bucket, key, &upload_id, multipart_table, profile_id).await;
463 return Err(AppError::Cancelled);
464 }
465
466 let part_number = (i + 1) as i32;
467
468 let permit = semaphore
469 .clone()
470 .acquire_owned()
471 .await
472 .expect("semaphore must not be closed");
473
474 let client_clone = Arc::clone(&client_arc);
475 let bucket_str = bucket.as_str().to_owned();
476 let key_str = key.to_owned();
477 let upload_id_clone = upload_id.clone();
478
479 let task = tokio::spawn(async move {
480 let _permit = permit; let body = ByteStream::from(chunk_data);
482 let resp = client_clone
483 .upload_part()
484 .bucket(&bucket_str)
485 .key(&key_str)
486 .upload_id(&upload_id_clone)
487 .part_number(part_number)
488 .body(body)
489 .send()
490 .await
491 .map_err(|e| AppError::Network {
492 source: format!("upload_part {part_number} failed: {e}"),
493 })?;
494
495 let etag = resp
496 .e_tag()
497 .ok_or_else(|| AppError::Internal {
498 trace_id: format!("upload_part {part_number} returned no ETag"),
499 })?
500 .to_owned();
501
502 Ok::<(i32, String), AppError>((part_number, etag))
503 });
504
505 part_tasks.push(task);
506 }
507
508 let mut completed_parts: Vec<(i32, String)> = Vec::with_capacity(part_tasks.len());
512 let mut transferred_bytes: u64 = 0;
513 let mut throttle = ProgressThrottle::new();
514
515 for task in part_tasks {
516 if cancel_flag.load(Ordering::Acquire) {
517 abort_multipart(client, bucket, key, &upload_id, multipart_table, profile_id).await;
518 return Err(AppError::Cancelled);
519 }
520
521 match task.await {
522 Ok(Ok((part_number, etag))) => {
523 completed_parts.push((part_number, etag));
524 transferred_bytes += std::cmp::min(
525 part_size,
526 total_bytes.saturating_sub((part_number as u64 - 1) * part_size),
527 );
528
529 let parts_done = completed_parts.len() as u32;
530 {
531 let mut reg = registry.0.write().await;
532 let _ = reg.update(request_id, |t| {
533 t.transferred_bytes = transferred_bytes;
534 t.parts_done = parts_done;
535 });
536 }
537
538 let now = now_ms();
539 let _ = emit_progress(
540 channel,
541 request_id,
542 transferred_bytes,
543 Some(total_bytes),
544 parts_done,
545 parts_total,
546 &mut throttle,
547 now,
548 );
549 }
550 Ok(Err(e)) => {
551 abort_multipart(client, bucket, key, &upload_id, multipart_table, profile_id).await;
552 return Err(e);
553 }
554 Err(join_err) => {
555 abort_multipart(client, bucket, key, &upload_id, multipart_table, profile_id).await;
556 return Err(AppError::Internal {
557 trace_id: format!("part task join failed: {join_err}"),
558 });
559 }
560 }
561 }
562
563 completed_parts.sort_by_key(|(n, _)| *n);
567
568 let completed = CompletedMultipartUpload::builder()
569 .set_parts(Some(
570 completed_parts
571 .into_iter()
572 .map(|(n, etag)| CompletedPart::builder().part_number(n).e_tag(etag).build())
573 .collect(),
574 ))
575 .build();
576
577 client
578 .complete_multipart_upload()
579 .bucket(bucket.as_str())
580 .key(key)
581 .upload_id(&upload_id)
582 .multipart_upload(completed)
583 .send()
584 .await
585 .map_err(|e| {
586 let bucket_str = bucket.as_str().to_owned();
588 let key_str = key.to_owned();
589 let upload_id_str = upload_id.clone();
590 let client_clone = Arc::clone(&client_arc);
591 let _ = tokio::spawn(async move {
592 let _ = client_clone
593 .abort_multipart_upload()
594 .bucket(&bucket_str)
595 .key(&key_str)
596 .upload_id(&upload_id_str)
597 .send()
598 .await;
599 });
600 AppError::Network {
601 source: format!("complete_multipart_upload failed: {e}"),
602 }
603 })?;
604
605 let _ = multipart_table.remove(profile_id, bucket, key);
607
608 Ok(())
609}
610
611async fn abort_multipart(
616 client: &Client,
617 bucket: &BucketId,
618 key: &str,
619 upload_id: &str,
620 multipart_table: &MultipartTable,
621 profile_id: &ProfileId,
622) {
623 let _ = client
624 .abort_multipart_upload()
625 .bucket(bucket.as_str())
626 .key(key)
627 .upload_id(upload_id)
628 .send()
629 .await;
630
631 let _ = multipart_table.remove(profile_id, bucket, key);
633}
634
635async fn cleanup_on_error<E, C>(
640 request_id: &str,
641 registry: &TransferRegistryHandle,
642 lock_registry: &LockRegistry,
643 lock_id: &LockId,
644 channel: &E,
645 error: AppError,
646 log: &NotificationLogHandle,
647 os_notifier: &crate::notifications::os::OsNotifier<C>,
648) where
649 E: EventEmitter,
650 C: OsNotifyChannel,
651{
652 let finished_at = now_ms();
653 let error_for_emit = error.clone();
654 {
655 let mut reg = registry.0.write().await;
656 let _ = reg.update(request_id, |t| {
657 t.state = TransferState::Failed;
658 t.finished_at = Some(finished_at);
659 t.error = Some(error);
660 });
661 }
662
663 let _ = emit_state_with_error(
664 channel,
665 request_id,
666 TransferState::Failed,
667 Some(error_for_emit),
668 );
669
670 if let Ok(lock) = lock_registry.release(lock_id) {
671 let _ = crate::locks::emit_released(channel, &lock, ReleaseReason::Failure);
672 }
673
674 if let Some(transfer) = registry.0.read().await.get(request_id).cloned() {
676 let _ = notify_terminal(&transfer, channel, log, os_notifier).await;
677 }
678}
679
680fn now_secs() -> i64 {
685 SystemTime::now()
686 .duration_since(UNIX_EPOCH)
687 .map(|d| d.as_secs() as i64)
688 .unwrap_or(0)
689}
690
691fn now_ms() -> i64 {
692 SystemTime::now()
693 .duration_since(UNIX_EPOCH)
694 .map(|d| d.as_millis() as i64)
695 .unwrap_or(0)
696}
697
698pub fn compute_part_size(total_bytes: u64) -> u64 {
706 std::cmp::max(MIN_PART_SIZE, (total_bytes + MAX_PARTS - 1) / MAX_PARTS)
707}
708
709#[cfg(test)]
714mod tests {
715 use super::*;
716
717 #[test]
722 fn part_size_small_file_uses_minimum() {
723 assert_eq!(compute_part_size(5 * 1024 * 1024), MIN_PART_SIZE);
726 }
727
728 #[test]
729 fn part_size_50mb_is_minimum() {
730 assert_eq!(compute_part_size(50 * 1024 * 1024), MIN_PART_SIZE);
732 }
733
734 #[test]
735 fn part_size_500mb_is_minimum() {
736 assert_eq!(compute_part_size(500 * 1024 * 1024), MIN_PART_SIZE);
738 }
739
740 #[test]
741 fn part_size_5gb_scales_above_minimum() {
742 let total = 5u64 * 1024 * 1024 * 1024; let ps = compute_part_size(total);
744 assert_eq!(ps, MIN_PART_SIZE, "5 GB should still use 8 MB part size");
746 let num_parts = (total + ps - 1) / ps;
748 assert!(num_parts <= MAX_PARTS, "must not exceed 9999 parts");
749 }
750
751 #[test]
752 fn part_size_huge_file_caps_below_10000_parts() {
753 let total = 200u64 * 1024 * 1024 * 1024;
755 let ps = compute_part_size(total);
756 let num_parts = (total + ps - 1) / ps;
757 assert!(num_parts <= MAX_PARTS, "must not exceed 9999 parts");
758 assert!(ps >= MIN_PART_SIZE, "part size must be at least 8 MB");
759 }
760
761 #[test]
762 fn single_part_threshold_is_5mb() {
763 assert_eq!(SINGLE_PART_THRESHOLD, 5 * 1024 * 1024);
764 }
765
766 #[test]
771 fn prefix_extracted_correctly() {
772 let key = "data/2024/file.bin";
773 let prefix = key
774 .rfind('/')
775 .map(|i| key[..=i].to_owned())
776 .unwrap_or_default();
777 assert_eq!(prefix, "data/2024/");
778 }
779
780 #[test]
781 fn prefix_for_root_key_is_empty() {
782 let key = "rootfile.bin";
783 let prefix = key
784 .rfind('/')
785 .map(|i| key[..=i].to_owned())
786 .unwrap_or_default();
787 assert_eq!(prefix, "");
788 }
789}