#!/usr/bin/env python
#===- lib/hwasan/scripts/hwasan_symbolize ----------------------------------===#
#
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https:#llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
#
#===------------------------------------------------------------------------===#
#
# HWAddressSanitizer offline symbolization script.
#
#===------------------------------------------------------------------------===#

from __future__ import print_function
from __future__ import unicode_literals

import glob
import os
import re
import sys
import string
import subprocess
import argparse
import mmap
import struct
import os

if sys.version_info.major < 3:
  # Simulate Python 3.x behaviour of defaulting to UTF-8 for print. This is
  # important in case any symbols are non-ASCII.
  import codecs
  sys.stdout = codecs.getwriter("utf-8")(sys.stdout)

last_access_address = None
last_access_tag = None

# Below, a parser for a subset of ELF. It only supports 64 bit, little-endian,
# and only parses what is necessary to find the build ids. It uses a memoryview
# into an mmap to avoid copying.
Ehdr_size = 64
e_shnum_offset = 60
e_shoff_offset = 40

Shdr_size = 64
sh_type_offset = 4
sh_offset_offset = 24
sh_size_offset = 32
SHT_NOTE = 7

Nhdr_size = 12
NT_GNU_BUILD_ID = 3

def align_up(size, alignment):
  return (size + alignment - 1) & ~(alignment - 1)

def handle_Nhdr(mv, sh_size):
  offset = 0
  while offset < sh_size:
    n_namesz, n_descsz, n_type = struct.unpack_from('<III', buffer=mv,
                                                    offset=offset)
    if (n_type == NT_GNU_BUILD_ID and n_namesz == 4 and
        mv[offset + Nhdr_size: offset + Nhdr_size + 4] == b"GNU\x00"):
      value = mv[offset + Nhdr_size + 4: offset + Nhdr_size + 4 + n_descsz]
      return value.hex()
    offset += Nhdr_size + align_up(n_namesz, 4) + align_up(n_descsz, 4)
  return None

def handle_Shdr(mv):
  sh_type, = struct.unpack_from('<I', buffer=mv, offset=sh_type_offset)
  if sh_type != SHT_NOTE:
    return None, None
  sh_offset, = struct.unpack_from('<Q', buffer=mv, offset=sh_offset_offset)
  sh_size, = struct.unpack_from('<Q', buffer=mv, offset=sh_size_offset)
  return sh_offset, sh_size

def handle_elf(mv):
  # \x02 is ELFCLASS64, \x01 is ELFDATA2LSB. HWASan currently only works on
  # 64-bit little endian platforms (x86_64 and ARM64). If this changes, we will
  # have to extend the parsing code.
  if mv[:6] != b'\x7fELF\x02\x01':
    return None
  e_shnum, = struct.unpack_from('<H', buffer=mv, offset=e_shnum_offset)
  e_shoff, = struct.unpack_from('<Q', buffer=mv, offset=e_shoff_offset)
  for i in range(0, e_shnum):
    start = e_shoff + i * Shdr_size
    sh_offset, sh_size = handle_Shdr(mv[start: start + Shdr_size])
    if sh_offset is None:
      continue
    note_hdr = mv[sh_offset: sh_offset + sh_size]
    result = handle_Nhdr(note_hdr, sh_size)
    if result is not None:
      return result

def get_buildid(filename):
  with open(filename, "r") as fd:
    if os.fstat(fd.fileno()).st_size < Ehdr_size:
      return None
    with mmap.mmap(fd.fileno(), 0, access=mmap.ACCESS_READ) as m:
      with memoryview(m) as mv:
        return handle_elf(mv)

