Files
p2p-chat/src/file_transfer/mod.rs

592 lines
20 KiB
Rust

//! File transfer module — chunked file transfers over dedicated QUIC streams.
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use anyhow::{Context, Result};
use sha2::{Digest, Sha256};
use tokio::fs::File;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use crate::net::EndpointId;
use crate::protocol::{
decode_framed, new_file_id, write_framed, FileAcceptReject, FileChunk, FileDone, FileId,
FileOffer, FileStreamMessage,
};
/// Chunk size for file transfers (64 KB).
const CHUNK_SIZE: usize = 64 * 1024;
/// State of a file transfer.
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub enum TransferState {
/// We offered a file, waiting for opponent to accept.
Offering,
/// We received an offer, waiting for user to accept.
WaitingForAccept { expires_at: std::time::Instant },
/// We requested the file, waiting for sender to connect.
Requesting { expires_at: std::time::Instant },
/// Transfer is in progress.
Transferring {
bytes_transferred: u64,
total_size: u64,
start_time: std::time::Instant,
},
/// Transfer completed successfully.
Complete { completed_at: std::time::Instant },
/// Transfer was rejected by the peer.
Rejected { completed_at: std::time::Instant },
/// Transfer failed with an error.
Failed { error: String, completed_at: std::time::Instant },
}
/// Information about a tracked file transfer.
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct TransferInfo {
pub file_id: FileId,
pub file_name: String,
pub file_size: u64,
pub state: TransferState,
pub is_outgoing: bool,
pub peer: Option<EndpointId>,
pub path: Option<PathBuf>,
pub offer: Option<FileOffer>,
}
use std::sync::{Arc, Mutex};
/// Manages file transfers.
#[derive(Clone)]
pub struct FileTransferManager {
pub transfers: Arc<Mutex<HashMap<FileId, TransferInfo>>>,
pub pending_accepts: Arc<Mutex<HashMap<FileId, tokio::sync::oneshot::Sender<bool>>>>,
#[allow(dead_code)]
pub download_dir: PathBuf,
}
impl FileTransferManager {
pub fn new(download_dir: PathBuf) -> Self {
Self {
transfers: Arc::new(Mutex::new(HashMap::new())),
pending_accepts: Arc::new(Mutex::new(HashMap::new())),
download_dir,
}
}
/// Initiate sending a file to a peer.
/// Returns the file ID and the file offer for broadcasting.
pub async fn prepare_send(&self, file_path: &Path) -> Result<(FileId, FileOffer)> {
let file_name = file_path
.file_name()
.context("No filename")?
.to_string_lossy()
.to_string();
let metadata = tokio::fs::metadata(file_path)
.await
.context("Failed to read file metadata")?;
let file_size = metadata.len();
// Compute SHA-256 checksum
let mut file = File::open(file_path).await?;
let mut hasher = Sha256::new();
let mut buf = vec![0u8; CHUNK_SIZE];
loop {
let n = file.read(&mut buf).await?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
let checksum: [u8; 32] = hasher.finalize().into();
let file_id = new_file_id();
// Retrieve timeout from config (default 60s)
let timeout = if let Ok(cfg) = crate::config::AppConfig::load() {
cfg.files.default_timeout_seconds
} else {
60
};
let offer = FileOffer {
file_id,
name: file_name.clone(),
size: file_size,
checksum,
timeout,
};
{
let mut transfers = self.transfers.lock().unwrap();
transfers.insert(
file_id,
TransferInfo {
file_id,
file_name,
file_size,
state: TransferState::WaitingForAccept {
expires_at: std::time::Instant::now() + std::time::Duration::from_secs(timeout),
},
is_outgoing: true,
peer: None,
path: Some(file_path.to_path_buf()),
offer: Some(offer.clone()),
},
);
}
Ok((file_id, offer))
}
pub fn check_timeouts(&self) {
let mut transfers = self.transfers.lock().unwrap();
let now = std::time::Instant::now();
for info in transfers.values_mut() {
match info.state {
TransferState::WaitingForAccept { expires_at } | TransferState::Requesting { expires_at } => {
if now > expires_at {
info.state = TransferState::Failed { error: "Timed out".to_string(), completed_at: now };
}
}
_ => {}
}
}
// Remove expired
let to_remove: Vec<FileId> = transfers.iter().filter_map(|(id, info)| {
match info.state {
TransferState::Complete { completed_at } |
TransferState::Rejected { completed_at } |
TransferState::Failed { completed_at, .. } => {
if now.duration_since(completed_at) > std::time::Duration::from_secs(10) {
Some(*id)
} else {
None
}
}
_ => None
}
}).collect();
for id in to_remove {
transfers.remove(&id);
}
// Also cleanup pending accepts for timed out transfers?
// Actually, execute_receive handles its own timeout for pending_accepts channel.
// But for sender side, we don't have a pending accept channel causing a block?
// Sender is waiting in `execute_send`?
// Wait, execute_send sends Offer and waits for Accept/Reject.
// It uses `decode_framed(recv)`.
// If recipient never replies (timed out), sender is stuck in `decode_framed`.
// We need `execute_send` to timeout as well!
// But `check_timeouts` only updates the State in the Map.
// It doesn't interrupt the async `execute_send`.
// `execute_send` logic needs a timeout on `decode_framed`.
}
pub fn accept_transfer(&self, file_id: FileId) -> bool {
let mut pending = self.pending_accepts.lock().unwrap();
if let Some(tx) = pending.remove(&file_id) {
let _ = tx.send(true);
true
} else {
false
}
}
#[allow(dead_code)]
pub fn reject_transfer(&self, file_id: FileId) -> bool {
let mut pending = self.pending_accepts.lock().unwrap();
if let Some(tx) = pending.remove(&file_id) {
let _ = tx.send(false);
true
} else {
false
}
}
/// Execute the sending side of a file transfer over a QUIC bi-stream.
#[allow(dead_code)]
pub async fn execute_send(
&self,
file_id: FileId,
file_path: &Path,
offer: FileOffer,
send: &mut iroh::endpoint::SendStream,
recv: &mut iroh::endpoint::RecvStream,
) -> Result<()> {
// Send the offer
write_framed(send, &FileStreamMessage::Offer(offer)).await?;
// Wait for accept or reject
let response: FileStreamMessage = decode_framed(recv).await?;
match response {
FileStreamMessage::Accept(_) => {
// Proceed with transfer
}
FileStreamMessage::Reject(_) => {
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Rejected { completed_at: std::time::Instant::now() };
}
return Ok(());
}
_ => {
anyhow::bail!("Unexpected response to file offer");
}
}
// Stream file chunks
let mut file = File::open(file_path).await?;
let mut offset: u64 = 0;
let total_size = tokio::fs::metadata(file_path).await?.len();
let mut buf = vec![0u8; CHUNK_SIZE];
let start_time = std::time::Instant::now();
loop {
let n = file.read(&mut buf).await?;
if n == 0 {
break;
}
let chunk = FileStreamMessage::Chunk(FileChunk {
file_id,
offset,
data: buf[..n].to_vec(),
});
write_framed(send, &chunk).await?;
offset += n as u64;
// Update progress
// Scope limit the lock
{
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Transferring {
bytes_transferred: offset,
total_size,
start_time,
};
}
}
}
// Send done
write_framed(send, &FileStreamMessage::Done(FileDone { file_id })).await?;
{
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Complete { completed_at: std::time::Instant::now() };
}
}
Ok(())
}
/// Handle an incoming file offer for display.
#[allow(dead_code)]
pub fn register_incoming_offer(&self, offer: &FileOffer, peer: EndpointId) {
let mut transfers = self.transfers.lock().unwrap();
transfers.insert(
offer.file_id,
TransferInfo {
file_id: offer.file_id,
file_name: offer.name.clone(),
file_size: offer.size,
state: TransferState::Offering,
is_outgoing: false,
peer: Some(peer),
path: None,
offer: None,
},
);
}
/// Handle an incoming file offer broadcast for display.
pub fn register_incoming_broadcast(
&self,
offer: &crate::protocol::FileOfferBroadcast,
peer: EndpointId,
) {
let mut transfers = self.transfers.lock().unwrap();
// Calculate expires_at based on timeout
let expires_at = std::time::Instant::now() + std::time::Duration::from_secs(offer.timeout);
transfers.insert(
offer.file_id,
TransferInfo {
file_id: offer.file_id,
file_name: offer.file_name.clone(),
file_size: offer.file_size,
// We go directly to WaitingForAccept because it's an offer we can accept
state: TransferState::WaitingForAccept { expires_at },
is_outgoing: false,
peer: Some(peer),
path: None,
offer: None,
},
);
}
/// Execute the receiving side of a file transfer over a QUIC bi-stream.
#[allow(dead_code)]
pub async fn execute_receive(
&self,
send: &mut iroh::endpoint::SendStream,
recv: &mut iroh::endpoint::RecvStream,
) -> Result<FileId> {
// Read the offer
let msg: FileStreamMessage = decode_framed(recv).await?;
let offer = match msg {
FileStreamMessage::Offer(o) => o,
_ => anyhow::bail!("Expected file offer"),
};
let file_id = offer.file_id;
let timeout_duration = std::time::Duration::from_secs(offer.timeout);
let expires_at = std::time::Instant::now() + timeout_duration;
// Check if pre-accepted (Requesting state)
let is_pre_accepted = {
let transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get(&file_id) {
matches!(info.state, TransferState::Requesting { .. })
} else {
false
}
};
// Register incoming offer (unless pre-accepted logic differs)
{
let mut transfers = self.transfers.lock().unwrap();
transfers.insert(
offer.file_id,
TransferInfo {
file_id: offer.file_id,
file_name: offer.name.clone(),
file_size: offer.size,
state: TransferState::WaitingForAccept { expires_at },
is_outgoing: false,
peer: None, // We don't need peer for receive technically, as we are connected
path: None,
offer: Some(offer.clone()), // Store incoming offer if useful
},
);
}
let accepted = if is_pre_accepted {
true
} else {
// Wait for acceptance
let (tx, rx) = tokio::sync::oneshot::channel();
{
let mut pending = self.pending_accepts.lock().unwrap();
pending.insert(file_id, tx);
}
// Wait for signal or timeout
match tokio::time::timeout(timeout_duration, rx).await {
Ok(Ok(decision)) => decision,
Ok(Err(_)) => false, // Sender dropped?
Err(_) => {
// Timeout
false
}
}
};
// Remove from pending if still there (in case of timeout or pre-accept bypass)
{
let mut pending = self.pending_accepts.lock().unwrap();
pending.remove(&file_id);
}
if !accepted {
// Reject
{
let mut transfers = self.transfers.lock().unwrap();
// Check if it wasn't already updated (e.g. by manual reject)
if let Some(info) = transfers.get_mut(&file_id) {
// Could be Offering or WaitingForAccept depending on race
info.state = TransferState::Rejected { completed_at: std::time::Instant::now() };
}
}
write_framed(
send,
&FileStreamMessage::Reject(FileAcceptReject { file_id }),
)
.await?;
return Ok(file_id);
}
// Accepted
write_framed(
send,
&FileStreamMessage::Accept(FileAcceptReject { file_id }),
)
.await?;
// Update state to Transferring immediately to stop spinner
{
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Transferring {
bytes_transferred: 0,
total_size: offer.size,
start_time: std::time::Instant::now(),
};
}
}
// Receive chunks
let dest_path = self.download_dir.join(&offer.name);
// Ensure download dir exists for safety (though main normally does this)
if let Some(parent) = dest_path.parent() {
let _ = tokio::fs::create_dir_all(parent).await;
}
let mut file = tokio::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(&dest_path)
.await?;
let mut received: u64 = 0;
let start_time = std::time::Instant::now();
loop {
let chunk_msg: FileStreamMessage = decode_framed(recv).await?;
match chunk_msg {
FileStreamMessage::Chunk(chunk) => {
file.write_all(&chunk.data).await?;
received += chunk.data.len() as u64;
{
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Transferring {
bytes_transferred: received,
total_size: offer.size,
start_time,
};
}
}
}
FileStreamMessage::Done(_) => {
break;
}
_ => {
anyhow::bail!("Unexpected message during file transfer");
}
}
}
file.flush().await?;
{
let mut transfers = self.transfers.lock().unwrap();
if let Some(info) = transfers.get_mut(&file_id) {
info.state = TransferState::Complete { completed_at: std::time::Instant::now() };
}
}
Ok(file_id)
}
/// Get a summary of active/recent transfers for display.
pub fn active_transfers(&self) -> Vec<TransferInfo> {
let transfers = self.transfers.lock().unwrap();
transfers.values().cloned().collect()
}
/// Format transfer progress as a human-readable string.
#[allow(dead_code)]
pub fn format_progress(info: &TransferInfo) -> String {
let direction = if info.is_outgoing { "" } else { "" };
match &info.state {
TransferState::Offering => {
format!("{} {} (Offering)", direction, info.file_name)
}
TransferState::Requesting { .. } => {
format!("{} {} (Requesting...)", direction, info.file_name)
}
TransferState::WaitingForAccept { expires_at } => {
let now = std::time::Instant::now();
let remaining = if *expires_at > now {
expires_at.duration_since(now).as_secs()
} else {
0
};
if info.is_outgoing {
format!("{} {} (Wait Accept - {}s)", direction, info.file_name, remaining)
} else {
format!(
"{} {} (Incoming Offer - {}s)",
direction, info.file_name, remaining
)
}
}
TransferState::Transferring {
bytes_transferred,
total_size,
start_time,
} => {
let pct = if *total_size > 0 {
(*bytes_transferred as f64 / *total_size as f64 * 100.0) as u8
} else {
0
};
let elapsed = start_time.elapsed().as_secs_f64();
let speed_bps = if elapsed > 0.0 {
*bytes_transferred as f64 / elapsed
} else {
0.0
};
let speed_mbps = speed_bps / (1024.0 * 1024.0);
format!(
"{} {} {}% ({}/{}) {:.1} MB/s",
direction,
info.file_name,
pct,
format_bytes(*bytes_transferred),
format_bytes(*total_size),
speed_mbps
)
}
TransferState::Complete { .. } => {
format!("{} {} ✓ complete", direction, info.file_name)
}
TransferState::Rejected { .. } => {
format!("{} {} ✗ rejected", direction, info.file_name)
}
TransferState::Failed { error, .. } => {
format!("{} {}{}", direction, info.file_name, error)
}
}
}
}
#[allow(dead_code)]
pub fn format_bytes(bytes: u64) -> String {
if bytes < 1024 {
format!("{}B", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.1}KiB", bytes as f64 / 1024.0)
} else if bytes < 1024 * 1024 * 1024 {
format!("{:.1}MiB", bytes as f64 / (1024.0 * 1024.0))
} else {
format!("{:.1}GiB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
}
}