// Copyright (c) 2018 The Foundry Visionmongers Ltd. All Rights Reserved.

#include "ProfilingMockRenderer/ProfilingMockRenderPlugin.h"

#include <atomic>
#include <chrono>
#include <cstdio>
#include <memory>
#include <string>
#include <queue>
#include <thread>
#include <utility>
#include <vector>

#include <tbb/parallel_do.h>
#include <tbb/task_arena.h>

#include "FnAttribute/FnAttribute.h"

namespace
{
/**
 * \brief Holds a copy of the options that can be passed to the render plug-in.
 * See \em ProfilingMockRendererGlobalSettings.xml for documentation.
 */
struct Config
{
    std::string traversalStrategy;
    int maxLogDepth;
    int maxThreads;
};

/**
 * \brief A utility structure that holds a location name, its scene graph
 * iterator, and depth level in which it is found. Used by the parallel search
 * and breadth-first search expansion modes.
 */
struct WorkItem
{
    int depth;
    std::string name;
    FnKat::FnScenegraphIterator it;
};


/**
 * \brief Implements the depth-first search strategy.
 */
class DfsTreeWalker
{
public:
    explicit DfsTreeWalker(const Config& cfg)
        : m_cfg(cfg), m_start(std::chrono::steady_clock::now()), m_count()
    {
    }

    int getNumLocationsProcessed() const { return m_count; }

    void traverse(const FnKat::FnScenegraphIterator& iterator, int depth = 0)
    {
        ++m_count;

        // Only print a log message if the maximum depth has not been exceeded.
        if (depth < m_cfg.maxLogDepth)
        {
            const auto duration = std::chrono::steady_clock::now() - m_start;

            fprintf(stderr,
                ">> Depth-first search: %.2fs %.2f ms/loc %10d %s\n",
                std::chrono::duration<double>(duration).count(),
                std::chrono::duration<double, std::milli>(duration).count() /
                    m_count,
                m_count, iterator.getFullName().c_str());
        }

        // now recurse on children in found order
        const FnKat::StringAttribute pcAttr = iterator.getPotentialChildren();
        for (const auto& child : pcAttr.getNearestSample(.0f))
        {
            const FnKat::FnScenegraphIterator childIter =
                iterator.getChildByName(child);
            traverse(childIter, depth + 1);
        }
    }

private:
    const Config& m_cfg;
    const std::chrono::steady_clock::time_point m_start;
    int m_count;
};

/**
 * \brief Implements the breadth-first search strategy.
 */
class BfsTreeWalker
{
public:
    explicit BfsTreeWalker(const Config& cfg)
        : m_cfg(cfg), m_count(0)
    {
    }

    void traverse(const FnKat::FnScenegraphIterator& root)
    {
        const auto start(std::chrono::steady_clock::now());

        // Populate the queue with the root element.
        std::queue<WorkItem> q;
        q.push({0, ".", root});

        while (!q.empty())
        {
            // Pop the next item.
            const WorkItem elem = q.front();
            q.pop();
            ++m_count;

            // Get the scenegraph iterator for this child location.
            const FnKat::FnScenegraphIterator it =
                elem.it.getChildByName(elem.name, true);

            // Only print a log message if the maximum depth has not been
            // exceeded.
            const int depth = elem.depth;
            if (depth < m_cfg.maxLogDepth)
            {
                const auto duration = std::chrono::steady_clock::now() - start;

                fprintf(stderr,
                    ">> Breadth-first search: %.2fs %.2f ms/loc %10d %s\n",
                    std::chrono::duration<double>(duration).count(),
                    std::chrono::duration<double, std::milli>(duration)
                            .count() / m_count, m_count,
                        it.getFullName().c_str());
            }

            // Enqueue all its potential children.
            const FnKat::StringAttribute pcAttr = it.getPotentialChildren();
            for (const auto& child : pcAttr.getNearestSample(.0f))
            {
                q.push({depth + 1, child, it});
            }
        }
    }

    int getNumLocationsProcessed() const { return m_count; }

private:
    const Config& m_cfg;
    int m_count;
};

/**
 * \brief Implements the parallel search strategy.
 */
class ParallelTreeWalker
{
public:
    explicit ParallelTreeWalker(const Config& cfg)
        : m_cfg(cfg),
          m_start(std::chrono::steady_clock::now()),
          m_threads(m_cfg.maxThreads == tbb::task_arena::automatic
                        ? "auto"
                        : std::to_string(m_cfg.maxThreads)),
          m_padding(static_cast<int>(
              m_cfg.maxThreads == tbb::task_arena::automatic
                  ? 0
                  : m_threads.size())),
          m_arena(m_cfg.maxThreads),
          m_count(0)
    {
    }

    int getNumLocationsProcessed() const
    {
        return m_count.load(std::memory_order_relaxed);
    }

