592 lines
20 KiB
Rust
592 lines
20 KiB
Rust
//! File transfer module — chunked file transfers over dedicated QUIC streams.
|
|
|
|
use std::collections::HashMap;
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use anyhow::{Context, Result};
|
|
use sha2::{Digest, Sha256};
|
|
use tokio::fs::File;
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
use crate::net::EndpointId;
|
|
use crate::protocol::{
|
|
decode_framed, new_file_id, write_framed, FileAcceptReject, FileChunk, FileDone, FileId,
|
|
FileOffer, FileStreamMessage,
|
|
};
|
|
|
|
/// Chunk size for file transfers (64 KB).
|
|
const CHUNK_SIZE: usize = 64 * 1024;
|
|
|
|
/// State of a file transfer.
|
|
#[derive(Debug, Clone)]
|
|
#[allow(dead_code)]
|
|
pub enum TransferState {
|
|
/// We offered a file, waiting for opponent to accept.
|
|
Offering,
|
|
/// We received an offer, waiting for user to accept.
|
|
WaitingForAccept { expires_at: std::time::Instant },
|
|
/// We requested the file, waiting for sender to connect.
|
|
Requesting { expires_at: std::time::Instant },
|
|
/// Transfer is in progress.
|
|
Transferring {
|
|
bytes_transferred: u64,
|
|
total_size: u64,
|
|
start_time: std::time::Instant,
|
|
},
|
|
/// Transfer completed successfully.
|
|
Complete { completed_at: std::time::Instant },
|
|
/// Transfer was rejected by the peer.
|
|
Rejected { completed_at: std::time::Instant },
|
|
/// Transfer failed with an error.
|
|
Failed { error: String, completed_at: std::time::Instant },
|
|
}
|
|
|
|
/// Information about a tracked file transfer.
|
|
#[derive(Debug, Clone)]
|
|
#[allow(dead_code)]
|
|
pub struct TransferInfo {
|
|
pub file_id: FileId,
|
|
pub file_name: String,
|
|
pub file_size: u64,
|
|
pub state: TransferState,
|
|
pub is_outgoing: bool,
|
|
pub peer: Option<EndpointId>,
|
|
pub path: Option<PathBuf>,
|
|
pub offer: Option<FileOffer>,
|
|
}
|
|
|
|
use std::sync::{Arc, Mutex};
|
|
|
|
/// Manages file transfers.
|
|
#[derive(Clone)]
|
|
pub struct FileTransferManager {
|
|
pub transfers: Arc<Mutex<HashMap<FileId, TransferInfo>>>,
|
|
pub pending_accepts: Arc<Mutex<HashMap<FileId, tokio::sync::oneshot::Sender<bool>>>>,
|
|
#[allow(dead_code)]
|
|
pub download_dir: PathBuf,
|
|
}
|
|
|
|
impl FileTransferManager {
|
|
pub fn new(download_dir: PathBuf) -> Self {
|
|
Self {
|
|
transfers: Arc::new(Mutex::new(HashMap::new())),
|
|
pending_accepts: Arc::new(Mutex::new(HashMap::new())),
|
|
download_dir,
|
|
}
|
|
}
|
|
|
|
/// Initiate sending a file to a peer.
|
|
/// Returns the file ID and the file offer for broadcasting.
|
|
pub async fn prepare_send(&self, file_path: &Path) -> Result<(FileId, FileOffer)> {
|
|
let file_name = file_path
|
|
.file_name()
|
|
.context("No filename")?
|
|
.to_string_lossy()
|
|
.to_string();
|
|
|
|
let metadata = tokio::fs::metadata(file_path)
|
|
.await
|
|
.context("Failed to read file metadata")?;
|
|
let file_size = metadata.len();
|
|
|
|
// Compute SHA-256 checksum
|
|
let mut file = File::open(file_path).await?;
|
|
let mut hasher = Sha256::new();
|
|
let mut buf = vec![0u8; CHUNK_SIZE];
|
|
loop {
|
|
let n = file.read(&mut buf).await?;
|
|
if n == 0 {
|
|
break;
|
|
}
|
|
hasher.update(&buf[..n]);
|
|
}
|
|
let checksum: [u8; 32] = hasher.finalize().into();
|
|
|
|
let file_id = new_file_id();
|
|
|
|
// Retrieve timeout from config (default 60s)
|
|
let timeout = if let Ok(cfg) = crate::config::AppConfig::load() {
|
|
cfg.files.default_timeout_seconds
|
|
} else {
|
|
60
|
|
};
|
|
|
|
let offer = FileOffer {
|
|
file_id,
|
|
name: file_name.clone(),
|
|
size: file_size,
|
|
checksum,
|
|
timeout,
|
|
};
|
|
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
transfers.insert(
|
|
file_id,
|
|
TransferInfo {
|
|
file_id,
|
|
file_name,
|
|
file_size,
|
|
state: TransferState::WaitingForAccept {
|
|
expires_at: std::time::Instant::now() + std::time::Duration::from_secs(timeout),
|
|
},
|
|
is_outgoing: true,
|
|
peer: None,
|
|
path: Some(file_path.to_path_buf()),
|
|
offer: Some(offer.clone()),
|
|
},
|
|
);
|
|
}
|
|
|
|
Ok((file_id, offer))
|
|
}
|
|
|
|
pub fn check_timeouts(&self) {
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
let now = std::time::Instant::now();
|
|
for info in transfers.values_mut() {
|
|
match info.state {
|
|
TransferState::WaitingForAccept { expires_at } | TransferState::Requesting { expires_at } => {
|
|
if now > expires_at {
|
|
info.state = TransferState::Failed { error: "Timed out".to_string(), completed_at: now };
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
// Remove expired
|
|
let to_remove: Vec<FileId> = transfers.iter().filter_map(|(id, info)| {
|
|
match info.state {
|
|
TransferState::Complete { completed_at } |
|
|
TransferState::Rejected { completed_at } |
|
|
TransferState::Failed { completed_at, .. } => {
|
|
if now.duration_since(completed_at) > std::time::Duration::from_secs(10) {
|
|
Some(*id)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
_ => None
|
|
}
|
|
}).collect();
|
|
|
|
for id in to_remove {
|
|
transfers.remove(&id);
|
|
}
|
|
|
|
// Also cleanup pending accepts for timed out transfers?
|
|
// Actually, execute_receive handles its own timeout for pending_accepts channel.
|
|
// But for sender side, we don't have a pending accept channel causing a block?
|
|
// Sender is waiting in `execute_send`?
|
|
// Wait, execute_send sends Offer and waits for Accept/Reject.
|
|
// It uses `decode_framed(recv)`.
|
|
// If recipient never replies (timed out), sender is stuck in `decode_framed`.
|
|
// We need `execute_send` to timeout as well!
|
|
// But `check_timeouts` only updates the State in the Map.
|
|
// It doesn't interrupt the async `execute_send`.
|
|
// `execute_send` logic needs a timeout on `decode_framed`.
|
|
}
|
|
|
|
pub fn accept_transfer(&self, file_id: FileId) -> bool {
|
|
let mut pending = self.pending_accepts.lock().unwrap();
|
|
if let Some(tx) = pending.remove(&file_id) {
|
|
let _ = tx.send(true);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn reject_transfer(&self, file_id: FileId) -> bool {
|
|
let mut pending = self.pending_accepts.lock().unwrap();
|
|
if let Some(tx) = pending.remove(&file_id) {
|
|
let _ = tx.send(false);
|
|
true
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Execute the sending side of a file transfer over a QUIC bi-stream.
|
|
#[allow(dead_code)]
|
|
pub async fn execute_send(
|
|
&self,
|
|
file_id: FileId,
|
|
file_path: &Path,
|
|
offer: FileOffer,
|
|
send: &mut iroh::endpoint::SendStream,
|
|
recv: &mut iroh::endpoint::RecvStream,
|
|
) -> Result<()> {
|
|
// Send the offer
|
|
write_framed(send, &FileStreamMessage::Offer(offer)).await?;
|
|
|
|
// Wait for accept or reject
|
|
let response: FileStreamMessage = decode_framed(recv).await?;
|
|
match response {
|
|
FileStreamMessage::Accept(_) => {
|
|
// Proceed with transfer
|
|
}
|
|
FileStreamMessage::Reject(_) => {
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Rejected { completed_at: std::time::Instant::now() };
|
|
}
|
|
return Ok(());
|
|
}
|
|
_ => {
|
|
anyhow::bail!("Unexpected response to file offer");
|
|
}
|
|
}
|
|
|
|
// Stream file chunks
|
|
let mut file = File::open(file_path).await?;
|
|
let mut offset: u64 = 0;
|
|
let total_size = tokio::fs::metadata(file_path).await?.len();
|
|
let mut buf = vec![0u8; CHUNK_SIZE];
|
|
let start_time = std::time::Instant::now();
|
|
|
|
loop {
|
|
let n = file.read(&mut buf).await?;
|
|
if n == 0 {
|
|
break;
|
|
}
|
|
|
|
let chunk = FileStreamMessage::Chunk(FileChunk {
|
|
file_id,
|
|
offset,
|
|
data: buf[..n].to_vec(),
|
|
});
|
|
write_framed(send, &chunk).await?;
|
|
offset += n as u64;
|
|
|
|
// Update progress
|
|
// Scope limit the lock
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Transferring {
|
|
bytes_transferred: offset,
|
|
total_size,
|
|
start_time,
|
|
};
|
|
}
|
|
}
|
|
}
|
|
|
|
// Send done
|
|
write_framed(send, &FileStreamMessage::Done(FileDone { file_id })).await?;
|
|
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Complete { completed_at: std::time::Instant::now() };
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Handle an incoming file offer for display.
|
|
#[allow(dead_code)]
|
|
pub fn register_incoming_offer(&self, offer: &FileOffer, peer: EndpointId) {
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
transfers.insert(
|
|
offer.file_id,
|
|
TransferInfo {
|
|
file_id: offer.file_id,
|
|
file_name: offer.name.clone(),
|
|
file_size: offer.size,
|
|
state: TransferState::Offering,
|
|
is_outgoing: false,
|
|
peer: Some(peer),
|
|
path: None,
|
|
offer: None,
|
|
},
|
|
);
|
|
}
|
|
|
|
/// Handle an incoming file offer broadcast for display.
|
|
pub fn register_incoming_broadcast(
|
|
&self,
|
|
offer: &crate::protocol::FileOfferBroadcast,
|
|
peer: EndpointId,
|
|
) {
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
// Calculate expires_at based on timeout
|
|
let expires_at = std::time::Instant::now() + std::time::Duration::from_secs(offer.timeout);
|
|
|
|
transfers.insert(
|
|
offer.file_id,
|
|
TransferInfo {
|
|
file_id: offer.file_id,
|
|
file_name: offer.file_name.clone(),
|
|
file_size: offer.file_size,
|
|
// We go directly to WaitingForAccept because it's an offer we can accept
|
|
state: TransferState::WaitingForAccept { expires_at },
|
|
is_outgoing: false,
|
|
peer: Some(peer),
|
|
path: None,
|
|
offer: None,
|
|
},
|
|
);
|
|
}
|
|
|
|
/// Execute the receiving side of a file transfer over a QUIC bi-stream.
|
|
#[allow(dead_code)]
|
|
pub async fn execute_receive(
|
|
&self,
|
|
send: &mut iroh::endpoint::SendStream,
|
|
recv: &mut iroh::endpoint::RecvStream,
|
|
) -> Result<FileId> {
|
|
// Read the offer
|
|
let msg: FileStreamMessage = decode_framed(recv).await?;
|
|
let offer = match msg {
|
|
FileStreamMessage::Offer(o) => o,
|
|
_ => anyhow::bail!("Expected file offer"),
|
|
};
|
|
|
|
let file_id = offer.file_id;
|
|
|
|
let timeout_duration = std::time::Duration::from_secs(offer.timeout);
|
|
let expires_at = std::time::Instant::now() + timeout_duration;
|
|
|
|
// Check if pre-accepted (Requesting state)
|
|
let is_pre_accepted = {
|
|
let transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get(&file_id) {
|
|
matches!(info.state, TransferState::Requesting { .. })
|
|
} else {
|
|
false
|
|
}
|
|
};
|
|
|
|
// Register incoming offer (unless pre-accepted logic differs)
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
transfers.insert(
|
|
offer.file_id,
|
|
TransferInfo {
|
|
file_id: offer.file_id,
|
|
file_name: offer.name.clone(),
|
|
file_size: offer.size,
|
|
state: TransferState::WaitingForAccept { expires_at },
|
|
is_outgoing: false,
|
|
peer: None, // We don't need peer for receive technically, as we are connected
|
|
path: None,
|
|
offer: Some(offer.clone()), // Store incoming offer if useful
|
|
},
|
|
);
|
|
}
|
|
|
|
let accepted = if is_pre_accepted {
|
|
true
|
|
} else {
|
|
// Wait for acceptance
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
{
|
|
let mut pending = self.pending_accepts.lock().unwrap();
|
|
pending.insert(file_id, tx);
|
|
}
|
|
|
|
// Wait for signal or timeout
|
|
match tokio::time::timeout(timeout_duration, rx).await {
|
|
Ok(Ok(decision)) => decision,
|
|
Ok(Err(_)) => false, // Sender dropped?
|
|
Err(_) => {
|
|
// Timeout
|
|
false
|
|
}
|
|
}
|
|
};
|
|
|
|
// Remove from pending if still there (in case of timeout or pre-accept bypass)
|
|
{
|
|
let mut pending = self.pending_accepts.lock().unwrap();
|
|
pending.remove(&file_id);
|
|
}
|
|
|
|
if !accepted {
|
|
// Reject
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
// Check if it wasn't already updated (e.g. by manual reject)
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
// Could be Offering or WaitingForAccept depending on race
|
|
info.state = TransferState::Rejected { completed_at: std::time::Instant::now() };
|
|
}
|
|
}
|
|
|
|
write_framed(
|
|
send,
|
|
&FileStreamMessage::Reject(FileAcceptReject { file_id }),
|
|
)
|
|
.await?;
|
|
|
|
return Ok(file_id);
|
|
}
|
|
|
|
// Accepted
|
|
write_framed(
|
|
send,
|
|
&FileStreamMessage::Accept(FileAcceptReject { file_id }),
|
|
)
|
|
.await?;
|
|
|
|
// Update state to Transferring immediately to stop spinner
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Transferring {
|
|
bytes_transferred: 0,
|
|
total_size: offer.size,
|
|
start_time: std::time::Instant::now(),
|
|
};
|
|
}
|
|
}
|
|
|
|
// Receive chunks
|
|
let dest_path = self.download_dir.join(&offer.name);
|
|
|
|
// Ensure download dir exists for safety (though main normally does this)
|
|
if let Some(parent) = dest_path.parent() {
|
|
let _ = tokio::fs::create_dir_all(parent).await;
|
|
}
|
|
|
|
let mut file = tokio::fs::OpenOptions::new()
|
|
.create(true)
|
|
.write(true)
|
|
.truncate(true)
|
|
.open(&dest_path)
|
|
.await?;
|
|
|
|
let mut received: u64 = 0;
|
|
let start_time = std::time::Instant::now();
|
|
|
|
loop {
|
|
let chunk_msg: FileStreamMessage = decode_framed(recv).await?;
|
|
match chunk_msg {
|
|
FileStreamMessage::Chunk(chunk) => {
|
|
file.write_all(&chunk.data).await?;
|
|
received += chunk.data.len() as u64;
|
|
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Transferring {
|
|
bytes_transferred: received,
|
|
total_size: offer.size,
|
|
start_time,
|
|
};
|
|
}
|
|
}
|
|
}
|
|
FileStreamMessage::Done(_) => {
|
|
break;
|
|
}
|
|
_ => {
|
|
anyhow::bail!("Unexpected message during file transfer");
|
|
}
|
|
}
|
|
}
|
|
|
|
file.flush().await?;
|
|
|
|
{
|
|
let mut transfers = self.transfers.lock().unwrap();
|
|
if let Some(info) = transfers.get_mut(&file_id) {
|
|
info.state = TransferState::Complete { completed_at: std::time::Instant::now() };
|
|
}
|
|
}
|
|
|
|
Ok(file_id)
|
|
}
|
|
|
|
/// Get a summary of active/recent transfers for display.
|
|
pub fn active_transfers(&self) -> Vec<TransferInfo> {
|
|
let transfers = self.transfers.lock().unwrap();
|
|
transfers.values().cloned().collect()
|
|
}
|
|
|
|
/// Format transfer progress as a human-readable string.
|
|
#[allow(dead_code)]
|
|
pub fn format_progress(info: &TransferInfo) -> String {
|
|
let direction = if info.is_outgoing { "↑" } else { "↓" };
|
|
match &info.state {
|
|
TransferState::Offering => {
|
|
format!("{} {} (Offering)", direction, info.file_name)
|
|
}
|
|
TransferState::Requesting { .. } => {
|
|
format!("{} {} (Requesting...)", direction, info.file_name)
|
|
}
|
|
TransferState::WaitingForAccept { expires_at } => {
|
|
let now = std::time::Instant::now();
|
|
let remaining = if *expires_at > now {
|
|
expires_at.duration_since(now).as_secs()
|
|
} else {
|
|
0
|
|
};
|
|
if info.is_outgoing {
|
|
format!("{} {} (Wait Accept - {}s)", direction, info.file_name, remaining)
|
|
} else {
|
|
format!(
|
|
"{} {} (Incoming Offer - {}s)",
|
|
direction, info.file_name, remaining
|
|
)
|
|
}
|
|
}
|
|
TransferState::Transferring {
|
|
bytes_transferred,
|
|
total_size,
|
|
start_time,
|
|
} => {
|
|
let pct = if *total_size > 0 {
|
|
(*bytes_transferred as f64 / *total_size as f64 * 100.0) as u8
|
|
} else {
|
|
0
|
|
};
|
|
let elapsed = start_time.elapsed().as_secs_f64();
|
|
let speed_bps = if elapsed > 0.0 {
|
|
*bytes_transferred as f64 / elapsed
|
|
} else {
|
|
0.0
|
|
};
|
|
let speed_mbps = speed_bps / (1024.0 * 1024.0);
|
|
|
|
format!(
|
|
"{} {} {}% ({}/{}) {:.1} MB/s",
|
|
direction,
|
|
info.file_name,
|
|
pct,
|
|
format_bytes(*bytes_transferred),
|
|
format_bytes(*total_size),
|
|
speed_mbps
|
|
)
|
|
}
|
|
TransferState::Complete { .. } => {
|
|
format!("{} {} ✓ complete", direction, info.file_name)
|
|
}
|
|
TransferState::Rejected { .. } => {
|
|
format!("{} {} ✗ rejected", direction, info.file_name)
|
|
}
|
|
TransferState::Failed { error, .. } => {
|
|
format!("{} {} ✗ {}", direction, info.file_name, error)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub fn format_bytes(bytes: u64) -> String {
|
|
if bytes < 1024 {
|
|
format!("{}B", bytes)
|
|
} else if bytes < 1024 * 1024 {
|
|
format!("{:.1}KiB", bytes as f64 / 1024.0)
|
|
} else if bytes < 1024 * 1024 * 1024 {
|
|
format!("{:.1}MiB", bytes as f64 / (1024.0 * 1024.0))
|
|
} else {
|
|
format!("{:.1}GiB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
|
|
}
|
|
}
|