#!/usr/bin/env python
# coding: utf-8
"""
Analyzes the JSON file produced by a Preview Render with Profiling to produce
sorted data sets.
"""

import copy
from functools import reduce
import json
import optparse
import string
import sys


class ColumnData(object):
    """A namespace for static column data."""
    def __init__(self):
        pass

    OP_ID = 0
    OP_NAME = 1
    OP_TYPE = 2
    NODE_NAME = 3
    NODE_TYPE = 4
    CPU_TIME = 5
    MEMORY_USED = 6

    JSON_KEYS = {
        OP_ID: "opId",
        OP_NAME: "opName",
        OP_TYPE: "opType",
        NODE_NAME: "nodeName",
        NODE_TYPE: "nodeType",
        CPU_TIME: "cpuTime",
        MEMORY_USED: "memoryUsed"
    }
    STRING_NAMES = {
        OP_ID: "Op ID",
        OP_NAME: "Op Name",
        OP_TYPE: "Op Type",
        NODE_NAME: "Node Name",
        NODE_TYPE: "Node Type",
        CPU_TIME: "CPU Time",
        MEMORY_USED: "Memory Used"
    }
    MAX_WIDTHS = {
        OP_ID: 0,
        OP_NAME: 0,
        OP_TYPE: 0,
        NODE_NAME: 0,
        NODE_TYPE: 0,
        CPU_TIME: 0,
        MEMORY_USED: 0
    }
    REVERSED_BY_DEFAULT = {
        OP_ID: False,
        OP_NAME: False,
        OP_TYPE: False,
        NODE_NAME: False,
        NODE_TYPE: False,
        CPU_TIME: True,
        MEMORY_USED: True
    }


class ProfilingResult(object):
    """An individual profiling result."""
    def __init__(self, jsonObj=None):
        self.opId = int(jsonObj["opId"])
        self.opName = jsonObj["opName"]
        self.opType = jsonObj["opType"]
        self.nodeName = jsonObj["nodeName"]
        self.nodeType = jsonObj["nodeType"]
        self.cpuTime = float(jsonObj["cpuTime"])
        self.memoryUsed = int(jsonObj["memoryUsed"])


class ProfilingResults(object):
    """A collection of (maybe grouped and sorted) profiling results."""
    def __init__(self, jsonDoc):
        self.jsonDoc = jsonDoc["ops"]
        self.results = self.parseJsonDoc()
        self.sortByKey = None
        self.reverse = False
        self.groupByKey = None

    def parseJsonDoc(self):
        """Read the report JSON object."""
        return [ProfilingResult(e) for e in self.jsonDoc]

    def sortBy(self, key, reverse):
        """Sort the results by the given key."""
        self.sortByKey = key
        self.reverse = reverse

        def sortFunc(e):
            """Sort functor"""
            value = getattr(e, self.sortByKey)
            if isinstance(value, str):
                value = value.lower()
            return value

        self.results.sort(key=sortFunc, reverse=reverse)

    def groupBy(self, key):
        """Group the results by the given key."""
        self.results = self.parseJsonDoc()
        self.groupByKey = key

        if self.groupByKey is not None:
            # Accumulate results with the same key.
            accumulated = {}
            for e in self.results:
                value = getattr(e, self.groupByKey)
                if value in accumulated:
                    ee = accumulated[value]
                    if ee.opType != e.opType:
                        ee.opType = "(multiple values)"
                    if ee.nodeName != e.nodeName:
                        ee.nodeName = "(multiple values)"
                    if ee.nodeType != e.nodeType:
                        ee.nodeType = "(multiple values)"
                    accumulated[value].cpuTime += e.cpuTime
                    accumulated[value].memoryUsed += e.memoryUsed
                else:
                    accumulated[value] = copy.deepcopy(e)

            self.results = list(accumulated.values())
            if self.sortByKey:
                self.sortBy(self.sortByKey, self.reverse)


def PrettyPrintDuration(secs, _):
    """Format a number of seconds as milliseconds."""
    return "%10.5f ms" % (secs * 1000.0)

def PrettyPrintMemory(totalBytes, humanReadable):
    """Convert a number of bytes to an appropriate unit."""
    result = ""
    if humanReadable:
        unitIdx = 0
        units = ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]

        while (unitIdx < (len(units)-1)) and (totalBytes >= 1024):
            unitIdx += 1
            totalBytes /= 1024.0

        result = "%10.2f %s" % (totalBytes, units[unitIdx])
    else:
        result = "%10d B" % totalBytes

    return result


