# -*- coding: utf-8 -*-
#
#*******************************************************************************
#
#  Copyright 2022 RIEGL Laser Measurement Systems
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
#  SPDX-License-Identifier: Apache-2.0
#
#*******************************************************************************
#
"""
Point attribute buffer class
"""

from . import datatypes

RDB_HAVE_NUMPY = True
try:
    # noinspection PyUnresolvedReferences
    import numpy
except ImportError:
    RDB_HAVE_NUMPY = False
if not RDB_HAVE_NUMPY:
    import array
    import ctypes


class AttributeBuffer:
    """
    Point attribute buffer class
    """

    def __init__(self, point_attribute, point_count):
        """
        Create new point attribute buffer

        Args:
            point_attribute: of type riegl.rdb.pointattribute.PointAttribute
            point_count: size of buffer in terms of number of points
        """

        # basic initialization
        self.point_count = point_count
        self.point_attribute = point_attribute
        self.total_length = point_attribute.length * point_count
        self.data_type = self._get_data_type()
        self.type_code = self._get_type_code()
        self.item_size = self._get_item_size()
        self.data = self._create_buffer()

        # initialize item getter and setter
        n = self.point_attribute.length
        if n == 1 or RDB_HAVE_NUMPY:  # one-dimensional or NumPy is available:
            def getitem(self, index):
                return self.data[index]

            def setitem(self, index, value):
                self.data[index] = value

        else:  # multi-dimensional and NumPy not available:
            def getitem(self, index):
                return tuple(self.data[index * n: (index + 1) * n]),

            def setitem(self, index, value):
                for i in range(n):
                    self.data[index * n + i] = value[i]

        self.__class__.__getitem__ = getitem
        self.__class__.__setitem__ = setitem

        # check if buffer data type matches our expectations...
        if (not RDB_HAVE_NUMPY) and (self.data.itemsize > self.item_size):
            # In this case, the 'array' module created an array of elements that
            # are larger than requested. So we must adjust the type information
            # and update the type code and item size based on the new type:
            if datatypes.is_signed_integer(self.data_type):
                if self.data.itemsize == 1:
                    self.data_type = datatypes.DataType.INT8
                elif self.data.itemsize == 2:
                    self.data_type = datatypes.DataType.INT16
                elif self.data.itemsize == 4:
                    self.data_type = datatypes.DataType.INT32
                elif self.data.itemsize == 8:
                    self.data_type = datatypes.DataType.INT64
            elif datatypes.is_unsigned_integer(self.data_type):
                if self.data.itemsize == 1:
                    self.data_type = datatypes.DataType.UINT8
                elif self.data.itemsize == 2:
                    self.data_type = datatypes.DataType.UINT16
                elif self.data.itemsize == 4:
                    self.data_type = datatypes.DataType.UINT32
                elif self.data.itemsize == 8:
                    self.data_type = datatypes.DataType.UINT64
            self.type_code = self._get_type_code()
            self.item_size = self._get_item_size()
        assert self.data.itemsize == self.item_size

    def _get_data_type(self):
        """Determine best fitting data type for attribute"""
        return self.point_attribute.suggest_data_type()

    # noinspection PyPep8
    def _get_type_code(self):
        """Translate RDB data type to array/numpy data type"""
        if RDB_HAVE_NUMPY:
            if self.data_type == datatypes.DataType.UINT8:  return numpy.uint8
            if self.data_type == datatypes.DataType.INT8:   return numpy.int8
            if self.data_type == datatypes.DataType.UINT16: return numpy.uint16
            if self.data_type == datatypes.DataType.INT16:  return numpy.int16
            if self.data_type == datatypes.DataType.UINT32: return numpy.uint32
            if self.data_type == datatypes.DataType.INT32:  return numpy.int32
            if self.data_type == datatypes.DataType.UINT64: return numpy.uint64
            if self.data_type == datatypes.DataType.INT64:  return numpy.int64
            if self.data_type == datatypes.DataType.SINGLE: return numpy.float32
            if self.data_type == datatypes.DataType.DOUBLE: return numpy.float64
            raise RuntimeError("Attribute data type not supported.")
        else:  # use standard array
            if self.data_type == datatypes.DataType.UINT8:  return "B"
            if self.data_type == datatypes.DataType.INT8:   return "b"
            if self.data_type == datatypes.DataType.UINT16: return "H"
            if self.data_type == datatypes.DataType.INT16:  return "h"
            if self.data_type == datatypes.DataType.UINT32: return "L"
            if self.data_type == datatypes.DataType.INT32:  return "l"
            if self.data_type == datatypes.DataType.UINT64: return "Q"
            if self.data_type == datatypes.DataType.INT64:  return "q"
            if self.data_type == datatypes.DataType.SINGLE: return "f"
            if self.data_type == datatypes.DataType.DOUBLE: return "d"
            raise RuntimeError("Attribute data type not supported.")

    # noinspection PyPep8
    def _get_item_size(self):
        """Determine expected item size"""
        if self.data_type == datatypes.DataType.UINT8:  return 1
        if self.data_type == datatypes.DataType.INT8:   return 1
        if self.data_type == datatypes.DataType.UINT16: return 2
        if self.data_type == datatypes.DataType.INT16:  return 2
        if self.data_type == datatypes.DataType.UINT32: return 4
        if self.data_type == datatypes.DataType.INT32:  return 4
        if self.data_type == datatypes.DataType.UINT64: return 8
        if self.data_type == datatypes.DataType.INT64:  return 8
        if self.data_type == datatypes.DataType.SINGLE: return 4
        if self.data_type == datatypes.DataType.DOUBLE: return 8
        raise RuntimeError("Attribute data type not supported.")

    def _create_buffer(self):
        """Create actual data buffer"""
        if RDB_HAVE_NUMPY:
            n = self.point_attribute.length
            shape = self.point_count if n == 1 else (self.point_count, n)
            return numpy.zeros(shape, self.type_code)
        else:
            return array.array(self.type_code, self.total_length * [0])

    def resize(self, point_count):
        """
        Change size of buffer

        Note: If the buffer was previously bound to a query, make sure that
              you rebind the buffer, since a resize can move the buffer to
              another memory location.
        """
        if self.point_count != point_count:
            self.point_count = point_count
            self.total_length = self.point_attribute.length * point_count
            if RDB_HAVE_NUMPY:
                n = self.point_attribute.length
                shape = self.point_count if n == 1 else (self.point_count, n)
                self.data.resize(shape, refcheck=False)
            else:
                self.data = array.array(
                    self.type_code, self.data[0:self.total_length]
                )

    def raw_data(self):
        """Return ctypes pointer to raw data"""
        if RDB_HAVE_NUMPY:
            return self.data.ctypes
        else:
            return ctypes.c_void_p(self.data.buffer_info()[0])

    def __len__(self):
        return self.point_count

    def __repr__(self):
        return repr(self.data)

    def __str__(self):
        return str(self.data)

    def __iter__(self):

        class Iterator:
            def __init__(self, buffer):
                self.buffer = buffer
                self.index = 0

            def __iter__(self):
                return self

            def __next__(self):
                if self.index < self.buffer.point_count:
                    result = self.buffer[self.index]
                    self.index += 1
                    return result
                else:
                    raise StopIteration()

        return Iterator(self)