    void traverse(const FnKat::FnScenegraphIterator& root)
    {
        m_arena.execute([&]()
        {
            const WorkItem rootElem{0, ".", root};
            tbb::parallel_do(&rootElem, &rootElem + 1,
                [&](const WorkItem& elem,
                       tbb::parallel_do_feeder<WorkItem>& feeder)
                {
                    explore_and_feed(elem, feeder);
                });
        });
    }

private:
    void explore_and_feed(const WorkItem& elem,
        tbb::parallel_do_feeder<WorkItem>& feeder)
    {
        m_count.fetch_add(1, std::memory_order_relaxed);

        const FnKat::FnScenegraphIterator it =
            elem.it.getChildByName(elem.name, true);

        // Only print a log message if the maximum depth has not been exceeded.
        const int depth = elem.depth;
        if (depth < m_cfg.maxLogDepth)
        {
            const auto duration = std::chrono::steady_clock::now() - m_start;
            const int count = m_count.load(std::memory_order_relaxed);

            const int threadIndex = m_arena.current_thread_index();

            // Note that tbb::this_task_arena::current_thread_index() and
            // tbb::task_arena::max_concurrency() are not available in
            // TBB 4.4 U6, but they could be used in newer versions of TBB.

            fprintf(stderr,
                ">> Parallel search (%s threads, #%0*d): %.2fs %.2f "
                "ms/loc %10d %s\n",
                m_threads.c_str(),
                m_padding,
                threadIndex,
                std::chrono::duration<double>(duration).count(),
                std::chrono::duration<double, std::milli>(duration)
                        .count() / count, count, it.getFullName().c_str());
        }

        const FnKat::StringAttribute pcAttr = it.getPotentialChildren();
        for (const auto& child : pcAttr.getNearestSample(.0f))
        {
            feeder.add({depth + 1, child, it});
        }
    }

private:
    const Config& m_cfg;
    const std::chrono::steady_clock::time_point m_start;
    const std::string m_threads;
    const int m_padding;
    tbb::task_arena m_arena;
    std::atomic<int> m_count;
};
}  // namespace


namespace ProfilingMockRenderer
{
RenderPlugin::RenderPlugin(FnKat::FnScenegraphIterator rootIterator,
                           FnKat::GroupAttribute arguments)
    : RenderBase(rootIterator, arguments)
{
    // no extra initialisation required
}

RenderPlugin::~RenderPlugin()
{
}

FnKat::Render::RenderBase* RenderPlugin::create(
    FnKat::FnScenegraphIterator rootIterator,
    FnAttribute::GroupAttribute arguments)
{
    return new RenderPlugin(rootIterator, arguments);
}

int RenderPlugin::start()
{
    // These default values match the default values specified in the XML.
    static constexpr char kTraversalStrategyDefault[] = "dfs";
    constexpr int kMaxLogDepthDefault = 5;

    // Get the configured options from /root.
    const FnKat::FnScenegraphIterator rootIt = getRootIterator();
    const FnKat::GroupAttribute optionsAttr =
        rootIt.getAttribute("profilingMockRendererGlobalStatements.options");

    // Populate the config structure based on the given options.
    const Config cfg{
        // traversalStrategy
        FnKat::StringAttribute(optionsAttr.getChildByName("traversalStrategy"))
            .getValue(kTraversalStrategyDefault, false),

        // maxLogDepth,
        FnKat::IntAttribute(optionsAttr.getChildByName("maxLogDepth"))
            .getValue(kMaxLogDepthDefault, false),

        // maxThreads
        [&]() {
            const FnKat::IntAttribute maxThreadsAttr =
                optionsAttr.getChildByName("maxThreads");
            int value = maxThreadsAttr.getValue(0, false);

            const int hardwareConcurrency =
                static_cast<int>(std::thread::hardware_concurrency());

            // Clamp if greater than the maximum number of threads that can be
            // used.
            if (hardwareConcurrency > 0 && hardwareConcurrency < value)
            {
                value = hardwareConcurrency;
            }
            // If not valid, try to autodetermine it, or set to automatic.
            else if (value <= 0)
            {
                if (hardwareConcurrency > 0)
                {
                    value = hardwareConcurrency;
                }
                else
                {
                    value = tbb::task_arena::automatic;
                }
            }

            return value;
        }(),
    };

    const auto begin = std::chrono::steady_clock::now();
    int numLocations = 0;

    // Traverse the scene following the strategy as defined.
    if (cfg.traversalStrategy == "dfs")
    {
        DfsTreeWalker treeWalker(cfg);
        treeWalker.traverse(rootIt);
        numLocations = treeWalker.getNumLocationsProcessed();
    }
    else if (cfg.traversalStrategy == "bfs")
    {
        BfsTreeWalker treeWalker(cfg);
        treeWalker.traverse(rootIt);
        numLocations = treeWalker.getNumLocationsProcessed();
    }
    else if (cfg.traversalStrategy == "parallel")
    {
        ParallelTreeWalker treeWalker(cfg);
        treeWalker.traverse(rootIt);
        numLocations = treeWalker.getNumLocationsProcessed();
    }
    else
    {
        throw std::runtime_error("Invalid traversal strategy");
    }

    const auto elapsed = std::chrono::steady_clock::now() - begin;
    const double dElapsed = std::chrono::duration<double>(elapsed).count();
    const double locsPerSecond = numLocations / dElapsed;

    fprintf(stderr,
            "ProfilingMockRenderer: Processed %d locations in %.2fs. Avg. %.2f"
            " locs/second.\n", numLocations, dElapsed, locsPerSecond);

    // Let the Runtime deallocate resources and cease further computation.
    // Resource deallocation can take time and so elements of it are
    // completed asynchronously.
    //
    // Typically renderer plug-ins want to invoke this function if they will
    // need lots of memory available for rendering purposes. If the renderer
    // determines that not so much memory is needed for the current scene, there
    // is no need to finalize the Runtime.
    rootIt.finalizeRuntime();

    return 0;
}

int RenderPlugin::stop()
{
    // do nothing
    return 0;
}

void RenderPlugin::configureDiskRenderOutputProcess(
    FnKat::Render::DiskRenderOutputProcess& diskRenderOutputProcess,
    const std::string& outputName,
    const std::string& outputPath,
    const std::string& renderMethodName,
    const float& frameTime) const
{
    std::unique_ptr<FnKat::Render::RenderAction> renderAction(
        new FnKat::Render::RenderAction(outputPath));
    renderAction->setLoadOutputInMonitor(false);
    diskRenderOutputProcess.setRenderAction(std::move(renderAction));
}
}  // namespace ProfilingMockRenderer
