safeptrace: Remove unnecessary copying of registers

Summary: Might lead to a minor speed up, but probably not due to the magic of compiler optimizations. I noticed this a while back and meant to fix it before publishing the `safeptrace` crate (so this won't be a breaking change later).

Reviewed By: VladimirMakaev

Differential Revision: D41699769

fbshipit-source-id: ef3e5f24468ddb4b69b8628e28b1da8cfdcbce7b
This commit is contained in:
Jason White 2022-12-07 07:38:13 -08:00 committed by Facebook GitHub Bot
parent 866529e6e8
commit bcce15d17d
2 changed files with 20 additions and 20 deletions

View file

@ -535,7 +535,7 @@ fn set_ret(task: &Stopped, ret: Reg) -> Result<Reg, TraceError> {
let mut regs = task.getregs()?; let mut regs = task.getregs()?;
let old = regs.ret(); let old = regs.ret();
*regs.ret_mut() = ret; *regs.ret_mut() = ret;
task.setregs(regs)?; task.setregs(&regs)?;
Ok(old) Ok(old)
} }
@ -611,7 +611,7 @@ fn restore_context(
// them, because the syscall is finished and they're supposed to change. // them, because the syscall is finished and they're supposed to change.
// TL&DR: do not restore %rcx/%r11 here. // TL&DR: do not restore %rcx/%r11 here.
task.setregs(regs) task.setregs(&regs)
} }
impl<L: Tool + 'static> TracedTask<L> { impl<L: Tool + 'static> TracedTask<L> {
@ -671,7 +671,7 @@ impl<L: Tool + 'static> TracedTask<L> {
0, 0,
)); ));
task.setregs(regs)?; task.setregs(&regs)?;
// Execute the injected mmap call. // Execute the injected mmap call.
let mut running = task.step(None)?; let mut running = task.step(None)?;
@ -713,7 +713,7 @@ impl<L: Tool + 'static> TracedTask<L> {
cp::populate_mmap_page(task.pid().into(), page_addr).map_err(|err| err)?; cp::populate_mmap_page(task.pid().into(), page_addr).map_err(|err| err)?;
// Restore our saved registers, including our instruction pointer. // Restore our saved registers, including our instruction pointer.
task.setregs(*saved_regs)?; task.setregs(saved_regs)?;
Ok(task) Ok(task)
} }
@ -767,7 +767,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// PTRACE_POKEDATA. // PTRACE_POKEDATA.
let ip = AddrMut::from_raw(regs.ip() as usize).unwrap(); let ip = AddrMut::from_raw(regs.ip() as usize).unwrap();
task.write_value(ip, &saved)?; task.write_value(ip, &saved)?;
task.setregs(regs)?; task.setregs(&regs)?;
Ok(()) Ok(())
} }
@ -806,7 +806,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// Restore registers again after we've injected syscalls so that we // Restore registers again after we've injected syscalls so that we
// don't leave the return value register (%rax) in a dirty state. // don't leave the return value register (%rax) in a dirty state.
task.setregs(regs)?; task.setregs(&regs)?;
Ok(task) Ok(task)
} }
@ -973,12 +973,12 @@ impl<L: Tool + 'static> TracedTask<L> {
Ok(match trap_info { Ok(match trap_info {
Some(SegfaultTrapInfo::Cpuid) => { Some(SegfaultTrapInfo::Cpuid) => {
let regs = self.handle_cpuid(regs).await?; let regs = self.handle_cpuid(regs).await?;
task.setregs(regs)?; task.setregs(&regs)?;
HandleSignalResult::SignalSuppressed(task.resume(None)?.next_state().await?) HandleSignalResult::SignalSuppressed(task.resume(None)?.next_state().await?)
} }
Some(SegfaultTrapInfo::Rdtscs(req)) => { Some(SegfaultTrapInfo::Rdtscs(req)) => {
let regs = self.handle_rdtscs(regs, req).await?; let regs = self.handle_rdtscs(regs, req).await?;
task.setregs(regs)?; task.setregs(&regs)?;
HandleSignalResult::SignalSuppressed(task.resume(None)?.next_state().await?) HandleSignalResult::SignalSuppressed(task.resume(None)?.next_state().await?)
} }
None => HandleSignalResult::SignalToDeliver(task, Signal::SIGSEGV), None => HandleSignalResult::SignalToDeliver(task, Signal::SIGSEGV),
@ -1519,7 +1519,7 @@ impl<L: Tool + 'static> TracedTask<L> {
{ {
let mut new_regs = regs; let mut new_regs = regs;
*new_regs.orig_syscall_mut() = -1i64 as u64; *new_regs.orig_syscall_mut() = -1i64 as u64;
task.setregs(new_regs)?; task.setregs(&new_regs)?;
} }
#[cfg(target_arch = "aarch64")] #[cfg(target_arch = "aarch64")]
@ -1534,7 +1534,7 @@ impl<L: Tool + 'static> TracedTask<L> {
match running.next_state().await? { match running.next_state().await? {
Wait::Stopped(task, Event::Signal(Signal::SIGTRAP)) => { Wait::Stopped(task, Event::Signal(Signal::SIGTRAP)) => {
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
task.setregs(regs)?; task.setregs(&regs)?;
break Ok(task); break Ok(task);
} }
Wait::Stopped(task, Event::Signal(sig)) => { Wait::Stopped(task, Event::Signal(sig)) => {
@ -1595,7 +1595,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// `populate_mmap_page` for details. // `populate_mmap_page` for details.
*regs.ip_mut() = cp::PRIVATE_PAGE_OFFSET as Reg; *regs.ip_mut() = cp::PRIVATE_PAGE_OFFSET as Reg;
task.setregs(regs)?; task.setregs(&regs)?;
// Step to run the syscall instruction. // Step to run the syscall instruction.
let wait = task.step(None)?.next_state().await?; let wait = task.step(None)?.next_state().await?;
@ -1878,7 +1878,7 @@ impl<L: Tool + 'static> TracedTask<L> {
task: Stopped, task: Stopped,
regs: libc::user_regs_struct, regs: libc::user_regs_struct,
) -> Result<Wait, TraceError> { ) -> Result<Wait, TraceError> {
task.setregs(regs)?; task.setregs(&regs)?;
// Task could be hitting a breakpoint, after previously suspended by // Task could be hitting a breakpoint, after previously suspended by
// a different task, need to notify this task is fully stopped. // a different task, need to notify this task is fully stopped.
@ -2093,8 +2093,8 @@ impl<L: Tool + 'static> TracedTask<L> {
fn write_registers(&self, core_regs: CoreRegs) -> Result<(), TraceError> { fn write_registers(&self, core_regs: CoreRegs) -> Result<(), TraceError> {
let task = self.assume_stopped(); let task = self.assume_stopped();
let (regs, fpregs) = core_regs.into_parts(); let (regs, fpregs) = core_regs.into_parts();
task.setregs(regs)?; task.setregs(&regs)?;
task.setfpregs(fpregs)?; task.setfpregs(&fpregs)?;
Ok(()) Ok(())
} }
} }

View file

@ -538,10 +538,10 @@ impl Stopped {
Ok(unsafe { regs.assume_init() }) Ok(unsafe { regs.assume_init() })
} }
fn setregset<T>(&self, which: i32, regs: T) -> Result<(), Error> { fn setregset<T>(&self, which: i32, regs: &T) -> Result<(), Error> {
let iov = libc::iovec { let iov = libc::iovec {
iov_base: &regs as *const _ as *mut _, iov_base: regs as *const _ as *mut _,
iov_len: core::mem::size_of_val(&regs), iov_len: core::mem::size_of::<T>(),
}; };
unsafe { unsafe {
@ -566,7 +566,7 @@ impl Stopped {
} }
/// Sets the general purpose registers. /// Sets the general purpose registers.
pub fn setregs(&self, regs: Regs) -> Result<(), Error> { pub fn setregs(&self, regs: &Regs) -> Result<(), Error> {
self.setregset(libc::NT_PRSTATUS, regs) self.setregset(libc::NT_PRSTATUS, regs)
} }
@ -576,7 +576,7 @@ impl Stopped {
} }
/// Sets the floating point registers. /// Sets the floating point registers.
pub fn setfpregs(&self, regs: FpRegs) -> Result<(), Error> { pub fn setfpregs(&self, regs: &FpRegs) -> Result<(), Error> {
self.setregset(libc::NT_PRFPREG, regs) self.setregset(libc::NT_PRFPREG, regs)
} }
@ -614,7 +614,7 @@ impl Stopped {
#[cfg(target_arch = "aarch64")] #[cfg(target_arch = "aarch64")]
pub fn set_syscall(&self, nr: i32) -> Result<(), Error> { pub fn set_syscall(&self, nr: i32) -> Result<(), Error> {
const NT_ARM_SYSTEM_CALL: i32 = 0x404; const NT_ARM_SYSTEM_CALL: i32 = 0x404;
self.setregset(NT_ARM_SYSTEM_CALL, nr) self.setregset(NT_ARM_SYSTEM_CALL, &nr)
} }
/// Gets info about the signal that caused the process to be stopped. /// Gets info about the signal that caused the process to be stopped.