|
from capstone import * |
|
from capstone.arm import * |
|
from capstone.arm64 import * |
|
from capstone.x86 import * |
|
import cle |
|
import struct |
|
from math import e as CONST_E, pi as CONST_PI |
|
import sympy as sp |
|
|
|
from .util import DecodeError |
|
|
|
def int2fp32(v): |
|
if type(v) == int: |
|
v = struct.unpack("<f", v.to_bytes(4, "little")) |
|
v = v[0] |
|
return v |
|
def int2fp64(v): |
|
if type(v) == int: |
|
v = struct.unpack("<d", v.to_bytes(8, "little")) |
|
v = v[0] |
|
return v |
|
|
|
def align4(v): |
|
return v & (0xFFFFFFFC) |
|
|
|
class DisassemblerBase: |
|
def __init__(self, expr_constants={}, match_constants=False): |
|
self.loader = None |
|
self.reg_values = {} |
|
self.constidx = 0 |
|
self.constants = {} |
|
self.constaddrs = set() |
|
self.expr_constants = expr_constants |
|
self.match_constants = match_constants |
|
|
|
def get_function_bytes(self, funcname): |
|
func = self.loader.find_symbol(funcname) |
|
if not func: |
|
raise DecodeError(f"Function {funcname} not found in binary") |
|
faddr = func.rebased_addr |
|
if (not isinstance(self, DisassemblerX64)) and faddr % 2 == 1: |
|
|
|
faddr = faddr - 1 |
|
fbytes = self.loader.memory.load(faddr, func.size) |
|
self.funcrange = faddr, faddr + func.size |
|
return faddr, fbytes |
|
|
|
def find_constant(self, constants, value): |
|
for ec in constants: |
|
if abs(value - constants[ec]) < 1e-5: |
|
return ec, "" |
|
elif abs(1/value - constants[ec]) < 1e-5: |
|
return ec, "1/" |
|
elif abs(-value - constants[ec]) < 1e-5: |
|
return ec, "-" |
|
elif abs(-1/value - constants[ec]) < 1e-5: |
|
return ec, "-1/" |
|
return False |
|
|
|
def add_constant(self, value, addr=0, size=0): |
|
|
|
if value == 0: |
|
cname = "CONST=0" |
|
elif abs(value - CONST_E) < 1e-7: |
|
cname = "CONST=E" |
|
elif abs(value - CONST_PI) < 1e-7: |
|
cname = "CONST=pi" |
|
elif self.match_constants and \ |
|
(ecmatch := self.find_constant(self.expr_constants, value)): |
|
|
|
ecname, ecxpr = ecmatch |
|
|
|
cname = f"{ecxpr}CSYM{ecname[1:]}" |
|
self.constants[ecname] = value |
|
elif size > 0 and addr in self.constaddrs and \ |
|
(smatch := self.find_constant(self.constants, value)): |
|
sname, sxpr = smatch |
|
cname = f"{sxpr}CSYM{sname}" |
|
else: |
|
rep = sp.nsimplify(value, [sp.E, sp.pi], tolerance=1e-7) |
|
if isinstance(rep, sp.Integer) or \ |
|
(isinstance(rep, sp.Rational) and rep.q <= 16): |
|
cname = f"CONST={rep}" |
|
elif not self.match_constants: |
|
cname = f"CSYM{self.constidx}" |
|
self.constants[self.constidx] = value |
|
self.constidx += 1 |
|
else: |
|
raise DecodeError(f"Cannot represent unmatched float {value}") |
|
|
|
if size > 0: |
|
self.constaddrs |= {addr+i for i in range(size)} |
|
return cname |
|
|
|
def disassemble(self, function): |
|
raise NotImplementedError("Call disassemble on child classes, not base") |
|
|
|
|
|
class DisassemblerARM32(DisassemblerBase): |
|
def __init__(self, binpath, expr_constants={}, match_constants=False): |
|
super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
|
self.md = Cs(CS_ARCH_ARM, CS_MODE_THUMB) |
|
self.md.detail = True |
|
self.loader = cle.Loader(binpath) |
|
|
|
def check_mov_imm(self, insn): |
|
if insn.id not in {ARM_INS_MOV, ARM_INS_MOVW, |
|
ARM_INS_MOVT, ARM_INS_ADR}: |
|
return False |
|
ops = list(insn.operands) |
|
if len(ops) != 2: |
|
return False |
|
if ops[0].type != ARM_OP_REG or ops[1].type != ARM_OP_IMM: |
|
return False |
|
imm = ops[1].value.imm |
|
if imm < 0: |
|
imm = 2**32 + imm |
|
if insn.id == ARM_INS_ADR: |
|
|
|
imm += insn.address + 4 |
|
return ops[0].value.reg, imm |
|
|
|
def check_float_store(self, insn): |
|
if insn.id not in {ARM_INS_STR, ARM_INS_STRD}: |
|
return False |
|
ops = list(insn.operands) |
|
if insn.id == ARM_INS_STRD: |
|
dest = ops[0].value.reg |
|
dest2 = ops[1].value.reg |
|
if dest not in self.reg_values or dest2 not in self.reg_values: |
|
return False |
|
fval = int2fp64((self.reg_values[dest2]<<32) + self.reg_values[dest]) |
|
else: |
|
dest = ops[0].value.reg |
|
if dest not in self.reg_values: |
|
return False |
|
fval = int2fp32(self.reg_values[dest]) |
|
if abs(fval) < 1e-3 or abs(fval) > 100: |
|
return False |
|
return fval |
|
|
|
def check_ldrd(self, insn): |
|
if insn.id != ARM_INS_LDRD: |
|
return False |
|
ops = list(insn.operands) |
|
if len(ops) != 3: |
|
return False |
|
if ops[2].type != ARM_OP_MEM: |
|
return False |
|
mem = ops[2].value.mem |
|
if mem.base == ARM_REG_PC: |
|
addr = align4(insn.address + 4) + mem.disp |
|
elif mem.base in self.reg_values: |
|
addr = align4(self.reg_values[mem.base]) + mem.disp |
|
else: |
|
return False |
|
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr: |
|
|
|
return False |
|
fhex = self.loader.memory.load(addr, 8) |
|
fval = struct.unpack("d", fhex)[0] |
|
return fval, addr, 8 |
|
|
|
def check_vldr(self, insn): |
|
if insn.id != ARM_INS_VLDR: |
|
return False |
|
ops = list(insn.operands) |
|
dest = ops[0] |
|
if ops[1].type != ARM_OP_MEM: |
|
return False |
|
mem = ops[1].value.mem |
|
if mem.base == ARM_REG_PC: |
|
|
|
|
|
addr = align4(insn.address + 4) + mem.disp |
|
elif mem.base in self.reg_values: |
|
addr = align4(self.reg_values[mem.base]) + mem.disp |
|
else: |
|
return False |
|
if addr < self.loader.min_addr or addr + 8 > self.loader.max_addr: |
|
|
|
return False |
|
if dest.value.reg >= ARM_REG_D0 and dest.value.reg <= ARM_REG_D31: |
|
size = 8 |
|
fhex = self.loader.memory.load(addr, 8) |
|
fval = struct.unpack("d", fhex)[0] |
|
else: |
|
size = 4 |
|
fhex = self.loader.memory.load(addr, 4) |
|
fval = struct.unpack("f", fhex)[0] |
|
return fval, addr, size |
|
|
|
def check_vmov(self, insn): |
|
|
|
if insn.id not in {ARM_INS_FCONSTS, ARM_INS_FCONSTD}: |
|
return False |
|
ops = list(insn.operands) |
|
if len(ops) != 2 or ops[1].type != ARM_OP_FP: |
|
return False |
|
fval = ops[1].value.fp |
|
destname = insn.reg_name(ops[0].value.reg) |
|
asm = f"{insn.mnemonic} {destname}, {fval}" |
|
return asm, fval |
|
|
|
def check_branch_symbol(self, insn): |
|
if insn.id not in {ARM_INS_B, ARM_INS_BL, ARM_INS_BLX}: |
|
return False |
|
ops = list(insn.operands) |
|
if len(ops) != 1 or ops[0].type != ARM_OP_IMM: |
|
return False |
|
addr = ops[0].value.imm |
|
if addr > self.funcrange[0] and addr < self.funcrange[1]: |
|
|
|
func = f"SELF+{hex(addr - self.funcrange[0])}" |
|
else: |
|
func = self.loader.find_plt_stub_name(addr) |
|
if func is None: |
|
|
|
|
|
func = self.loader.find_plt_stub_name(addr + 4) |
|
if func is None: |
|
return False |
|
asm = f"{insn.mnemonic} <{func}>" |
|
return asm |
|
|
|
def get_function_bytes(self, funcname): |
|
func = self.loader.find_symbol(funcname) |
|
if not func: |
|
raise DecodeError(f"Function {funcname} not found in binary") |
|
faddr = func.rebased_addr |
|
if faddr % 2 == 1: |
|
|
|
faddr = faddr - 1 |
|
fbytes = self.loader.memory.load(faddr, func.size) |
|
self.funcrange = faddr, faddr + func.size |
|
return faddr, fbytes |
|
|
|
def disassemble(self, funcname): |
|
funcaddr, funcbytes = self.get_function_bytes(funcname) |
|
disassm = [] |
|
|
|
for insn in self.md.disasm(funcbytes, funcaddr): |
|
if insn.address in self.constaddrs: |
|
|
|
continue |
|
|
|
cname = None |
|
asm = None |
|
|
|
if vldr := self.check_vldr(insn): |
|
fval, faddr, fsize = vldr |
|
cname = self.add_constant(fval, faddr, fsize) |
|
elif ldrd := self.check_ldrd(insn): |
|
fval, faddr, fsize = ldrd |
|
cname = self.add_constant(fval, faddr, fsize) |
|
elif strfloat := self.check_float_store(insn): |
|
fval = strfloat |
|
cname = self.add_constant(fval) |
|
elif vmovfloat := self.check_vmov(insn): |
|
asm, fval = vmovfloat |
|
cname = self.add_constant(fval) |
|
elif branch := self.check_branch_symbol(insn): |
|
asm = branch |
|
|
|
|
|
|
|
if movimm := self.check_mov_imm(insn): |
|
reg, imm = movimm |
|
if insn.id == ARM_INS_MOVT: |
|
if reg not in self.reg_values: |
|
self.reg_values[reg] = 0 |
|
self.reg_values[reg] += imm << 16 |
|
else: |
|
self.reg_values[reg] = imm |
|
else: |
|
reads, writes = insn.regs_access() |
|
for r in writes: |
|
|
|
if r in self.reg_values: |
|
del self.reg_values[r] |
|
|
|
if not asm: |
|
asm = f"{insn.mnemonic} {insn.op_str}" |
|
if cname: |
|
asm += f", {cname}" |
|
disassm.append(asm) |
|
|
|
fulldiss = "; ".join(disassm) |
|
return fulldiss |
|
|
|
class DisassemblerAArch64(DisassemblerBase): |
|
def __init__(self, binpath, expr_constants={}, match_constants=False): |
|
super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
|
self.md = Cs(CS_ARCH_ARM64, CS_MODE_ARM) |
|
self.md.detail = True |
|
self.loader = cle.Loader(binpath) |
|
|
|
def reg_size_type(self, reg): |
|
|
|
if reg >= ARM64_REG_W0 and reg <= ARM64_REG_W30: |
|
return 32, int |
|
elif reg >= ARM64_REG_X0 and reg <= ARM64_REG_X30: |
|
return 64, int |
|
elif reg >= ARM64_REG_S0 and reg <= ARM64_REG_S31: |
|
return 32, float |
|
elif reg >= ARM64_REG_D0 and reg <= ARM64_REG_D31: |
|
return 64, float |
|
return 0, None |
|
|
|
def check_mov_imm(self, insn): |
|
if insn.id not in {ARM64_INS_ADRP, ARM64_INS_ADR, ARM64_INS_MOV, ARM64_INS_MOVK}: |
|
return False |
|
|
|
ops = insn.operands |
|
if len(ops) != 2: |
|
return False |
|
if ops[0].type != ARM64_OP_REG or ops[1].type != ARM64_OP_IMM: |
|
return False |
|
|
|
imm = ops[1].value.imm |
|
if ops[1].shift.type == 1: |
|
imm <<= ops[1].shift.value |
|
mask = 0xFFFF << ops[1].shift.value |
|
|
|
if insn.id == ARM64_INS_ADRP: |
|
|
|
|
|
|
|
pass |
|
elif insn.id == ARM64_INS_ADR: |
|
imm -= 0x400000 |
|
imm += insn.address + 4 |
|
elif insn.id == ARM64_INS_MOVK: |
|
|
|
if ops[0].value.reg in self.reg_values: |
|
curr = self.reg_values[ops[0].value.reg] |
|
imm = (imm & mask) | (curr & (~mask)) |
|
|
|
return ops[0].value.reg, imm |
|
|
|
def check_fmov(self, insn): |
|
if insn.id != ARM64_INS_FMOV: |
|
return False |
|
ops = insn.operands |
|
if len(ops) != 2: |
|
return False |
|
|
|
destsize, _ = self.reg_size_type(ops[0].value.reg) |
|
destname = insn.reg_name(ops[0].value.reg) |
|
if ops[1].type == ARM64_OP_FP: |
|
fval = ops[1].value.fp |
|
asm = f"{insn.mnemonic} {destname}, {fval}" |
|
elif ops[1].type == ARM64_OP_REG: |
|
reg = ops[1].value.reg |
|
if reg not in self.reg_values: |
|
return False |
|
|
|
fhex = self.reg_values[reg] |
|
if destsize == 64: |
|
if fhex < 0: |
|
fhex += 2**64 |
|
fval = int2fp64(fhex) |
|
elif destsize == 32: |
|
if fhex < 0: |
|
fhex += 2**32 |
|
fval = int2fp32(fhex) |
|
else: |
|
return False |
|
|
|
if abs(fval) < 1e-5 or abs(fval) > 1e5: |
|
return False |
|
asm = f"{insn.mnemonic} {insn.op_str}" |
|
return asm, fval |
|
|
|
def check_ldr(self, insn): |
|
if insn.id != ARM64_INS_LDR: |
|
return False |
|
ops = insn.op_str[:-1].split(", ") |
|
destsize, desttype = self.reg_size_type(insn.operands[0].value.reg) |
|
if len(ops) < 2 or desttype != float: |
|
return False |
|
reg = ops[1] |
|
if reg[0] != "[" or "sp" in reg: |
|
return False |
|
basereg = ARM64_REG_X0 + int(reg[2:]) |
|
if basereg not in self.reg_values: |
|
return False |
|
base = align4(self.reg_values[basereg]) |
|
if len(ops) == 3: |
|
offset = ops[2][1:] |
|
if offset.startswith("0x"): |
|
offset = int(offset[2:], base=16) |
|
else: |
|
offset = int(offset) |
|
else: |
|
offset = 0 |
|
addr = base + offset |
|
if destsize == 64: |
|
fhex = self.loader.memory.load(addr, 8) |
|
fval = struct.unpack("d", fhex)[0] |
|
return fval, addr, 8 |
|
elif destsize == 32: |
|
fhex = self.loader.memory.load(addr, 4) |
|
fval = struct.unpack("f", fhex)[0] |
|
return fval, addr, 4 |
|
else: |
|
return False |
|
|
|
|
|
def check_branch_symbol(self, insn): |
|
if insn.id not in {ARM64_INS_BL, ARM64_INS_B}: |
|
return False |
|
ops = insn.operands |
|
if len(ops) != 1 or ops[0].type != ARM_OP_IMM: |
|
return False |
|
addr = ops[0].value.imm |
|
if addr > self.funcrange[0] and addr < self.funcrange[1]: |
|
|
|
func = f"SELF+{hex(addr - self.funcrange[0])}" |
|
else: |
|
func = self.loader.find_plt_stub_name(addr) |
|
if func is None: |
|
|
|
|
|
func = self.loader.find_plt_stub_name(addr + 4) |
|
if func is None: |
|
return False |
|
asm = f"{insn.mnemonic} <{func}>" |
|
return asm |
|
|
|
def disassemble(self, funcname): |
|
funcaddr, funcbytes = self.get_function_bytes(funcname) |
|
disassm = [] |
|
|
|
for insn in self.md.disasm(funcbytes, funcaddr): |
|
if insn.address in self.constaddrs: |
|
|
|
continue |
|
|
|
cname = None |
|
asm = None |
|
|
|
if movimm := self.check_mov_imm(insn): |
|
reg, imm = movimm |
|
self.reg_values[reg] = imm |
|
else: |
|
reads, writes = insn.regs_access() |
|
for r in writes: |
|
|
|
if r in self.reg_values: |
|
del self.reg_values[r] |
|
|
|
if fmov := self.check_fmov(insn): |
|
asm, fval = fmov |
|
cname = self.add_constant(fval) |
|
elif ldr := self.check_ldr(insn): |
|
fval, faddr, fsize = ldr |
|
cname = self.add_constant(fval, faddr, fsize) |
|
elif branch := self.check_branch_symbol(insn): |
|
asm = branch |
|
|
|
if not asm: |
|
asm = f"{insn.mnemonic} {insn.op_str}" |
|
if cname: |
|
asm += f", {cname}" |
|
disassm.append(asm) |
|
|
|
fulldiss = "; ".join(disassm) |
|
return fulldiss |
|
|
|
class DisassemblerX64(DisassemblerBase): |
|
def __init__(self, binpath, expr_constants={}, match_constants=False): |
|
super().__init__(expr_constants=expr_constants, match_constants=match_constants) |
|
self.md = Cs(CS_ARCH_X86, CS_MODE_64) |
|
self.md.detail = True |
|
self.loader = cle.Loader(binpath) |
|
|
|
def check_call_symbol(self, insn): |
|
if insn.id != X86_INS_CALL: |
|
return False |
|
ops = insn.operands |
|
if len(ops) != 1 or ops[0].type != X86_OP_IMM: |
|
return False |
|
addr = ops[0].value.imm |
|
func = self.loader.find_plt_stub_name(addr) |
|
if func is None: |
|
return False |
|
asm = f"{insn.mnemonic} <{func}>" |
|
return asm |
|
|
|
def check_fload(self, insn): |
|
|
|
|
|
ops = insn.operands |
|
memops = [op for op in ops |
|
if (op.type == X86_OP_MEM and |
|
op.value.mem.base == X86_REG_RIP)] |
|
if len(memops) != 1: |
|
return False |
|
mem, size = memops[0].value.mem, memops[0].size |
|
if size > 8: |
|
return False |
|
addr = insn.address + insn.size + mem.disp |
|
fhex = self.loader.memory.load(addr, size) |
|
fval = struct.unpack("f" if size == 4 else "d", fhex)[0] |
|
return fval, addr, size |
|
|
|
def disassemble(self, funcname): |
|
funcaddr, funcbytes = self.get_function_bytes(funcname) |
|
disassm = [] |
|
|
|
for insn in self.md.disasm(funcbytes, funcaddr): |
|
asm = None |
|
cname = None |
|
if fload := self.check_fload(insn): |
|
fval, faddr, fsize = fload |
|
cname = self.add_constant(fval, faddr, fsize) |
|
elif call := self.check_call_symbol(insn): |
|
asm = call |
|
|
|
if not asm: |
|
asm = f"{insn.mnemonic} {insn.op_str}" |
|
if cname: |
|
asm += f", {cname}" |
|
disassm.append(asm) |
|
|
|
fulldiss = "; ".join(disassm) |
|
return fulldiss |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser("Pre-process assembly to replace constants and dump") |
|
parser.add_argument("--bin", required=True) |
|
parser.add_argument("--func", required=True) |
|
parser.add_argument("--arch", required=True) |
|
args = parser.parse_args() |
|
|
|
if args.arch == "arm32": |
|
D = DisassemblerARM32(args.bin) |
|
elif args.arch == "aarch64": |
|
D = DisassemblerAArch64(args.bin) |
|
elif args.arch == "x64": |
|
D = DisassemblerX64(args.bin) |
|
diss = D.disassemble(args.func) |
|
print(diss) |
|
print(D.constants) |
|
|