REMEND / remend /disassemble.py
udiboy1209's picture
Fix type in x64 disassembler
2b94efd
from capstone import *
from capstone.arm import *
from capstone.arm64 import *
from capstone.x86 import *
import cle
import struct
from math import e as CONST_E, pi as CONST_PI
import sympy as sp
from .util import DecodeError
def int2fp32(v):
if type(v) == int:
v = struct.unpack("<f", v.to_bytes(4, "little"))
v = v[0]
return v
def int2fp64(v):
if type(v) == int:
v = struct.unpack("<d", v.to_bytes(8, "little"))
v = v[0]
return v
def align4(v):
return v & (0xFFFFFFFC)
class DisassemblerBase:
def __init__(self, expr_constants={}, match_constants=False):
self.loader = None # Load in child class
self.reg_values = {}
self.constidx = 0
self.constants = {}
self.constaddrs = set()
self.expr_constants = expr_constants
self.match_constants = match_constants
def get_function_bytes(self, funcname):
func = self.loader.find_symbol(funcname)
if not func:
raise DecodeError(f"Function {funcname} not found in binary")
faddr = func.rebased_addr
if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1:
# Unaligned address, aligning
faddr = faddr - 1
fbytes = self.loader.memory.load(faddr, func.size)
self.funcrange = faddr, faddr + func.size
return faddr, fbytes
def find_constant(self, constants, value):
for ec in constants:
if abs(value - constants[ec]) < 1e-5:
return ec, ""
elif abs(1/value - constants[ec]) < 1e-5:
return ec, "1/"
elif abs(-value - constants[ec]) < 1e-5:
return ec, "-"
elif abs(-1/value - constants[ec]) < 1e-5:
return ec, "-1/"
return False
def add_constant(self, value, addr=0, size=0):
# Don't map known constants like e, pi, 0
if value == 0:
cname = "CONST=0"
elif abs(value - CONST_E) < 1e-7:
cname = "CONST=E"
elif abs(value - CONST_PI) < 1e-7:
cname = "CONST=pi"
elif self.match_constants and \
(ecmatch := self.find_constant(self.expr_constants, value)):
# Gives the name and expression of the matched constant
ecname, ecxpr = ecmatch
# print(value, ecname, ecxpr, self.expr_constants[ecname])
cname = f"{ecxpr}CSYM{ecname[1:]}"
self.constants[ecname] = value
elif size > 0 and addr in self.constaddrs and \
(smatch := self.find_constant(self.constants, value)):
sname, sxpr = smatch
cname = f"{sxpr}CSYM{sname}"
else:
rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7)
if isinstance(rep, sp.Integer) or \
(isinstance(rep, sp.Rational) and rep.q <= 16):
cname = f"CONST={rep}"
elif not self.match_constants:
cname = f"CSYM{self.constidx}"
self.constants[self.constidx] = value
self.constidx += 1
else:
raise DecodeError(f"Cannot represent unmatched float {value}")
if size > 0:
self.constaddrs |= {addr+i for i in range(size)}
return cname
def disassemble(self, function):
raise NotImplementedError("Call disassemble on child classes, not base")
class DisassemblerARM32(DisassemblerBase):
def __init__(self, binpath, expr_constants={}, match_constants=False):
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB)
self.md.detail = True
self.loader = cle.Loader(binpath)
def check_mov_imm(self, insn):
if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW,
ARM_INS_MOVT, ARM_INS_ADR}:
return False
ops = list(insn.operands)
if len(ops) != 2:
return False
if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM:
return False
imm = ops[1].value.imm
if imm < 0:
imm = 2**32 + imm # 2's complement
if insn.id == ARM_INS_ADR:
# Add PC value
imm += insn.address + 4
return ops[0].value.reg, imm
def check_float_store(self, insn):
if insn.id not in {ARM_INS_STR, ARM_INS_STRD}:
return False
ops = list(insn.operands)
if insn.id == ARM_INS_STRD:
dest = ops[0].value.reg
dest2 = ops[1].value.reg
if dest not in self.reg_values or dest2 not in self.reg_values:
return False
fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest])
else:
dest = ops[0].value.reg
if dest not in self.reg_values:
return False
fval = int2fp32(self.reg_values[dest])
if abs(fval) < 1e-3 or abs(fval) > 100:
return False
return fval
def check_ldrd(self, insn):
if insn.id != ARM_INS_LDRD:
return False
ops = list(insn.operands)
if len(ops) != 3:
return False
if ops[2].type != ARM_OP_MEM:
return False
mem = ops[2].value.mem
if mem.base == ARM_REG_PC:
addr = align4(insn.address + 4) + mem.disp
elif mem.base in self.reg_values:
addr = align4(self.reg_values[mem.base]) + mem.disp
else:
return False
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr:
# Out of bounds
return False
fhex = self.loader.memory.load(addr, 8)
fval = struct.unpack("d", fhex)[0]
return fval, addr, 8
def check_vldr(self, insn):
if insn.id != ARM_INS_VLDR:
return False
ops = list(insn.operands)
dest = ops[0]
if ops[1].type != ARM_OP_MEM:
return False
mem = ops[1].value.mem
if mem.base == ARM_REG_PC:
# Align4(PC) + Imm
# For whatever reason, in Thumb PC=addr+4
addr = align4(insn.address + 4) + mem.disp
elif mem.base in self.reg_values:
addr = align4(self.reg_values[mem.base]) + mem.disp
else:
return False
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr:
# Out of bounds
return False
if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31:
size = 8
fhex = self.loader.memory.load(addr, 8)
fval = struct.unpack("d", fhex)[0]
else:
size = 4
fhex = self.loader.memory.load(addr, 4)
fval = struct.unpack("f", fhex)[0]
return fval, addr, size
def check_vmov(self, insn):
# fconsts/d == vmov.f32/f64 (old/new names)
if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}:
return False
ops = list(insn.operands)
if len(ops) != 2 or ops[1].type != ARM_OP_FP:
return False
fval = ops[1].value.fp
destname = insn.reg_name(ops[0].value.reg)
asm = f"{insn.mnemonic} {destname}, {fval}"
return asm, fval
def check_branch_symbol(self, insn):
if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}:
return False
ops = list(insn.operands)
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
return False
addr = ops[0].value.imm
if addr > self.funcrange[0] and addr < self.funcrange[1]:
# Self-branch
func = f"SELF+{hex(addr - self.funcrange[0])}"
else:
func = self.loader.find_plt_stub_name(addr)
if func is None:
# Some tail call optimized PLT stubs have extra instructions
# that are not identified by CLE, so check with offset of 4 also.
func = self.loader.find_plt_stub_name(addr + 4)
if func is None:
return False
asm = f"{insn.mnemonic} <{func}>"
return asm
def get_function_bytes(self, funcname):
func = self.loader.find_symbol(funcname)
if not func:
raise DecodeError(f"Function {funcname} not found in binary")
faddr = func.rebased_addr
if faddr % 2 == 1:
# Unaligned address, aligning
faddr = faddr - 1
fbytes = self.loader.memory.load(faddr, func.size)
self.funcrange = faddr, faddr + func.size
return faddr, fbytes
def disassemble(self, funcname):
funcaddr, funcbytes = self.get_function_bytes(funcname)
disassm = []
for insn in self.md.disasm(funcbytes, funcaddr):
if insn.address in self.constaddrs:
# Skip if this is a constant value and not instruction
continue
cname = None
asm = None
if vldr := self.check_vldr(insn):
fval, faddr, fsize = vldr
cname = self.add_constant(fval, faddr, fsize)
elif ldrd := self.check_ldrd(insn):
fval, faddr, fsize = ldrd
cname = self.add_constant(fval, faddr, fsize)
elif strfloat := self.check_float_store(insn):
fval = strfloat
cname = self.add_constant(fval)
elif vmovfloat := self.check_vmov(insn):
asm, fval = vmovfloat
cname = self.add_constant(fval)
elif branch := self.check_branch_symbol(insn):
asm = branch
# Maintain values of immediate moves.
# Needs to be done after processing current instruction.
if movimm := self.check_mov_imm(insn):
reg, imm = movimm
if insn.id == ARM_INS_MOVT:
if reg not in self.reg_values:
self.reg_values[reg] = 0
self.reg_values[reg] += imm << 16
else:
self.reg_values[reg] = imm
else:
reads, writes = insn.regs_access()
for r in writes:
# Remove this reg if written to
if r in self.reg_values:
del self.reg_values[r]
if not asm:
asm = f"{insn.mnemonic} {insn.op_str}"
if cname:
asm += f", {cname}"
disassm.append(asm)
fulldiss = "; ".join(disassm)
return fulldiss
class DisassemblerAArch64(DisassemblerBase):
def __init__(self, binpath, expr_constants={}, match_constants=False):
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM)
self.md.detail = True
self.loader = cle.Loader(binpath)
def reg_size_type(self, reg):
# Bit width and datatype of register
if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30:
return 32, int
elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30:
return 64, int
elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31:
return 32, float
elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31:
return 64, float
return 0, None
def check_mov_imm(self, insn):
if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}:
return False
ops = insn.operands
if len(ops) != 2:
return False
if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM:
return False
imm = ops[1].value.imm
if ops[1].shift.type == 1: # LSL
imm <<= ops[1].shift.value
mask = 0xFFFF << ops[1].shift.value
if insn.id == ARM64_INS_ADRP:
# imm -= 0x400000 # Subtract global offset for some reason
# imm = ((insn.address + 4) & (~4095)) + imm
# Really confused about this, maybe I can use the imm directly
pass
elif insn.id == ARM64_INS_ADR:
imm -= 0x400000 # Subtract global offset for some reason
imm += insn.address + 4
elif insn.id == ARM64_INS_MOVK:
# load previous reg value
if ops[0].value.reg in self.reg_values:
curr = self.reg_values[ops[0].value.reg]
imm = (imm & mask) | (curr & (~mask))
return ops[0].value.reg, imm
def check_fmov(self, insn):
if insn.id != ARM64_INS_FMOV:
return False
ops = insn.operands
if len(ops) != 2: # or ops[1].type != ARM64_OP_FP:
return False
destsize, _ = self.reg_size_type(ops[0].value.reg)
destname = insn.reg_name(ops[0].value.reg)
if ops[1].type == ARM64_OP_FP:
fval = ops[1].value.fp
asm = f"{insn.mnemonic} {destname}, {fval}"
elif ops[1].type == ARM64_OP_REG:
reg = ops[1].value.reg
if reg not in self.reg_values:
return False
# TODO datatype
fhex = self.reg_values[reg]
if destsize == 64:
if fhex < 0:
fhex += 2**64
fval = int2fp64(fhex)
elif destsize == 32:
if fhex < 0:
fhex += 2**32
fval = int2fp32(fhex)
else:
return False
if abs(fval) < 1e-5 or abs(fval) > 1e5:
return False
asm = f"{insn.mnemonic} {insn.op_str}"
return asm, fval
def check_ldr(self, insn):
if insn.id != ARM64_INS_LDR:
return False
ops = insn.op_str[:-1].split(", ")
destsize, desttype = self.reg_size_type(insn.operands[0].value.reg)
if len(ops) < 2 or desttype != float:
return False
reg = ops[1]
if reg[0] != "[" or "sp" in reg:
return False
basereg = ARM64_REG_X0 + int(reg[2:]) # Shitty hack, may malfunction
if basereg not in self.reg_values:
return False
base = align4(self.reg_values[basereg])
if len(ops) == 3:
offset = ops[2][1:]
if offset.startswith("0x"):
offset = int(offset[2:], base=16)
else:
offset = int(offset)
else:
offset = 0
addr = base + offset
if destsize == 64:
fhex = self.loader.memory.load(addr, 8)
fval = struct.unpack("d", fhex)[0]
return fval, addr, 8
elif destsize == 32:
fhex = self.loader.memory.load(addr, 4)
fval = struct.unpack("f", fhex)[0]
return fval, addr, 4
else:
return False
def check_branch_symbol(self, insn):
if insn.id not in {ARM64_INS_BL, ARM64_INS_B}:
return False
ops = insn.operands
if len(ops) != 1 or ops[0].type != ARM_OP_IMM:
return False
addr = ops[0].value.imm
if addr > self.funcrange[0] and addr < self.funcrange[1]:
# Self-branch
func = f"SELF+{hex(addr - self.funcrange[0])}"
else:
func = self.loader.find_plt_stub_name(addr)
if func is None:
# Some tail call optimized PLT stubs have extra instructions
# that are not identified by CLE, so check with offset of 4 also.
func = self.loader.find_plt_stub_name(addr + 4)
if func is None:
return False
asm = f"{insn.mnemonic} <{func}>"
return asm
def disassemble(self, funcname):
funcaddr, funcbytes = self.get_function_bytes(funcname)
disassm = []
for insn in self.md.disasm(funcbytes, funcaddr):
if insn.address in self.constaddrs:
# Skip if this is a constant value and not instruction
continue
cname = None
asm = None
# Maintain values of immediate moves
if movimm := self.check_mov_imm(insn):
reg, imm = movimm
self.reg_values[reg] = imm
else:
reads, writes = insn.regs_access()
for r in writes:
# Remove this reg if written to
if r in self.reg_values:
del self.reg_values[r]
if fmov := self.check_fmov(insn):
asm, fval = fmov
cname = self.add_constant(fval)
elif ldr := self.check_ldr(insn):
fval, faddr, fsize = ldr
cname = self.add_constant(fval, faddr, fsize)
elif branch := self.check_branch_symbol(insn):
asm = branch
if not asm:
asm = f"{insn.mnemonic} {insn.op_str}"
if cname:
asm += f", {cname}"
disassm.append(asm)
fulldiss = "; ".join(disassm)
return fulldiss
class DisassemblerX64(DisassemblerBase):
def __init__(self, binpath, expr_constants={}, match_constants=False):
super().__init__(expr_constants=expr_constants, match_constants=match_constants)
self.md = Cs(CS_ARCH_X86, CS_MODE_64)
self.md.detail = True
self.loader = cle.Loader(binpath)
def check_call_symbol(self, insn):
if insn.id != X86_INS_CALL:
return False
ops = insn.operands
if len(ops) != 1 or ops[0].type != X86_OP_IMM:
return False
addr = ops[0].value.imm
func = self.loader.find_plt_stub_name(addr)
if func is None:
return False
asm = f"{insn.mnemonic} <{func}>"
return asm
def check_fload(self, insn):
# Cannot rely on ID because any instruction
# can access memory.
ops = insn.operands
memops = [op for op in ops
if (op.type == X86_OP_MEM and
op.value.mem.base == X86_REG_RIP)]
if len(memops) != 1:
return False
mem, size = memops[0].value.mem, memops[0].size
if size > 8:
return False
addr = insn.address + insn.size + mem.disp
fhex = self.loader.memory.load(addr, size)
fval = struct.unpack("f" if size == 4 else "d", fhex)[0]
return fval, addr, size
def disassemble(self, funcname):
funcaddr, funcbytes = self.get_function_bytes(funcname)
disassm = []
for insn in self.md.disasm(funcbytes, funcaddr):
asm = None
cname = None
if fload := self.check_fload(insn):
fval, faddr, fsize = fload
cname = self.add_constant(fval, faddr, fsize)
elif call := self.check_call_symbol(insn):
asm = call
if not asm:
asm = f"{insn.mnemonic} {insn.op_str}"
if cname:
asm += f", {cname}"
disassm.append(asm)
fulldiss = "; ".join(disassm)
return fulldiss
# Regular
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump")
parser.add_argument("--bin", required=True)
parser.add_argument("--func", required=True)
parser.add_argument("--arch", required=True)
args = parser.parse_args()
if args.arch == "arm32":
D = DisassemblerARM32(args.bin)
elif args.arch == "aarch64":
D = DisassemblerAArch64(args.bin)
elif args.arch == "x64":
D = DisassemblerX64(args.bin)
diss = D.disassemble(args.func)
print(diss)
print(D.constants)