Replace direct register accesses with abstraction

Reviewed By: wangbj

Differential Revision: D40577287

fbshipit-source-id: 8fe94d172ffb0f1d7c2e177ec6314b1b77f4e5c1
This commit is contained in:
Jason White 2022-10-21 12:09:36 -07:00 committed by Facebook GitHub Bot
parent 6ad5309bb1
commit 8862833da4
2 changed files with 76 additions and 64 deletions

View file

@ -20,3 +20,12 @@ pub const TRAMPOLINE_SIZE: usize = 0x1000;
/// total private page size /// total private page size
pub const PRIVATE_PAGE_SIZE: usize = TRAMPOLINE_SIZE; pub const PRIVATE_PAGE_SIZE: usize = TRAMPOLINE_SIZE;
/// The size of a breakpoint instruction. On x86_64, this is just 0xcc, which is
/// one byte.
#[cfg(target_arch = "x86_64")]
pub const BREAKPOINT_SIZE: usize = 1;
/// The size of a breakpoint instruction. On aarch64, this is 4 bytes.
#[cfg(target_arch = "aarch64")]
pub const BREAKPOINT_SIZE: usize = 4;

View file

@ -72,6 +72,8 @@ use tracing::info;
use tracing::trace; use tracing::trace;
use tracing::warn; use tracing::warn;
use super::regs::Reg;
use super::regs::RegAccess;
use crate::children; use crate::children;
use crate::cp; use crate::cp;
use crate::error::Error; use crate::error::Error;
@ -512,27 +514,28 @@ impl<L: Tool> TracedTask<L> {
fn get_syscall(&self, task: &Stopped) -> Result<Syscall, TraceError> { fn get_syscall(&self, task: &Stopped) -> Result<Syscall, TraceError> {
let regs = task.getregs()?; let regs = task.getregs()?;
let nr = Sysno::from(regs.orig_rax as i32); let nr = Sysno::from(regs.orig_syscall() as i32);
let args = SyscallArgs::new(
regs.rdi as usize, let args = regs.args();
regs.rsi as usize,
regs.rdx as usize, Ok(Syscall::from_raw(
regs.r10 as usize, nr,
regs.r8 as usize, SyscallArgs::new(
regs.r9 as usize, args.0 as usize,
); args.1 as usize,
trace!( args.2 as usize,
"[retrieve_task_state] translating ptrace event SECCOMP into syscall {}", args.3 as usize,
nr args.4 as usize,
); args.5 as usize,
Ok(Syscall::from_raw(nr, args)) ),
))
} }
} }
fn set_rax(task: &Stopped, rax: u64) -> Result<u64, TraceError> { fn set_ret(task: &Stopped, ret: Reg) -> Result<Reg, TraceError> {
let mut regs = task.getregs()?; let mut regs = task.getregs()?;
let old = regs.rax; let old = regs.ret();
regs.rax = rax; *regs.ret_mut() = ret;
task.setregs(regs)?; task.setregs(regs)?;
Ok(old) Ok(old)
} }
@ -581,28 +584,25 @@ fn decode_segfault(insn_at_rip: u64) -> Option<SegfaultTrapInfo> {
fn restore_context( fn restore_context(
task: &Stopped, task: &Stopped,
context: libc::user_regs_struct, context: libc::user_regs_struct,
rax: Option<u64>, retval: Option<Reg>,
) -> Result<(), TraceError> { ) -> Result<(), TraceError> {
let mut regs = task.getregs()?; let mut regs = task.getregs()?;
if let Some(rax) = rax { if let Some(ret) = retval {
regs.rax = rax; *regs.ret_mut() = ret;
} }
regs.rip = context.rip; // Restore instruction pointer.
*regs.ip_mut() = context.ip();
regs.rdi = context.rdi; // Restore syscall arguments.
regs.rsi = context.rsi; regs.set_args(context.args());
regs.rdx = context.rdx;
regs.r10 = context.r10;
regs.r8 = context.r8;
regs.r9 = context.r9;
// This is needed when syscall is interrupted by a signal (ERESTARTSYS) // This is needed when syscall is interrupted by a signal (ERESTARTSYS)
// we need restore the original syscall number as well because it is // we need restore the original syscall number as well because it is
// possible syscall is reinjected as a different variant, like vfork -> // possible syscall is reinjected as a different variant, like vfork ->
// clone, which accepts different arguments. // clone, which accepts different arguments.
regs.orig_rax = context.orig_rax; *regs.orig_syscall_mut() = context.orig_syscall();
// NB: syscall also clobbers %rcx/%r11, but we're not required to restore // NB: syscall also clobbers %rcx/%r11, but we're not required to restore
// them, because the syscall is finished and they're supposed to change. // them, because the syscall is finished and they're supposed to change.
@ -651,14 +651,16 @@ impl<L: Tool + 'static> TracedTask<L> {
let page_addr = cp::PRIVATE_PAGE_OFFSET; let page_addr = cp::PRIVATE_PAGE_OFFSET;
regs.orig_rax = Sysno::mmap as u64; *regs.syscall_mut() = Sysno::mmap as Reg;
regs.rax = regs.orig_rax; *regs.orig_syscall_mut() = regs.syscall();
regs.rdi = page_addr; regs.set_args((
regs.rsi = cp::PRIVATE_PAGE_SIZE as u64; page_addr,
regs.rdx = (libc::PROT_READ | libc::PROT_WRITE | libc::PROT_EXEC) as u64; cp::PRIVATE_PAGE_SIZE as Reg,
regs.r10 = (libc::MAP_PRIVATE | libc::MAP_FIXED | libc::MAP_ANONYMOUS) as u64; (libc::PROT_READ | libc::PROT_WRITE | libc::PROT_EXEC) as Reg,
regs.r8 = -1i64 as u64; (libc::MAP_PRIVATE | libc::MAP_FIXED | libc::MAP_ANONYMOUS) as Reg,
regs.r9 = 0u64; -1i64 as Reg,
0,
));
task.setregs(regs)?; task.setregs(regs)?;
// Execute the injected mmap call. // Execute the injected mmap call.
@ -691,7 +693,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// Make sure we got our desired address. // Make sure we got our desired address.
assert_eq!( assert_eq!(
Errno::from_ret(task.getregs()?.rax as usize)? as u64, Errno::from_ret(task.getregs()?.ret() as usize)? as u64,
page_addr, page_addr,
"Could not mmap address {}", "Could not mmap address {}",
page_addr page_addr
@ -699,7 +701,7 @@ impl<L: Tool + 'static> TracedTask<L> {
cp::populate_mmap_page(task.pid().into(), page_addr).map_err(|err| err)?; cp::populate_mmap_page(task.pid().into(), page_addr).map_err(|err| err)?;
saved_regs.rip -= 1; // bp size *saved_regs.ip_mut() -= cp::BREAKPOINT_SIZE as Reg;
task.setregs(saved_regs)?; task.setregs(saved_regs)?;
Ok(task) Ok(task)
} }
@ -715,10 +717,10 @@ impl<L: Tool + 'static> TracedTask<L> {
let regs = task.getregs()?; let regs = task.getregs()?;
// Saved instruction memory // Saved instruction memory
let rip = AddrMut::from_raw(regs.rip as usize).unwrap(); let ip = AddrMut::from_raw(regs.ip() as usize).unwrap();
let saved: u64 = task.read_value(rip)?; let saved: u64 = task.read_value(ip)?;
// Patch the tracee at the current instruction pointer. // Patch the tracee at the current instruction pointer.
task.write_value(rip, &((saved & !(0xffffffff_u64)) | bp_syscall_bp))?; task.write_value(ip, &((saved & !(0xffffffff_u64)) | bp_syscall_bp))?;
// When resumed, the tracee will hit the first breakpoint. Then we // When resumed, the tracee will hit the first breakpoint. Then we
// wait for it to reach that breakpoint and trap/stop. // wait for it to reach that breakpoint and trap/stop.
@ -739,7 +741,7 @@ impl<L: Tool + 'static> TracedTask<L> {
saved: u64, saved: u64,
) -> Result<Stopped, TraceError> { ) -> Result<Stopped, TraceError> {
// Restore what we dirtied: // Restore what we dirtied:
task.write_value(AddrMut::from_raw(regs.rip as usize).unwrap(), &saved)?; task.write_value(AddrMut::from_raw(regs.ip() as usize).unwrap(), &saved)?;
task.setregs(regs)?; task.setregs(regs)?;
Ok(task) Ok(task)
} }
@ -863,10 +865,10 @@ impl<L: Tool + 'static> TracedTask<L> {
.resumed_by_gdb .resumed_by_gdb
.map_or(false, |action| matches!(action, ResumeAction::Step(_))); .map_or(false, |action| matches!(action, ResumeAction::Step(_)));
let mut regs = task.getregs()?; let mut regs = task.getregs()?;
let rip_minus_one = regs.rip - 1; let rip_minus_one = regs.ip() - 1;
Ok(if self.breakpoints.contains_key(&rip_minus_one) { Ok(if self.breakpoints.contains_key(&rip_minus_one) {
regs.rip = rip_minus_one; *regs.ip_mut() = rip_minus_one;
let next_state = self.resume_from_swbreak(task, regs).await?; let next_state = self.resume_from_swbreak(task, regs).await?;
HandleSignalResult::SignalSuppressed(next_state) HandleSignalResult::SignalSuppressed(next_state)
} else if resumed_by_gdb_step { } else if resumed_by_gdb_step {
@ -1043,7 +1045,7 @@ impl<L: Tool + 'static> TracedTask<L> {
Err(err) => (-(err.into_errno()?.into_raw() as i64)) as u64, Err(err) => (-(err.into_errno()?.into_raw() as i64)) as u64,
}; };
set_rax(&task, ret)?; set_ret(&task, ret)?;
} }
// Finally, resume the guest. // Finally, resume the guest.
@ -1448,7 +1450,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// our patched syscall on the first run. Please note after calling this // our patched syscall on the first run. Please note after calling this
// function, the task state will no longer be in ptrace event seccomp. // function, the task state will no longer be in ptrace event seccomp.
let mut new_regs = regs; let mut new_regs = regs;
new_regs.orig_rax = -1i64 as u64; *new_regs.orig_syscall_mut() = -1i64 as u64;
task.setregs(new_regs)?; task.setregs(new_regs)?;
let mut running = task.step(None)?; let mut running = task.step(None)?;
@ -1504,20 +1506,21 @@ impl<L: Tool + 'static> TracedTask<L> {
let oldregs = regs; let oldregs = regs;
let no = nr as u64; *regs.syscall_mut() = nr as Reg;
regs.orig_rax = no; *regs.orig_syscall_mut() = nr as Reg;
regs.rax = no; regs.set_args((
regs.rdi = args.arg0 as u64; args.arg0 as Reg,
regs.rsi = args.arg1 as u64; args.arg1 as Reg,
regs.rdx = args.arg2 as u64; args.arg2 as Reg,
regs.r10 = args.arg3 as u64; args.arg3 as Reg,
regs.r8 = args.arg4 as u64; args.arg4 as Reg,
regs.r9 = args.arg5 as u64; args.arg5 as Reg,
));
// instruction at PRIVATE_PAGE_OFFSET, see `populate_mmap_page`. // instruction at PRIVATE_PAGE_OFFSET, see `populate_mmap_page`.
// 7000_0000: 0f 05 syscall // 7000_0000: 0f 05 syscall
// 7000_0002: 0f 0b ud2 // 7000_0002: 0f 0b ud2
regs.rip = cp::PRIVATE_PAGE_OFFSET; *regs.ip_mut() = cp::PRIVATE_PAGE_OFFSET;
task.setregs(regs)?; task.setregs(regs)?;
@ -1547,7 +1550,7 @@ impl<L: Tool + 'static> TracedTask<L> {
Wait::Stopped(stopped, event) => match event { Wait::Stopped(stopped, event) => match event {
Event::Signal(_sig) if context.is_none() => { Event::Signal(_sig) if context.is_none() => {
let regs = stopped.getregs()?; let regs = stopped.getregs()?;
Ok(Ok(regs.rax as i64)) Ok(Ok(regs.ret() as i64))
} }
Event::Signal(sig) => { Event::Signal(sig) => {
let mut regs = stopped.getregs()?; let mut regs = stopped.getregs()?;
@ -1555,13 +1558,13 @@ impl<L: Tool + 'static> TracedTask<L> {
// SIGCHLD) before single step finishes (in that case rip == // SIGCHLD) before single step finishes (in that case rip ==
// 0x7000_0000u64). // 0x7000_0000u64).
debug_assert!( debug_assert!(
regs.rip == cp::PRIVATE_PAGE_OFFSET + 0x2 regs.ip() == cp::PRIVATE_PAGE_OFFSET + 0x2
|| regs.rip == cp::PRIVATE_PAGE_OFFSET || regs.ip() == cp::PRIVATE_PAGE_OFFSET
); );
// interrupted by signal, return -ERESTARTSYS so that tracee can do a // interrupted by signal, return -ERESTARTSYS so that tracee can do a
// restart_syscall. // restart_syscall.
if sig != Signal::SIGTRAP { if sig != Signal::SIGTRAP {
regs.rax = (-(Errno::ERESTARTSYS.into_raw()) as i64) as u64; *regs.ret_mut() = (-(Errno::ERESTARTSYS.into_raw()) as i64) as u64;
self.pending_signal = Some(sig); self.pending_signal = Some(sig);
} }
if let Some(context) = context { if let Some(context) = context {
@ -1571,7 +1574,7 @@ impl<L: Tool + 'static> TracedTask<L> {
// it back. // it back.
restore_context(&stopped, context, None)?; restore_context(&stopped, context, None)?;
} }
Ok(Errno::from_ret(regs.rax as usize).map(|x| x as i64)) Ok(Errno::from_ret(regs.ret() as usize).map(|x| x as i64))
} }
Event::NewChild(op, child) => { Event::NewChild(op, child) => {
let ret = child.pid().as_raw() as i64; let ret = child.pid().as_raw() as i64;
@ -1585,7 +1588,7 @@ impl<L: Tool + 'static> TracedTask<L> {
} }
Event::Syscall => { Event::Syscall => {
let regs = stopped.getregs()?; let regs = stopped.getregs()?;
Ok(Errno::from_ret(regs.rax as usize).map(|x| x as i64)) Ok(Errno::from_ret(regs.ret() as usize).map(|x| x as i64))
} }
st => panic!("untraced_syscall returned unknown state: {:?}", st), st => panic!("untraced_syscall returned unknown state: {:?}", st),
}, },
@ -1890,9 +1893,9 @@ impl<L: Tool + 'static> TracedTask<L> {
match wait { match wait {
Wait::Stopped(task, event) if event == Event::Signal(Signal::SIGTRAP) => { Wait::Stopped(task, event) if event == Event::Signal(Signal::SIGTRAP) => {
let mut regs = task.getregs()?; let mut regs = task.getregs()?;
let rip_minus_one = regs.rip - 1; let rip_minus_one = regs.ip() - 1;
if self.breakpoints.contains_key(&rip_minus_one) { if self.breakpoints.contains_key(&rip_minus_one) {
regs.rip = rip_minus_one; *regs.ip_mut() = rip_minus_one;
self.resume_from_swbreak(task, regs).await self.resume_from_swbreak(task, regs).await
} else { } else {
Ok(Wait::Stopped(task, event)) Ok(Wait::Stopped(task, event))