git: move progress callback into a struct

This commit is contained in:
Benjamin Saunders 2022-11-06 09:36:52 -08:00
parent a27da7d8d5
commit b55c4ae0a3
3 changed files with 115 additions and 51 deletions

View file

@ -17,7 +17,7 @@ use std::fs::OpenOptions;
use std::io::{Read, Write};
use std::sync::Arc;
use git2::{Oid, RemoteCallbacks};
use git2::Oid;
use itertools::Itertools;
use thiserror::Error;
@ -280,7 +280,7 @@ pub fn fetch(
mut_repo: &mut MutableRepo,
git_repo: &git2::Repository,
remote_name: &str,
progress: Option<&mut dyn FnMut(&Progress)>,
callbacks: RemoteCallbacks<'_>,
) -> Result<Option<String>, GitFetchError> {
let mut remote =
git_repo
@ -298,7 +298,7 @@ pub fn fetch(
let mut proxy_options = git2::ProxyOptions::new();
proxy_options.auto();
fetch_options.proxy_options(proxy_options);
let callbacks = create_remote_callbacks(progress);
let callbacks = callbacks.into_git();
fetch_options.remote_callbacks(callbacks);
let refspec: &[&str] = &[];
remote.download(refspec, Some(&mut fetch_options))?;
@ -435,7 +435,7 @@ fn push_refs(
let mut proxy_options = git2::ProxyOptions::new();
proxy_options.auto();
push_options.proxy_options(proxy_options);
let mut callbacks = create_remote_callbacks(None);
let mut callbacks = RemoteCallbacks::default().into_git();
callbacks.push_update_reference(|refname, status| {
// The status is Some if the ref update was rejected
if status.is_none() {
@ -466,39 +466,53 @@ fn push_refs(
}
}
fn create_remote_callbacks(progress_cb: Option<&mut dyn FnMut(&Progress)>) -> RemoteCallbacks<'_> {
let mut callbacks = git2::RemoteCallbacks::new();
if let Some(progress_cb) = progress_cb {
callbacks.transfer_progress(move |progress| {
progress_cb(&Progress {
bytes_downloaded: if progress.received_objects() < progress.total_objects() {
Some(progress.received_bytes() as u64)
} else {
None
},
overall: (progress.indexed_objects() + progress.indexed_deltas()) as f32
/ (progress.total_objects() + progress.total_deltas()) as f32,
#[non_exhaustive]
#[derive(Default)]
pub struct RemoteCallbacks<'a> {
pub progress: Option<&'a mut dyn FnMut(&Progress)>,
}
impl<'a> RemoteCallbacks<'a> {
fn into_git(self) -> git2::RemoteCallbacks<'a> {
let mut callbacks = git2::RemoteCallbacks::new();
if let Some(progress_cb) = self.progress {
callbacks.transfer_progress(move |progress| {
progress_cb(&Progress {
bytes_downloaded: if progress.received_objects() < progress.total_objects() {
Some(progress.received_bytes() as u64)
} else {
None
},
overall: (progress.indexed_objects() + progress.indexed_deltas()) as f32
/ (progress.total_objects() + progress.total_deltas()) as f32,
});
true
});
true
});
}
// TODO: We should expose the callbacks to the caller instead -- the library
// crate shouldn't look in $HOME etc.
callbacks.credentials(|_url, username_from_url, allowed_types| {
if allowed_types.contains(git2::CredentialType::SSH_KEY) {
if std::env::var("SSH_AUTH_SOCK").is_ok() || std::env::var("SSH_AGENT_PID").is_ok() {
return git2::Cred::ssh_key_from_agent(username_from_url.unwrap());
}
if let Ok(home_dir) = std::env::var("HOME") {
let key_path = std::path::Path::new(&home_dir).join(".ssh").join("id_rsa");
if key_path.is_file() {
return git2::Cred::ssh_key(username_from_url.unwrap(), None, &key_path, None);
}
// TODO: We should expose the callbacks to the caller instead -- the library
// crate shouldn't look in $HOME etc.
callbacks.credentials(|_url, username_from_url, allowed_types| {
if allowed_types.contains(git2::CredentialType::SSH_KEY) {
if std::env::var("SSH_AUTH_SOCK").is_ok() || std::env::var("SSH_AGENT_PID").is_ok()
{
return git2::Cred::ssh_key_from_agent(username_from_url.unwrap());
}
if let Ok(home_dir) = std::env::var("HOME") {
let key_path = std::path::Path::new(&home_dir).join(".ssh").join("id_rsa");
if key_path.is_file() {
return git2::Cred::ssh_key(
username_from_url.unwrap(),
None,
&key_path,
None,
);
}
}
}
}
git2::Cred::default()
});
callbacks
git2::Cred::default()
});
callbacks
}
}
pub struct Progress {

View file

@ -593,7 +593,13 @@ fn test_fetch_empty_repo() {
let test_data = GitRepoData::create();
let mut tx = test_data.repo.start_transaction("test");
let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
let default_branch = git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
// No default branch and no refs
assert_eq!(default_branch, None);
assert_eq!(*tx.mut_repo().view().git_refs(), btreemap! {});
@ -606,7 +612,13 @@ fn test_fetch_initial_commit() {
let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]);
let mut tx = test_data.repo.start_transaction("test");
let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
let default_branch = git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
// No default branch because the origin repo's HEAD wasn't set
assert_eq!(default_branch, None);
let repo = tx.commit();
@ -637,7 +649,13 @@ fn test_fetch_success() {
let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]);
let mut tx = test_data.repo.start_transaction("test");
git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
test_data.repo = tx.commit();
test_data.origin_repo.set_head("refs/heads/main").unwrap();
@ -648,7 +666,13 @@ fn test_fetch_success() {
);
let mut tx = test_data.repo.start_transaction("test");
let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
let default_branch = git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
// The default branch is "main"
assert_eq!(default_branch, Some("main".to_string()));
let repo = tx.commit();
@ -679,7 +703,13 @@ fn test_fetch_prune_deleted_ref() {
empty_git_commit(&test_data.git_repo, "refs/heads/main", &[]);
let mut tx = test_data.repo.start_transaction("test");
git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
// Test the setup
assert!(tx.mut_repo().get_branch("main").is_some());
@ -690,7 +720,13 @@ fn test_fetch_prune_deleted_ref() {
.delete()
.unwrap();
// After re-fetching, the branch should be deleted
git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
assert!(tx.mut_repo().get_branch("main").is_none());
}
@ -700,7 +736,13 @@ fn test_fetch_no_default_branch() {
let initial_git_commit = empty_git_commit(&test_data.origin_repo, "refs/heads/main", &[]);
let mut tx = test_data.repo.start_transaction("test");
git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
empty_git_commit(
&test_data.origin_repo,
@ -715,7 +757,13 @@ fn test_fetch_no_default_branch() {
.set_head_detached(initial_git_commit.id())
.unwrap();
let default_branch = git::fetch(tx.mut_repo(), &test_data.git_repo, "origin", None).unwrap();
let default_branch = git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"origin",
git::RemoteCallbacks::default(),
)
.unwrap();
// There is no default branch
assert_eq!(default_branch, None);
}
@ -725,7 +773,12 @@ fn test_fetch_no_such_remote() {
let test_data = GitRepoData::create();
let mut tx = test_data.repo.start_transaction("test");
let result = git::fetch(tx.mut_repo(), &test_data.git_repo, "invalid-remote", None);
let result = git::fetch(
tx.mut_repo(),
&test_data.git_repo,
"invalid-remote",
git::RemoteCallbacks::default(),
);
assert!(matches!(result, Err(GitFetchError::NoSuchRemote(_))));
}

View file

@ -4104,14 +4104,11 @@ fn git_fetch(
progress.update(Instant::now(), x);
});
}
let result = git::fetch(
mut_repo,
git_repo,
remote_name,
callback
.as_mut()
.map(|x| x as &mut dyn FnMut(&git::Progress)),
);
let mut callbacks = git::RemoteCallbacks::default();
callbacks.progress = callback
.as_mut()
.map(|x| x as &mut dyn FnMut(&git::Progress));
let result = git::fetch(mut_repo, git_repo, remote_name, callbacks);
result
}