#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# rule_engine/ast/binary/arithmetic.py
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of the project nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
import datetime
import functools
import operator
from typing import Any, Callable
from ... import errors
from ...types import DataType, coerce_value
from ...types import _DataTypeDef
from ..base import _assert_is_bytes, _assert_is_natural_number, _assert_is_numeric, _assert_is_string, _assert_not_nullable, _is_reduced
from .base import BinaryExpressionBase
[docs]
class AddExpression(BinaryExpressionBase):
"""A class for representing addition expressions from the grammar text."""
compatible_types: tuple[_DataTypeDef, ...] = (DataType.BYTES, DataType.FLOAT, DataType.STRING, DataType.DATETIME, DataType.TIMEDELTA)
result_type: _DataTypeDef = DataType.UNDEFINED
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(AddExpression, self).__init__(*args, **kwargs)
_assert_not_nullable(self.left.result_type, role="left operand of '+'")
_assert_not_nullable(self.right.result_type, role="right operand of '+'")
if self.left.result_type != DataType.UNDEFINED and self.right.result_type != DataType.UNDEFINED:
if self.left.result_type == DataType.DATETIME:
if self.right.result_type != DataType.TIMEDELTA:
raise errors.EvaluationError('data type mismatch')
self.result_type = self.left.result_type
elif self.left.result_type == DataType.TIMEDELTA:
if self.right.result_type not in (DataType.DATETIME, DataType.TIMEDELTA):
raise errors.EvaluationError('data type mismatch')
self.result_type = self.right.result_type
elif self.left.result_type != self.right.result_type:
raise errors.EvaluationError('data type mismatch')
else:
self.result_type = self.left.result_type
def _op_add(self, thing: Any) -> Any:
left_value = self.left.evaluate(thing)
right_value = self.right.evaluate(thing)
if isinstance(left_value, datetime.datetime):
if not isinstance(right_value, datetime.timedelta):
raise errors.EvaluationError('data type mismatch (not a timedelta value)')
elif isinstance(left_value, datetime.timedelta):
if not isinstance(right_value, (datetime.timedelta, datetime.datetime)):
raise errors.EvaluationError('data type mismatch (not a datetime or timedelta value)')
elif isinstance(left_value, bytes) or isinstance(right_value, bytes):
_assert_is_bytes(left_value, right_value)
elif isinstance(left_value, str) or isinstance(right_value, str):
_assert_is_string(left_value, right_value)
else:
_assert_is_numeric(left_value, right_value)
return operator.add(left_value, right_value)
[docs]
class SubtractExpression(BinaryExpressionBase):
"""
A class for representing subtraction expressions from the grammar text.
.. versionadded:: 3.5.0
"""
compatible_types: tuple[_DataTypeDef, ...] = (DataType.FLOAT, DataType.DATETIME, DataType.TIMEDELTA)
result_type: _DataTypeDef = DataType.UNDEFINED
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(SubtractExpression, self).__init__(*args, **kwargs)
_assert_not_nullable(self.left.result_type, role="left operand of '-'")
_assert_not_nullable(self.right.result_type, role="right operand of '-'")
if self.left.result_type != DataType.UNDEFINED and self.right.result_type != DataType.UNDEFINED:
if self.left.result_type == DataType.DATETIME:
if self.right.result_type == DataType.DATETIME:
self.result_type = DataType.TIMEDELTA
elif self.right.result_type == DataType.TIMEDELTA:
self.result_type = DataType.DATETIME
else:
raise errors.EvaluationError('data type mismatch')
elif self.left.result_type == DataType.TIMEDELTA:
if self.right.result_type != DataType.TIMEDELTA:
raise errors.EvaluationError('data type mismatch')
self.result_type = self.left.result_type
elif self.left.result_type != self.right.result_type:
raise errors.EvaluationError('data type mismatch')
else:
self.result_type = self.left.result_type
def _op_sub(self, thing: Any) -> Any:
left_value = self.left.evaluate(thing)
right_value = self.right.evaluate(thing)
if isinstance(left_value, datetime.datetime):
if not isinstance(right_value, (datetime.datetime, datetime.timedelta)):
raise errors.EvaluationError('data type mismatch (not a datetime or timedelta value)')
elif isinstance(left_value, datetime.timedelta):
if not isinstance(right_value, datetime.timedelta):
raise errors.EvaluationError('data type mismatch (not a timedelta value)')
else:
_assert_is_numeric(left_value, right_value)
return operator.sub(left_value, right_value)
[docs]
class ArithmeticExpression(BinaryExpressionBase):
"""A class for representing arithmetic expressions from the grammar text such as multiplication and division."""
compatible_types: tuple[_DataTypeDef, ...] = (DataType.FLOAT,)
result_type: _DataTypeDef = DataType.FLOAT
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(ArithmeticExpression, self).__init__(*args, **kwargs)
_assert_not_nullable(self.left.result_type, role='left arithmetic operand')
_assert_not_nullable(self.right.result_type, role='right arithmetic operand')
def __op_arithmetic(self, op: Callable[[Any, Any], Any], thing: Any) -> Any:
left_value = self.left.evaluate(thing)
_assert_is_numeric(left_value)
right_value = self.right.evaluate(thing)
_assert_is_numeric(right_value)
try:
result = op(left_value, right_value)
except ZeroDivisionError:
raise errors.ArithmeticError('arithmetic error: division by zero') from None
except ArithmeticError:
raise errors.ArithmeticError('arithmetic error') from None
return result
_op_fdiv = functools.partialmethod(__op_arithmetic, operator.floordiv)
_op_tdiv = functools.partialmethod(__op_arithmetic, operator.truediv)
_op_mod = functools.partialmethod(__op_arithmetic, operator.mod)
_op_mul = functools.partialmethod(__op_arithmetic, operator.mul)
_op_pow = functools.partialmethod(__op_arithmetic, operator.pow)
[docs]
class BitwiseExpression(BinaryExpressionBase):
"""
A class for representing bitwise arithmetic expressions from the grammar text such as XOR and shifting operations.
"""
compatible_types: tuple[_DataTypeDef, ...] = (DataType.FLOAT, DataType.SET)
result_type: _DataTypeDef = DataType.UNDEFINED
def __init__(self, *args: Any, **kwargs: Any) -> None:
super(BitwiseExpression, self).__init__(*args, **kwargs)
_assert_not_nullable(self.left.result_type, role='left bitwise operand')
_assert_not_nullable(self.right.result_type, role='right bitwise operand')
# don't use DataType.is_compatible, because for sets the member type isn't important
if self.left.result_type != DataType.UNDEFINED and self.right.result_type != DataType.UNDEFINED:
if self.left.result_type.__class__ != self.right.result_type.__class__:
raise errors.EvaluationError('data type mismatch')
if self.left.result_type == DataType.FLOAT:
if _is_reduced(self.left):
_assert_is_natural_number(self.left.evaluate(None))
self.result_type = DataType.FLOAT
if self.right.result_type == DataType.FLOAT:
if _is_reduced(self.right):
_assert_is_natural_number(self.right.evaluate(None))
self.result_type = DataType.FLOAT
if DataType.is_type(self.left.result_type, DataType.SET) or DataType.is_type(self.right.result_type, DataType.SET):
self.result_type = DataType.SET # this discards the member type info
def _op_bitwise(self, op: Callable[[Any, Any], Any], thing: Any) -> Any:
left = self.left.evaluate(thing)
if DataType.from_value(left) == DataType.FLOAT:
return self._op_bitwise_float(op, thing, left)
elif DataType.is_type(DataType.from_value(left), DataType.SET):
return self._op_bitwise_set(op, thing, left)
raise errors.EvaluationError('data type mismatch')
def _op_bitwise_float(self, op: Callable[[Any, Any], Any], thing: Any, left: Any) -> Any:
_assert_is_natural_number(left)
right = self.right.evaluate(thing)
_assert_is_natural_number(right)
return coerce_value(op(int(left), int(right)))
def _op_bitwise_set(self, op: Callable[[Any, Any], Any], thing: Any, left: Any) -> Any:
right = self.right.evaluate(thing)
if not DataType.is_compatible(DataType.from_value(right), DataType.SET):
raise errors.EvaluationError('data type mismatch')
return op(left, right)
_op_bwand = functools.partialmethod(_op_bitwise, operator.and_)
_op_bwor = functools.partialmethod(_op_bitwise, operator.or_)
_op_bwxor = functools.partialmethod(_op_bitwise, operator.xor)
[docs]
class BitwiseShiftExpression(BitwiseExpression):
compatible_types: tuple[_DataTypeDef, ...] = (DataType.FLOAT,)
result_type: _DataTypeDef = DataType.FLOAT
def _op_bitwise_shift(self, *args: Any, **kwargs: Any) -> Any:
return self._op_bitwise(*args, **kwargs)
_op_bwlsh = functools.partialmethod(_op_bitwise_shift, operator.lshift)
_op_bwrsh = functools.partialmethod(_op_bitwise_shift, operator.rshift)