class Symbolizer:
  def __init__(self, path, binary_prefixes, paths_to_cut):
    self.__pipe = None
    self.__path = path
    self.__binary_prefixes = binary_prefixes
    self.__paths_to_cut = paths_to_cut
    self.__log = False
    self.__warnings = set()
    self.__index = {}

  def enable_logging(self, enable):
    self.__log = enable

  def __open_pipe(self):
    if not self.__pipe:
      opt = {}
      if sys.version_info.major > 2:
        opt['encoding'] = 'utf-8'
      self.__pipe = subprocess.Popen([self.__path, "--inlining", "--functions"],
                                     stdin=subprocess.PIPE, stdout=subprocess.PIPE,
                                     **opt)

  class __EOF(Exception):
    pass

  def __write(self, s):
    print(s, file=self.__pipe.stdin)
    self.__pipe.stdin.flush()
    if self.__log:
      print("#>>  |%s|" % (s,), file=sys.stderr)

  def __read(self):
    s = self.__pipe.stdout.readline().rstrip()
    if self.__log:
      print("# << |%s|" % (s,), file=sys.stderr)
    if s == '':
      raise Symbolizer.__EOF
    return s

  def __process_source_path(self, file_name):
    for path_to_cut in self.__paths_to_cut:
      file_name = re.sub(".*" + path_to_cut, "", file_name)
    file_name = re.sub(".*hwasan_[a-z_]*.(cc|h):[0-9]*", "[hwasan_rtl]", file_name)
    file_name = re.sub(".*asan_[a-z_]*.(cc|h):[0-9]*", "[asan_rtl]", file_name)
    file_name = re.sub(".*crtstuff.c:0", "???:0", file_name)
    return file_name

  def __process_binary_name(self, name, buildid):
    if name.startswith('/'):
      name = name[1:]
    if buildid is not None and buildid in self.__index:
      return self.__index[buildid]

    for p in self.__binary_prefixes:
      full_path = os.path.join(p, name)
      if os.path.exists(full_path):
        return full_path
      apex_prefix = "apex/com.android."
      if name.startswith(apex_prefix):
        full_path = os.path.join(p, "apex/com.google.android." + name[len(apex_prefix):])
        if os.path.exists(full_path):
          return full_path
    # Try stripping extra path components as the last resort.
    for p in self.__binary_prefixes:
      full_path = os.path.join(p, os.path.basename(name))
      if os.path.exists(full_path):
        return full_path
    if name not in self.__warnings:
      print("Could not find symbols for", name, file=sys.stderr)
      self.__warnings.add(name)
    return None

  def iter_locals(self, binary, addr, buildid):
    self.__open_pipe()
    p = self.__pipe
    binary = self.__process_binary_name(binary, buildid)
    if not binary:
      return
    self.__write("FRAME %s %s" % (binary, addr))
    try:
      while True:
        function_name = self.__read()
        local_name = self.__read()
        file_line = self.__read()
        extra = self.__read().split()

        file_line = self.__process_source_path(file_line)
        offset = None if extra[0] == '??' else int(extra[0])
        size = None if extra[1] == '??' else int(extra[1])
        tag_offset = None if extra[2] == '??' else int(extra[2])
        yield (function_name, file_line, local_name, offset, size, tag_offset)
    except Symbolizer.__EOF:
      pass

  def iter_call_stack(self, binary, buildid, addr):
    self.__open_pipe()
    p = self.__pipe
    binary = self.__process_binary_name(binary, buildid)
    if not binary:
      return
    self.__write("CODE %s %s" % (binary, addr))
    try:
      while True:
        function_name = self.__read()
        file_line = self.__read()
        file_line = self.__process_source_path(file_line)
        yield (function_name, file_line)
    except Symbolizer.__EOF:
      pass

  def build_index(self):
    for p in self.__binary_prefixes:
      for dname, _, fnames in os.walk(p):
        for fn in fnames:
          filename = os.path.join(dname, fn)
          bid = get_buildid(filename)
          if bid is not None:
            self.__index[bid] = filename

def symbolize_line(line, symbolizer_path):
  #0 0x7f6e35cf2e45  (/blah/foo.so+0x11fe45) (BuildId: 4abce4cd41ea5c2f34753297b7e774d9)
  match = re.match(r'^(.*?)#([0-9]+)( *)(0x[0-9a-f]*) *\((.*)\+(0x[0-9a-f]+)\)'
                   r'(?:\s*\(BuildId: ([0-9a-f]+)\))?', line, re.UNICODE)
  if match:
    frameno = match.group(2)
    binary = match.group(5)
    addr = int(match.group(6), 16)
    buildid = match.group(7)

    frames = list(symbolizer.iter_call_stack(binary, buildid, addr))

    if len(frames) > 0:
      print("%s#%s%s%s in %s" % (match.group(1), match.group(2),
                                 match.group(3), frames[0][0], frames[0][1]))
      for i in range(1, len(frames)):
        space1 = ' ' * match.end(1)
        space2 = ' ' * (match.start(4) - match.end(1) - 2)
        print("%s->%s%s in %s" % (space1, space2, frames[i][0], frames[i][1]))
    else:
      print(line.rstrip())
  else:
    print(line.rstrip())

def save_access_address(line):
  global last_access_address, last_access_tag
  match = re.match(r'^(.*?)HWAddressSanitizer: tag-mismatch on address (0x[0-9a-f]+) ', line, re.UNICODE)
  if match:
    last_access_address = int(match.group(2), 16)
  match = re.match(r'^(.*?) of size [0-9]+ at 0x[0-9a-f]* tags: ([0-9a-f]+)/[0-9a-f]+ \(ptr/mem\)', line, re.UNICODE)
  if match:
    last_access_tag = int(match.group(2), 16)

