# Copyright 2017 Gehirn Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from typing import (
    AbstractSet,
    Optional,
)

from .exceptions import (
    JWSDecodeError,
    JWSEncodeError,
)
from .jwa import (
    AbstractSigningAlgorithm,
    supported_signing_algorithms,
)
from .jwk import AbstractJWKBase
from .utils import (
    b64decode,
    b64encode,
)

__all__ = ["JWS"]


class JWS:

    def __init__(self) -> None:
        self._supported_algs = supported_signing_algorithms()

    def _retrieve_alg(self, alg: str) -> AbstractSigningAlgorithm:
        try:
            return self._supported_algs[alg]
        except KeyError:
            raise JWSDecodeError("Unsupported signing algorithm.")

    def encode(
        self,
        message: bytes,
        key: Optional[AbstractJWKBase] = None,
        alg="HS256",
        optional_headers: Optional[dict[str, str]] = None,
    ) -> str:
        if alg not in self._supported_algs:  # pragma: no cover
            raise JWSEncodeError(f"unsupported algorithm: {alg}")
        alg_impl = self._retrieve_alg(alg)

        header = optional_headers.copy() if optional_headers else {}
        header["alg"] = alg

        header_b64 = b64encode(
            json.dumps(header, separators=(",", ":")).encode("ascii")
        )
        message_b64 = b64encode(message)
        signing_message = header_b64 + "." + message_b64

        signature = alg_impl.sign(signing_message.encode("ascii"), key)
        signature_b64 = b64encode(signature)

        return signing_message + "." + signature_b64

    def _decode_segments(
        self, message: str
    ) -> tuple[dict[str, str], bytes, bytes, str]:
        try:
            signing_message, signature_b64 = message.rsplit(".", 1)
            header_b64, message_b64 = signing_message.split(".")
        except ValueError:
            raise JWSDecodeError("malformed JWS payload")

        header = json.loads(b64decode(header_b64).decode("ascii"))
        message_bin = b64decode(message_b64)
        signature = b64decode(signature_b64)
        return header, message_bin, signature, signing_message

    def decode(
        self,
        message: str,
        key: Optional[AbstractJWKBase] = None,
        do_verify=True,
        algorithms: Optional[AbstractSet[str]] = None,
    ) -> bytes:
        if algorithms is None:
            algorithms = set(supported_signing_algorithms().keys())

        header, message_bin, signature, signing_message = (
            self._decode_segments(message)
        )

        alg_value = header["alg"]
        if alg_value not in algorithms:
            raise JWSDecodeError("Unsupported signing algorithm.")

        alg_impl = self._retrieve_alg(alg_value)
        if do_verify and not alg_impl.verify(
            signing_message.encode("ascii"), key, signature
        ):
            raise JWSDecodeError("JWS passed could not be validated")

        return message_bin
