diff --git a/src/main.rs b/src/main.rs index a1feae2..8640c33 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1021,7 +1021,7 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { let db = iroh_blobs::store::fs::FsStore::load(&iroh_data_dir).await?; let db2 = db.clone(); trace!("load done!"); - let fut = async move { + let fut = async { trace!("running"); let mut mp: MultiProgress = MultiProgress::new(); let draw_target = if args.common.no_progress { @@ -1114,8 +1114,12 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { }; let (total_files, payload_size, stats) = select! { x = fut => match x { - Ok(x) => x, + Ok(x) => { + endpoint.close().await; + x + } Err(e) => { + endpoint.close().await; // make sure we shutdown the db before exiting db2.shutdown().await?; eprintln!("error: {e}"); @@ -1123,6 +1127,7 @@ async fn receive(args: ReceiveArgs) -> anyhow::Result<()> { } }, _ = tokio::signal::ctrl_c() => { + endpoint.close().await; db2.shutdown().await?; std::process::exit(130); } diff --git a/tests/cli.rs b/tests/cli.rs index 7c6a6bd..269c41b 100644 --- a/tests/cli.rs +++ b/tests/cli.rs @@ -78,6 +78,48 @@ fn send_recv_file() { assert_eq!(tgt_data, data); } +#[test] +fn receive_closes_endpoint_no_iroh_socket_error() { + let name = "graceful-close.bin"; + let data = vec![0xabu8; 64]; + let src_dir = tempfile::tempdir().unwrap(); + let tgt_dir = tempfile::tempdir().unwrap(); + let src_file = src_dir.path().join(name); + std::fs::write(&src_file, &data).unwrap(); + let mut send_cmd = duct::cmd( + sendme_bin(), + ["send", src_file.as_os_str().to_str().unwrap()], + ) + .dir(src_dir.path()) + .env_remove("RUST_LOG") + .stderr_to_stdout() + .reader() + .unwrap(); + let output = read_ascii_lines(3, &mut send_cmd).unwrap(); + let output = String::from_utf8(output).unwrap(); + let ticket = output.split_ascii_whitespace().last().unwrap(); + let ticket = BlobTicket::from_str(ticket).unwrap(); + let receive_output = duct::cmd(sendme_bin(), ["receive", &ticket.to_string()]) + .dir(tgt_dir.path()) + .env("RUST_LOG", "iroh::socket=error") + .stdout_capture() + .stderr_capture() + .run() + .unwrap(); + assert!(receive_output.status.success(), "{receive_output:?}"); + let stderr = String::from_utf8_lossy(&receive_output.stderr); + assert!( + !stderr.contains("Endpoint dropped"), + "unexpected iroh shutdown log on stderr: {stderr}" + ); + assert!( + !stderr.contains("Aborting ungracefully"), + "unexpected iroh shutdown log on stderr: {stderr}" + ); + let tgt_file = tgt_dir.path().join(name); + assert_eq!(std::fs::read(&tgt_file).unwrap(), data); +} + #[test] fn send_recv_dir() { fn create_file(base: &Path, i: usize, j: usize, k: usize) -> (PathBuf, Vec) {