// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// Copyright by contributors to this project.
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use alloc::vec::Vec;
use core::{
    fmt::{self, Debug},
    ops::Deref,
};

use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};

use crate::{
    client::MlsError,
    group::{framing::FramedContent, MessageSignature},
    WireFormat,
};

use super::{AuthenticatedContent, ConfirmationTag};

#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct ConfirmedTranscriptHash(
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
    Vec<u8>,
);

impl Debug for ConfirmedTranscriptHash {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        mls_rs_core::debug::pretty_bytes(&self.0)
            .named("ConfirmedTranscriptHash")
            .fmt(f)
    }
}

impl Deref for ConfirmedTranscriptHash {
    type Target = Vec<u8>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<Vec<u8>> for ConfirmedTranscriptHash {
    fn from(value: Vec<u8>) -> Self {
        Self(value)
    }
}

impl ConfirmedTranscriptHash {
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub(crate) async fn create<P: CipherSuiteProvider>(
        cipher_suite_provider: &P,
        interim_transcript_hash: &InterimTranscriptHash,
        content: &AuthenticatedContent,
    ) -> Result<Self, MlsError> {
        #[derive(Debug, MlsSize, MlsEncode)]
        struct ConfirmedTranscriptHashInput<'a> {
            wire_format: WireFormat,
            content: &'a FramedContent,
            signature: &'a MessageSignature,
        }

        let input = ConfirmedTranscriptHashInput {
            wire_format: content.wire_format,
            content: &content.content,
            signature: &content.auth.signature,
        };

        let hash_input = [
            interim_transcript_hash.deref(),
            input.mls_encode_to_vec()?.deref(),
        ]
        .concat();

        cipher_suite_provider
            .hash(&hash_input)
            .await
            .map(Into::into)
            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
    }
}

#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub(crate) struct InterimTranscriptHash(
    #[mls_codec(with = "mls_rs_codec::byte_vec")]
    #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
    Vec<u8>,
);

impl Debug for InterimTranscriptHash {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        mls_rs_core::debug::pretty_bytes(&self.0)
            .named("InterimTranscriptHash")
            .fmt(f)
    }
}

impl Deref for InterimTranscriptHash {
    type Target = Vec<u8>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<Vec<u8>> for InterimTranscriptHash {
    fn from(value: Vec<u8>) -> Self {
        Self(value)
    }
}

impl InterimTranscriptHash {
    #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
    pub async fn create<P: CipherSuiteProvider>(
        cipher_suite_provider: &P,
        confirmed: &ConfirmedTranscriptHash,
        confirmation_tag: &ConfirmationTag,
    ) -> Result<Self, MlsError> {
        #[derive(Debug, MlsSize, MlsEncode)]
        struct InterimTranscriptHashInput<'a> {
            confirmation_tag: &'a ConfirmationTag,
        }

        let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;

        cipher_suite_provider
            .hash(&[confirmed.0.deref(), &input].concat())
            .await
            .map(Into::into)
            .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
    }
}

// Test vectors come from the MLS interop repository and contain a proposal by reference.
#[cfg(feature = "by_ref_proposal")]
#[cfg(test)]
mod tests {
    use alloc::vec::Vec;

    use mls_rs_codec::MlsDecode;

    use crate::{
        crypto::test_utils::try_test_cipher_suite_provider,
        group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
    };

    #[cfg(not(mls_build_async))]
    use alloc::{boxed::Box, vec};

    #[cfg(not(mls_build_async))]
    use crate::{
        crypto::test_utils::test_cipher_suite_provider,
        group::{
            confirmation_tag::ConfirmationTag,
            framing::Content,
            proposal::{Proposal, ProposalOrRef, RemoveProposal},
            test_utils::get_test_group_context,
            Commit, LeafIndex, Sender,
        },
        mls_rs_codec::MlsEncode,
        CipherSuite, CipherSuiteProvider, WireFormat,
    };

    #[cfg(not(mls_build_async))]
    use super::{ConfirmedTranscriptHash, InterimTranscriptHash};

    #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
    struct TestCase {
        pub cipher_suite: u16,

        #[serde(with = "hex::serde")]
        pub confirmation_key: Vec<u8>,
        #[serde(with = "hex::serde")]
        pub authenticated_content: Vec<u8>,
        #[serde(with = "hex::serde")]
        pub interim_transcript_hash_before: Vec<u8>,

        #[serde(with = "hex::serde")]
        pub confirmed_transcript_hash_after: Vec<u8>,
        #[serde(with = "hex::serde")]
        pub interim_transcript_hash_after: Vec<u8>,
    }

    #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
    async fn transcript_hash() {
        let test_cases: Vec<TestCase> =
            load_test_case_json!(interop_transcript_hashes, generate_test_vector());

        for test_case in test_cases.into_iter() {
            let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
                continue;
            };

            let auth_content =
                AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();

            assert!(auth_content.content.content_type() == ContentType::Commit);

            let conf_key = &test_case.confirmation_key;
            let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
            let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();

            let matches = conf_tag
                .matches(conf_key, &conf_hash_after, &cs)
                .await
                .unwrap();

            assert!(matches);

            let (expected_interim, expected_conf) = transcript_hashes(
                &cs,
                &test_case.interim_transcript_hash_before.into(),
                &auth_content,
            )
            .await
            .unwrap();

            assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
            assert_eq!(expected_conf, conf_hash_after);
        }
    }

    #[cfg(not(mls_build_async))]
    #[cfg_attr(coverage_nightly, coverage(off))]
    fn generate_test_vector() -> Vec<TestCase> {
        CipherSuite::all().fold(vec![], |mut test_cases, cs| {
            let cs = test_cipher_suite_provider(cs);

            let context = get_test_group_context(0x3456, cs.cipher_suite());

            let proposal = Proposal::Remove(RemoveProposal {
                to_remove: LeafIndex(1),
            });

            let proposal = ProposalOrRef::Proposal(Box::new(proposal));

            let commit = Commit {
                proposals: vec![proposal],
                path: None,
            };

            let signer = cs.signature_key_generate().unwrap().0;

            let mut auth_content = AuthenticatedContent::new_signed(
                &cs,
                &context,
                Sender::Member(0),
                Content::Commit(alloc::boxed::Box::new(commit)),
                &signer,
                WireFormat::PublicMessage,
                vec![],
            )
            .unwrap();

            let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();

            let conf_hash_after =
                ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();

            let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
            let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();

            let interim_hash_after =
                InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();

            auth_content.auth.confirmation_tag = Some(conf_tag);

            let test_case = TestCase {
                cipher_suite: cs.cipher_suite().into(),

                confirmation_key: conf_key,
                authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
                interim_transcript_hash_before: interim_hash_before.0,

                confirmed_transcript_hash_after: conf_hash_after.0,
                interim_transcript_hash_after: interim_hash_after.0,
            };

            test_cases.push(test_case);
            test_cases
        })
    }

    #[cfg(mls_build_async)]
    fn generate_test_vector() -> Vec<TestCase> {
        panic!("Tests cannot be generated in async mode");
    }
}