# -*- coding: utf-8 -*-
#
#*******************************************************************************
#
#  Copyright 2022 RIEGL Laser Measurement Systems
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
#  SPDX-License-Identifier: Apache-2.0
#
#*******************************************************************************
#
"""
Point statistics query
"""

import weakref
from ctypes import cast, byref, c_void_p, c_uint32, c_uint64, Structure, POINTER
from warnings import warn

from . import attributebuffer
from . import library
from . import utilities


class GraphNode:
    """
    Graph Node

    This class represents an index graph node. The index structure is
    used to organize the point cloud and consists of at least one node.

    The index space is given by the primary point attribute defined
    during point cloud database creation (see class CreateSettings).

    Each graph node covers a certain range of the index space and
    has a number of sub-nodes (aka. "child nodes"). All child nodes
    lie within the range of the parent node whereas they are usually
    smaller. A node without child nodes is called a leaf node. Note
    that any nodes may overlap in index space as well as all other
    point dimensions (attributes).

    This documentation uses the term "branch" for a node and all
    children and grandchildren up to and including the leaf nodes.
    """

    class Revision(int):
        """See constructor of GraphNode"""
        def __call__(self, name):
            """
            Attribute revision

            Provides the ID of the last transaction that has modified the attribute
            in any node of the branch.
            """
            node = self.__dict__["_node"]()
            return node._query().revision(name, node.id)

    def __init__(self, parent, query):
        self.parent = weakref.ref(parent) if parent is not None else None
        """node's parent graph node"""

        self.id = None
        """unique node identifier (zero is invalid)"""

        self.revision = GraphNode.Revision(0)
        """ID of last transaction that modified any attribute of this branch"""

        self.children = None
        """list of child nodes (without grandchildren)"""

        self.point_count_total = None
        """total number of points in all leaf nodes of the branch"""

        self.point_count_node = None
        """number of points in this node (see notes about LOD)"""

        # stat query for internal use
        self._query = weakref.ref(query)

    def __setattr__(self, name, value):
        if name == "revision":
            new_revision = GraphNode.Revision(value)
            new_revision.__dict__["_node"] = weakref.ref(self)
            self.__dict__[name] = new_revision
        else:
            super().__setattr__(name, value)

    def __getattr__(self, item):
        if item.startswith("revision_"):
            name = item.replace("revision_", "").replace("_", ".", 1)
            warn(str('Consider using revision("{0}") instead').format(name), DeprecationWarning, stacklevel=2)
            return self._query().revision(name, self.id)
        if item.startswith("minimum_"):
            name = item.replace("minimum_", "").replace("_", ".", 1)
            warn(str('Consider using minimum("{0}") instead').format(name), DeprecationWarning, stacklevel=2)
            return self._query().minimum(name, self.id)
        if item.startswith("maximum_"):
            name = item.replace("maximum_", "").replace("_", ".", 1)
            warn(str('Consider using maximum("{0}") instead').format(name), DeprecationWarning, stacklevel=2)
            return self._query().maximum(name, self.id)
        return super().__getattribute__(item)

    def minimum(self, name):
        """
        Lowest value of node

        Provides the attribute's minimum value of the branch (i.e. node and
        children). The return value is either a single value (scalar) or an
        array (vectors). The data type depends on the point attribute.
        """
        return self._query().minimum(name, self.id)

    def maximum(self, name):
        """
        Highest value of node

        Provides the attribute's maximum value of the branch (i.e. node and
        children). The return value is either a single value (scalar) or an
        array (vectors). The data type depends on the point attribute.
        """
        return self._query().maximum(name, self.id)

    def select(self, selection=None, attributes=None):
        """
        Select points of node

        Details see riegl.rdb.pointcloud.Pointcloud.select()
        """
        return self._query().pointcloud.select(selection, attributes, self.id)

    def points(self, selection=None, attributes=None):
        """
        Select points of node point by point

        Details see riegl.rdb.pointcloud.Pointcloud.points()
        """
        return self._query().pointcloud.points(selection, attributes, self.id)


