#!/usr/bin/env python
from __future__ import print_function

# Based on the `exitsnoop` script at https://github.com/iovisor/bcc/blob/master/tools/exitsnoop.py
# Based on the `execsnoop` script at https://github.com/iovisor/bcc/blob/master/tools/execsnoop.py

import argparse
import os
import platform
import re
import signal
import sys

from bcc import BPF
from collections import defaultdict
from datetime import datetime
from time import strftime

bpf_src = """
#include <uapi/linux/ptrace.h>
#include <linux/sched.h>

#define ARGSIZE 128
#define MAXARG  19

struct data_t {
    u8 isArgv;
    u64 start_time;
    u64 exit_time;
    u32 pid;
    u32 ppid;
    char comm[TASK_COMM_LEN];
    char argv[ARGSIZE];
};

BPF_PERF_OUTPUT(events);

TRACEPOINT_PROBE(sched, sched_process_exit)
{
    struct task_struct *task = (typeof(task))bpf_get_current_task();
    if (task->pid != task->tgid) { return 0; }

    struct data_t data = {};

    data.start_time = task->start_time,
    data.exit_time = bpf_ktime_get_ns(),
    data.pid = task->pid,
    data.ppid = task->real_parent->tgid,
    bpf_get_current_comm(&data.comm, sizeof(data.comm));

    events.perf_submit(args, &data, sizeof(data));
    return 0;
}

static int __submit_arg(struct pt_regs *ctx, void *ptr, struct data_t *data)
{
    bpf_probe_read_user(data->argv, sizeof(data->argv), ptr);
    events.perf_submit(ctx, data, sizeof(struct data_t));
    return 1;
}

static int submit_arg(struct pt_regs *ctx, void *ptr, struct data_t *data)
{
    const char *argp = NULL;
    bpf_probe_read_user(&argp, sizeof(argp), ptr);
    if (argp) {
        return __submit_arg(ctx, (void *)(argp), data);
    }
    return 0;
}

int syscall__execve(struct pt_regs *ctx,
    const char __user *filename,
    const char __user *const __user *__argv,
    const char __user *const __user *__envp)
{

    struct task_struct *task = (typeof(task))bpf_get_current_task();

    struct data_t data = {};
    data.pid = bpf_get_current_pid_tgid() >> 32;
    data.isArgv = 1;

    bpf_get_current_comm(&data.comm, sizeof(data.comm));
    __submit_arg(ctx, (void *)filename, &data);

    // skip first arg, as we submitted filename
    #pragma unroll
    for (int i = 1; i < MAXARG; i++) {
        if (submit_arg(ctx, (void *)&__argv[i], &data) == 0)
             goto out;
    }

    // handle truncated argument list
    char ellipsis[] = "...";
    __submit_arg(ctx, (void *)ellipsis, &data);
out:

    return 0;
}

int do_ret_sys_execve(struct pt_regs *ctx)
{
    struct task_struct *task = (typeof(task))bpf_get_current_task();

    struct data_t data = {};
    data.pid = bpf_get_current_pid_tgid() >> 32;
    bpf_get_current_comm(&data.comm, sizeof(data.comm));
    events.perf_submit(ctx, &data, sizeof(data));
    return 0;
}
"""

def _print_header():
    print("%-16s %-7s %-7s %-7s %s" %
              ("PCOMM", "PID", "PPID", "AGE(ms)", "ARGV"))

buffer = None
pidToArgv = defaultdict(list)

def _print_event(cpu, data, size): # callback
    """Print the exit event."""
    global buffer
    e = buffer["events"].event(data)

    comm = e.comm.decode()
    if comm == "3":
        # For absolutely unknown reasons, 'crun' appears as '3'.
        comm = "crun"

    if e.isArgv:
        pidToArgv[e.pid].append(e.argv)
        return

    if comm not in ["podman", "crun", "runc", "conmon", "netavark", "aardvark-dns"]:
        try:
            del(pidToArgv[e.pid])
        except Exception:
            pass
        return

    age = (e.exit_time - e.start_time) / 1e6
    argv = b' '.join(pidToArgv[e.pid]).replace(b'\n', b'\\n')
    print("%-16s %-7d %-7d %-7.2f %s" %
              (comm, e.pid, e.ppid, age, argv.decode()), end="")
    print()

    try:
        del(pidToArgv[e.pid])
    except Exception:
        pass

def snoop(bpf, event_handler):
    bpf["events"].open_perf_buffer(event_handler)
    while True:
        bpf.perf_buffer_poll()

def main():
    global buffer
    try:
        buffer = BPF(text=bpf_src)
        execve_fnname = buffer.get_syscall_fnname("execve")
        buffer.attach_kprobe(event=execve_fnname, fn_name="syscall__execve")
        _print_header()
        snoop(buffer, _print_event)
    except KeyboardInterrupt:
        print()
        sys.exit()

    return 0

if __name__ == '__main__':
    main()