def process_stack_history(line, symbolizer, ignore_tags=False):
  if last_access_address is None or last_access_tag is None:
    return
  if re.match(r'Previously allocated frames:', line, re.UNICODE):
    return True
  pc_mask = (1 << 48) - 1
  fp_mask = (1 << 20) - 1
  # record_addr:0x1234ABCD record:0x1234ABCD (/path/to/binary+0x1234ABCD) (BuildId: 4abce4cd41ea5c2f34753297b7e774d9)
  match = re.match(r'^(.*?)record_addr:(0x[0-9a-f]+) +record:(0x[0-9a-f]+) +\((.*)\+(0x[0-9a-f]+)\)'
                   r'(?:\s*\(BuildId: ([0-9a-f]+)\))?', line, re.UNICODE)
  if match:
    record_addr = int(match.group(2), 16)
    record = int(match.group(3), 16)
    binary = match.group(4)
    addr = int(match.group(5), 16)
    buildid = match.group(6)
    base_tag = (record_addr >> 3) & 0xFF
    fp = (record >> 48) << 4
    pc = record & pc_mask

    for local in symbolizer.iter_locals(binary, addr, buildid):
      frame_offset = local[3]
      size = local[4]
      if frame_offset is None or size is None:
        continue
      obj_offset = (last_access_address - fp - frame_offset) & fp_mask
      if obj_offset >= size:
        continue
      tag_offset = local[5]
      if not ignore_tags and (tag_offset is None or base_tag ^ tag_offset != last_access_tag):
        continue
      print('')
      print('Potentially referenced stack object:')
      print('  %d bytes inside variable "%s" in stack frame of function "%s"' % (obj_offset, local[2], local[0]))
      print('  at %s' % (local[1],))
    return True
  return False

parser = argparse.ArgumentParser()
parser.add_argument('-d', action='store_true')
parser.add_argument('-v', action='store_true')
parser.add_argument('--ignore-tags', action='store_true')
parser.add_argument('--symbols', action='append')
parser.add_argument('--source', action='append')
parser.add_argument('--index', action='store_true')
parser.add_argument('--symbolizer')
parser.add_argument('args', nargs=argparse.REMAINDER)
args = parser.parse_args()

# Unstripped binaries location.
binary_prefixes = args.symbols or []
if not binary_prefixes:
  if 'ANDROID_PRODUCT_OUT' in os.environ:
    product_out = os.path.join(os.environ['ANDROID_PRODUCT_OUT'], 'symbols')
    binary_prefixes.append(product_out)
  binary_prefixes.append('/')

for p in binary_prefixes:
  if not os.path.isdir(p):
    print("Symbols path does not exist or is not a directory:", p, file=sys.stderr)
    sys.exit(1)

# Source location.
paths_to_cut = args.source or []
if not paths_to_cut:
  paths_to_cut.append(os.getcwd() + '/')
  if 'ANDROID_BUILD_TOP' in os.environ:
    paths_to_cut.append(os.environ['ANDROID_BUILD_TOP'] + '/')

# llvm-symbolizer binary.
# 1. --symbolizer flag
# 2. environment variable
# 3. unsuffixed binary in the current directory
# 4. if inside Android platform, prebuilt binary at a known path
# 5. first "llvm-symbolizer", then "llvm-symbolizer-$VER" with the
#    highest available version in $PATH
symbolizer_path = args.symbolizer
if not symbolizer_path:
  if 'LLVM_SYMBOLIZER_PATH' in os.environ:
    symbolizer_path = os.environ['LLVM_SYMBOLIZER_PATH']
  elif 'HWASAN_SYMBOLIZER_PATH' in os.environ:
    symbolizer_path = os.environ['HWASAN_SYMBOLIZER_PATH']

if not symbolizer_path:
  s = os.path.join(os.path.dirname(sys.argv[0]), 'llvm-symbolizer')
  if os.path.exists(s):
    symbolizer_path = s

if not symbolizer_path:
  if 'ANDROID_BUILD_TOP' in os.environ:
    s = os.path.join(os.environ['ANDROID_BUILD_TOP'], 'prebuilts/clang/host/linux-x86/llvm-binutils-stable/llvm-symbolizer')
    if os.path.exists(s):
      symbolizer_path = s

if not symbolizer_path:
  for path in os.environ["PATH"].split(os.pathsep):
    p = os.path.join(path, 'llvm-symbolizer')
    if os.path.exists(p):
      symbolizer_path = p
      break

def extract_version(s):
  idx = s.rfind('-')
  if idx == -1:
    return 0
  x = float(s[idx + 1:])
  return x

if not symbolizer_path:
  for path in os.environ["PATH"].split(os.pathsep):
    candidates = glob.glob(os.path.join(path, 'llvm-symbolizer-*'))
    if len(candidates) > 0:
      candidates.sort(key = extract_version, reverse = True)
      symbolizer_path = candidates[0]
      break

if not os.path.exists(symbolizer_path):
  print("Symbolizer path does not exist:", symbolizer_path, file=sys.stderr)
  sys.exit(1)

if args.v:
  print("Looking for symbols in:")
  for s in binary_prefixes:
    print("  %s" % (s,))
  print("Stripping source path prefixes:")
  for s in paths_to_cut:
    print("  %s" % (s,))
  print("Using llvm-symbolizer binary in:\n  %s" % (symbolizer_path,))
  print()

symbolizer = Symbolizer(symbolizer_path, binary_prefixes, paths_to_cut)
symbolizer.enable_logging(args.d)
if args.index:
  symbolizer.build_index()

for line in sys.stdin:
  if sys.version_info.major < 3:
    line = line.decode('utf-8')
  save_access_address(line)
  if process_stack_history(line, symbolizer, ignore_tags=args.ignore_tags):
    continue
  symbolize_line(line, symbolizer_path)