class QueryStat:
    """
    Point statistics query

    This query provides point attribute statistics like minimum and
    maximum value.

    Note: You either must delete the query object or call close()
          __before__ the parent Pointcloud instance is closed/deleted!
    """

    def __init__(self, pointcloud):
        """
        Constructor

        Creates a query prepared for reading stats.
        """
        self.context = pointcloud.context
        self.pointcloud = pointcloud
        self.handle = c_void_p(None)
        self.context.check(
            library.handle.rdb_pointcloud_query_stat_new(
                self.context.handle,
                self.pointcloud.handle,
                byref(self.handle)
            )
        )
        self._root_node = None  # set by 'index()' using 'self._read_graph()'

    def __del__(self):
        self.close()

    def __enter__(self):
        return self

    # noinspection PyUnusedLocal
    def __exit__(self, exc_type, exc_value, traceback):
        self.close()

    def __getattr__(self, item):
        if item.startswith("revision_"):
            return getattr(self.index, item)
        if item.startswith("minimum_"):
            return getattr(self.index, item)
        if item.startswith("maximum_"):
            return getattr(self.index, item)
        if item == "index":
            return self.get_index()
        if item == "leaves":
            return self.get_leaves()
        return super().__getattribute__(item)

    @property
    def point_count_total(self):
        """Total number of points in point cloud"""
        return self.index.point_count_total

    def _read_graph(self, selection=None):

        class GraphNodeStructure(Structure):
            """C-API helper class"""
            _fields_ = [
                ("id", c_uint32),
                ("revision", c_uint32),
                ("children", c_void_p),
                ("child_count", c_uint32),
                ("point_count_total", c_uint64),
                ("point_count_node", c_uint64)
            ]
            _pack_ = 1

        def read_node(node, parent, query):
            # read basic graph node attributes
            assert node
            result = GraphNode(parent, query)
            result.id = node.contents.id
            result.revision = node.contents.revision
            result.point_count_node = node.contents.point_count_node
            result.point_count_total = node.contents.point_count_total

            # read graph node children
            result.children = list()
            if node.contents.child_count > 0:
                children = cast(
                    node.contents.children,
                    POINTER(POINTER(GraphNodeStructure))
                )
                assert children
                for i in range(node.contents.child_count):
                    result.children.append(read_node(children[i], node, query))

            # finally return read graph node
            return result

        root = POINTER(GraphNodeStructure)()
        self.context.check(
            library.handle.rdb_pointcloud_graph_node_new(
                self.context.handle,
                byref(root)
            )
        )
        try:
            self.context.check(
                library.handle.rdb_pointcloud_query_stat_index_filter(
                    self.context.handle,
                    self.handle,
                    root,
                    utilities.to_rdb_string(selection)
                )
            )
            return read_node(root, None, self)
        finally:
            self.context.check(
                library.handle.rdb_pointcloud_graph_node_delete(
                    self.context.handle,
                    byref(root)
                )
            )

    @property
    def valid(self):
        """"
        Check if query is not null

        A null query cannot be used to read stats.

        Returns True if the query is not null
        """
        return self.handle != c_void_p(None)

    def close(self):
        """
        Finish query

        Call this function when done with reading stats.
        """
        if self.valid:
            library.handle.rdb_pointcloud_query_stat_delete(
                self.context.handle,
                byref(self.handle)
            )
            self.handle = c_void_p(None)

    def get_index(self, selection=None):
        """
        Get index graph

        This function returns a simple directed graph which represents
        the index structure that is used to organize the point cloud.

        The optional filter expression can be used to select particular
        graph nodes - if no filter is given, all nodes will be returned.
        Filter expression syntax see riegl::rdb::Pointcloud::select().

        Note: The reported point counts and attribute extents are not
        affected by the filter expressions given here or defined in the
        meta data item riegl.stored_filters.

        Details see description of class GraphNode.
        """
        self._root_node = self._read_graph(selection)
        return self._root_node

    def get_leaves(self, selection=None):
        """
        Get index graph leaves

        This function is similar to index() but instead of returning
        the graph root node, it returns a list of graph leaf nodes.
        """

        def scan_leaves(node, result):
            if len(node.children) > 0:  # intermediate node
                for child in node.children:
                    scan_leaves(child, result)
            else:  # leaf node
                result.append(node)
            return result

        return scan_leaves(self.get_index(selection), list())

    def minimum(self, point_attribute_name, node_id=None, cleaned=False):
        """
        Lowest value of node

        Provides the attribute's minimum value of a branch (i.e. node and
        children). If the node ID is zero, then the minimum value of all
        points is returned. The return value is either a single value
        (scalar) or an array (vectors). The data type depends on the
        point attribute.
        """
        if node_id is None:
            node_id = self.index.id
        point_attribute = self.pointcloud.point_attributes[point_attribute_name]
        buffer = attributebuffer.AttributeBuffer(point_attribute, 1)
        if cleaned:
            self.context.check(
                library.handle.rdb_pointcloud_query_stat_minimum_cleaned(
                    self.context.handle,
                    self.handle,
                    c_uint32(node_id),
                    utilities.to_rdb_string(point_attribute_name),
                    c_uint32(buffer.data_type.value),
                    buffer.raw_data()
                )
            )
        else:  # full:
            self.context.check(
                library.handle.rdb_pointcloud_query_stat_minimum(
                    self.context.handle,
                    self.handle,
                    c_uint32(node_id),
                    utilities.to_rdb_string(point_attribute_name),
                    c_uint32(buffer.data_type.value),
                    buffer.raw_data()
                )
            )
        return buffer[0]

    def maximum(self, point_attribute_name, node_id=None, cleaned=False):
        """
        Highest value of node

        Provides the attribute's maximum value of a branch (i.e. node and
        children). If the node ID is zero, then the maximum value of all
        points is returned. The return value is either a single value
        (scalar) or an array (vectors). The data type depends on the
        point attribute.
        """
        if node_id is None:
            node_id = self.index.id
        point_attribute = self.pointcloud.point_attributes[point_attribute_name]
        buffer = attributebuffer.AttributeBuffer(point_attribute, 1)
        if cleaned:
            self.context.check(
                library.handle.rdb_pointcloud_query_stat_maximum_cleaned(
                    self.context.handle,
                    self.handle,
                    c_uint32(node_id),
                    utilities.to_rdb_string(point_attribute_name),
                    c_uint32(buffer.data_type.value),
                    buffer.raw_data()
                )
            )
        else:  # full:
            self.context.check(
                library.handle.rdb_pointcloud_query_stat_maximum(
                    self.context.handle,
                    self.handle,
                    c_uint32(node_id),
                    utilities.to_rdb_string(point_attribute_name),
                    c_uint32(buffer.data_type.value),
                    buffer.raw_data()
                )
            )
        return buffer[0]

    def revision(self, point_attribute_name, node_id=None):
        """
        Attribute revision

        Provides the ID of the last transaction that has modified the attribute
        in any node of the given branch.
        """
        if node_id is None:
            node_id = self.index.id
        result = c_uint32(0)
        self.context.check(
            library.handle.rdb_pointcloud_query_stat_revision(
                self.context.handle,
                self.handle,
                c_uint32(node_id),
                utilities.to_rdb_string(point_attribute_name),
                byref(result)
            )
        )
        return result.value
