# Copyright (c) 2018 The Foundry Visionmongers Ltd. All Rights Reserved.


import logging
import os

import LookFileBakeAPI
from Katana import FnAttribute

log = logging.getLogger("UsdLookFileBakeOutputFormat")

# [USD install]/lib/python needs to be on $PYTHONPATH for this import to work
pxrImported = False
try:
    from pxr import (
        Usd,
        UsdShade,
        Sdf,
    )
    pxrImported = True
except:
    try:
        from fnpxr import (
            Usd,
            UsdShade,
            Sdf,
        )
        pxrImported = True
    except ImportError as e:
        log.warning('Error while importing pxr module (%s). Is '
                    '"[USD install]/lib/python" in PYTHONPATH?', e.message)

BaseOutputFormat = LookFileBakeAPI.BaseLookFileBakeOutputFormat


class UsdLookFileBakeOutputFormat(BaseOutputFormat):
    """
    Class implementing a USD look file bake output format.
    """

    # Class Variables ---------------------------------------------------------

    DisplayName = "as USD"
    FileExtension = ""
    PassFileExtension = "usda"

    # Protected Class Methods -------------------------------------------------

    @classmethod
    def _WriteUsdAttribute(cls, usdPrimitive, attributeName, attribute):
        """
        Writes the given attribute to the given USD Primitive.

        @type usdPrimitive: C{Usd.Prim}
        @type attributeName: C{str}
        @type attribute: C{FnAttribute}
        @param usdPrimitive: The USD Primitive for which to write attributes.
        @param attributeName: The name of the attribute to set.
        @param attribute: The Katana attribute to write to the given USD
            primitive.
        """
        if isinstance(attribute, FnAttribute.DataAttribute):
            typeName = None
            if isinstance(attribute, FnAttribute.IntAttribute):
                typeName = "Int"
            elif isinstance(attribute, FnAttribute.FloatAttribute):
                typeName = "Float"
            elif isinstance(attribute, FnAttribute.DoubleAttribute):
                typeName = "Double"
            elif isinstance(attribute, FnAttribute.StringAttribute):
                typeName = "String"

            if typeName is not None:
                attributeValue = attribute.getValue()
                usdPrimitive.SetCustomDataByKey(attributeName, attributeValue)

        elif isinstance(attribute, FnAttribute.GroupAttribute):
            # Recursively call this function to write the attributes for its
            # children
            for childAttributeName, childAttribute in attribute.childList():
                cls._WriteUsdAttribute(
                    usdPrimitive,
                    "{0}:{1}".format(attributeName, childAttributeName),
                    childAttribute)

    # Instance Methods --------------------------------------------------------

    def writeSinglePass(self, passData):
        """
        @type passData: C{LookFileBakeAPI.LookFilePassData}
        @rtype: C{list} of C{str}
        @param passData: The data representing a single look file pass to be
            baked.
        @return: A list of paths to files which have been written.
        """
        # Get the file path for this pass from the given pass data
        filePath = passData.filePath

        # If the enclosing directory doesn't exist, then try to create it
        LookFileBakeAPI.CreateLookFileDirectory(os.path.dirname(filePath))

        # Create a new USD stage
        stage = Usd.Stage.CreateNew(filePath)

        # Iterate over materials
        for materialLocationPath, (_, materialAttribute) in \
                passData.materialDict.items():
            # Discard materials if the path is not a valid SdfPath.
            if not Sdf.Path.IsValidPathString(materialLocationPath):
                log.warning('"%s" is not a valid SdfPath. Material will '
                            'be skipped.', materialLocationPath)
                continue

            # Create a material
            materialSdfPath = Sdf.Path(materialLocationPath)
            material = UsdShade.Material.Define(stage, materialSdfPath)

            # If the material has a "nodes" attribute, then it represents a
            # network material. A sensible extension to this plug-in would be
            # to create separate USD materials for each node in the material
            # network, and connect them together. Examples of this behaviour
            # are provided in the USD distro, e.g.
            #
            # [USD]/pxr/usd/lib/usdShade/testenv/testUsdShadeNodeGraphs.py
            #
            # A further extension could be to store all the pass info in
            # instance variables, and use the postProcess() method to store the
            # passes as USD shading variants. See:
            #
            # [USD]/extras/usd/tutorials/authoringVariants/authorVariants.py

            # Get the primitive
            materialPrim = material.GetPrim()

            # Convert the material attribute from FnAttribute to USD
            self.__class__._WriteUsdAttribute(materialPrim, "material",
                                              materialAttribute)

        # Save the stage
        stage.GetRootLayer().Save()

        return [filePath]

# Only register the output format if the pxr module has been imported
# successfully.
if pxrImported:
    LookFileBakeAPI.RegisterOutputFormat(UsdLookFileBakeOutputFormat)