class ProfilingReport(object):
    def __init__(self, results, columns, humanReadable, limit):
        self.results = results
        self.columns = columns
        self.humanReadable = humanReadable
        self.limit = limit

    def __repr__(self):
        output = [None] * len(ColumnData.JSON_KEYS)
        columnWidths = [0] * len(ColumnData.STRING_NAMES)

        columnTransforms = [None] * len(ColumnData.JSON_KEYS)
        columnTransforms[ColumnData.CPU_TIME] = PrettyPrintDuration
        columnTransforms[ColumnData.MEMORY_USED] = PrettyPrintMemory

        maxRows = len(self.results.results)
        if self.limit != 0:
            maxRows = min(maxRows, self.limit)

        columnLookup = {v: k for k, v in ColumnData.JSON_KEYS.items()}
        columnIdxs = []
        for k in self.columns:
            if k in columnLookup:
                columnIdxs.append(columnLookup[k])
            else:
                raise Exception("Unknown column name '%s'." % k)

        # Extract results and calculate column widths.
        for resultIdx, result in enumerate(self.results.results):
            if resultIdx >= maxRows:
                break

            for columnIdx in columnIdxs:
                jsonKey = ColumnData.JSON_KEYS[columnIdx]
                if not jsonKey in self.columns:
                    continue

                value = getattr(result, jsonKey)
                valueStr = None
                if columnTransforms[columnIdx]:
                    valueStr = columnTransforms[columnIdx](
                        value, self.humanReadable)
                else:
                    valueStr = str(value)

                maxLength = ColumnData.MAX_WIDTHS[columnIdx]
                if (maxLength > 0) and (len(valueStr) > maxLength):
                    valueStr = valueStr[0:maxLength-3] + "..."

                if output[columnIdx] is None:
                    output[columnIdx] = [None] * maxRows

                output[columnIdx][resultIdx] = valueStr
                columnWidths[columnIdx] \
                    = max(columnWidths[columnIdx], len(valueStr))

        def getColumnSeparator(columnIdx):
            return " | " if (columnIdx != columnIdxs[-1]) else ""

        # Print header.
        headerStr = ""
        for columnIdx in columnIdxs:
            jsonKey = ColumnData.JSON_KEYS[columnIdx]
            if not jsonKey in self.columns:
                continue

            name = ColumnData.STRING_NAMES[columnIdx]
            if jsonKey == self.results.sortByKey:
                if self.results.reverse:
                    name += " /\\"
                else:
                    name += " \\/"
            if jsonKey == self.results.groupByKey:
                name += " (g)"
            columnWidths[columnIdx] = max(columnWidths[columnIdx], len(name))
            headerStr += "%*s%s" % (
                columnWidths[columnIdx], name, getColumnSeparator(columnIdx))

        # The divider must span all columns plus spacing between.
        totalPaddingWidth = 3 * (len(self.columns) - 1)
        totalColumnWidth = reduce(lambda a, b: a+b, columnWidths)
        dividerStr = "-" * (totalPaddingWidth + totalColumnWidth)

        numRows = 0
        for col in output:
            if col:
                numRows = len(col)
                break

        # Build results string.
        resultStr = "\n%s\n%s\n%s\n" % (dividerStr, headerStr, dividerStr)
        for resultIdx in range(0, numRows):
            for columnIdx in columnIdxs:
                jsonKey = ColumnData.JSON_KEYS[columnIdx]
                if not jsonKey in self.columns:
                    continue

                resultStr += "%*s%s" % (
                    columnWidths[columnIdx], output[columnIdx][resultIdx],
                    getColumnSeparator(columnIdx))
            resultStr += "\n"

        resultStr += dividerStr + "\n"

        return resultStr


def buildParser(validColumns):
    """Construct the option parser and return it."""
    parser = optparse.OptionParser(
        usage="%prog JSON_DOCUMENT [options]",
        description=__doc__.strip(),
        add_help_option=False)

    parser.add_option(
        "--help",
        action="store_true",
        help="Show this help message and exit")
    parser.add_option(
        "--sort-by",
        help="Column to sort by, from [%s]" % validColumns,
        default=None)
    parser.add_option(
        "-r", "--reverse",
        action="store_true",
        help="Results are sorted in reverse order",
        default=False)
    parser.add_option(
        "--group-by",
        help="Column to group by, from [%s]" % validColumns,
        default=None)
    parser.add_option(
        "-h", "--human-readable",
        action="store_true",
        help="Print memory totals in human-readable units",
        default=False)
    parser.add_option(
        "-l", "--limit",
        help="Limit the maximum number of results printed",
        default=0)
    parser.add_option(
        "--columns",
        help="Comma-separated list of column names to show, from [%s]" % (
            validColumns),
        default=None)

    return parser


def main():
    columnNames = list(ColumnData.JSON_KEYS.values())
    validColumns = string.join(columnNames, "|")
    parser = buildParser(validColumns)
    opts, args = parser.parse_args()
    if opts.help:
        parser.print_help()
        sys.exit(0)

    if len(args) != 1:
        raise Exception("Expected path to JSON document as first argument")

    columns = []
    if opts.columns:
        incomingColumns = opts.columns.split(",")
        for c in incomingColumns:
            if c in columnNames:
                columns.append(c)
            else:
                fmt = "Invalid value '%s' given to --columns; options are [%s]"
                raise Exception(fmt % (c, validColumns))
    else:
        columns = columnNames

    if opts.sort_by is None:
        opts.sort_by = columns[0]
        print('No sort criterion set; defaulting to "%s"' % opts.sort_by)
    elif not opts.sort_by in columnNames:
        fmt = "Invalid value '%s' given for --sort-by; options are [%s]"
        raise Exception(fmt % (opts.sort_by, validColumns))

    if (opts.group_by is not None) and (not opts.group_by in columnNames):
        fmt = "Invalid value '%s' given for --group-by; options are [%s]"
        raise Exception(fmt % (opts.group_by, validColumns))

    # Determine whether the sort column is reversed by default.
    for k in ColumnData.JSON_KEYS:
        if opts.sort_by == ColumnData.JSON_KEYS[k]:
            if ColumnData.REVERSED_BY_DEFAULT[k]:
                opts.reverse = not opts.reverse
            break

    jsonDoc = json.load(open(args[0]))
    results = ProfilingResults(jsonDoc)
    results.groupBy(opts.group_by)
    results.sortBy(opts.sort_by, opts.reverse)

    report = ProfilingReport(results, columns, opts.human_readable,
                             int(opts.limit))
    print(report)

if __name__ == "__main__":
    main()
