#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#


import pandas as pd
import numpy as np

from pyspark import pandas as ps
from pyspark.testing.utils import is_ansi_mode_test
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.pandas.tests.data_type_ops.testing_utils import OpsTestBase


class NumMulDivTestsMixin:
    @property
    def float_pser(self):
        return pd.Series([1, 2, 3], dtype=float)

    @property
    def float_psser(self):
        return ps.from_pandas(self.float_pser)

    def test_mul(self):
        pdf, psdf = self.pdf, self.psdf
        for col in self.numeric_df_cols:
            pser, psser = pdf[col], psdf[col]
            self.assert_eq(pser * pser, psser * psser, check_exact=False)
            self.assert_eq(pser * pser.astype(bool), psser * psser.astype(bool), check_exact=False)
            self.assert_eq(pser * True, psser * True, check_exact=False)
            self.assert_eq(pser * False, psser * False, check_exact=False)

            if psser.dtype in [int, np.int32]:
                self.assert_eq(pser * pdf["string"], psser * psdf["string"])
            else:
                self.assertRaises(TypeError, lambda: psser * psdf["string"])

            self.assert_eq(pser * pdf["bool"], psser * psdf["bool"], check_exact=False)

            self.assertRaises(TypeError, lambda: psser * psdf["datetime"])
            self.assertRaises(TypeError, lambda: psser * psdf["date"])
            self.assertRaises(TypeError, lambda: psser * psdf["categorical"])

        if is_ansi_mode_test:
            self.assertRaises(TypeError, lambda: psdf["decimal"] * psdf["float"])
            self.assertRaises(TypeError, lambda: psdf["float"] * psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] * psdf["float32"])
            self.assertRaises(TypeError, lambda: psdf["float32"] * psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] * 0.1)
            self.assertRaises(TypeError, lambda: 0.1 * psdf["decimal"])

    def test_truediv(self):
        pdf, psdf = self.pdf, self.psdf
        for col in self.numeric_df_cols:
            pser, psser = pdf[col], psdf[col]
            if psser.dtype in [float, int, np.int32]:
                self.assert_eq(pser / pser, psser / psser)
                self.assert_eq(pser / pser.astype(bool), psser / psser.astype(bool))
                self.assert_eq(pser / True, psser / True)
                self.assert_eq(pser / False, psser / False)

            for n_col in self.non_numeric_df_cols:
                if n_col == "bool":
                    self.assert_eq(pdf["float"] / pdf[n_col], psdf["float"] / psdf[n_col])
                else:
                    self.assertRaises(TypeError, lambda: psser / psdf[n_col])

        if is_ansi_mode_test:
            self.assertRaises(TypeError, lambda: psdf["decimal"] / psdf["float"])
            self.assertRaises(TypeError, lambda: psdf["float"] / psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] / psdf["float32"])
            self.assertRaises(TypeError, lambda: psdf["float32"] / psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] / 0.1)
            self.assertRaises(TypeError, lambda: 0.1 / psdf["decimal"])

    def test_floordiv(self):
        pdf, psdf = self.pdf, self.psdf
        pser, psser = pdf["float"], psdf["float"]
        self.assert_eq(pser // pser, psser // psser)
        self.assert_eq(pser // pser.astype(bool), psser // psser.astype(bool))
        self.assert_eq(pser // True, psser // True)
        self.assert_eq(pser // False, psser // False)

        for n_col in self.non_numeric_df_cols:
            if n_col == "bool":
                self.assert_eq(pdf["float"] // pdf["bool"], psdf["float"] // psdf["bool"])
            else:
                for col in self.numeric_df_cols:
                    psser = psdf[col]
                    self.assertRaises(TypeError, lambda: psser // psdf[n_col])

        if is_ansi_mode_test:
            self.assertRaises(TypeError, lambda: psdf["decimal"] // psdf["float"])
            self.assertRaises(TypeError, lambda: psdf["float"] // psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] // psdf["float32"])
            self.assertRaises(TypeError, lambda: psdf["float32"] // psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] // 0.1)
            self.assertRaises(TypeError, lambda: 0.1 // psdf["decimal"])

    def test_mod(self):
        pdf, psdf = self.pdf, self.psdf

        # element-wise modulo for numeric columns
        for col in self.numeric_df_cols:
            pser, psser = pdf[col], psdf[col]

            if psser.dtype in [float, int, np.int32]:
                self.assert_eq(pser % pser, psser % psser)
                self.assert_eq(pser % pser.astype(bool), psser % psser.astype(bool))
                self.assert_eq(pser % True, psser % True)
                # TODO: decide if to follow pser % False
                self.assert_eq(pser % 0, psser % False)

            # modulo with non-numeric columns
            for n_col in self.non_numeric_df_cols:
                if n_col == "bool":
                    self.assert_eq(pdf["float"] % pdf["bool"], psdf["float"] % psdf["bool"])
                else:
                    self.assertRaises(TypeError, lambda: psser % psdf[n_col])

        if is_ansi_mode_test:
            self.assertRaises(TypeError, lambda: psdf["decimal"] % psdf["float"])
            self.assertRaises(TypeError, lambda: psdf["float"] % psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] % psdf["float32"])
            self.assertRaises(TypeError, lambda: psdf["float32"] % psdf["decimal"])
            self.assertRaises(TypeError, lambda: psdf["decimal"] % 0.1)
            self.assertRaises(TypeError, lambda: 0.1 % psdf["decimal"])


class NumMulDivTests(
    NumMulDivTestsMixin,
    OpsTestBase,
    PandasOnSparkTestCase,
):
    pass


if __name__ == "__main__":
    from pyspark.testing import main

    main()
