/* Copyright 2013-2022 Simo Sorce <simo@samba.org>, see COPYING for license */

#include <errno.h>
#include <string.h>

#include <openssl/des.h>
#include <openssl/rc4.h>
#include <openssl/evp.h>
#include <openssl/rand.h>
#include <zlib.h>

#include "crypto.h"

/* legacy provider with openssl 3.0 */
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
#  include <openssl/provider.h>
#  include <openssl/crypto.h>
#endif

int RAND_BUFFER(struct ntlm_buffer *random)
{
    int ret;

    ret = RAND_bytes(random->data, random->length);
    if (ret != 1) {
        return ERR_CRYPTO;
    }
    return 0;
}

int HMAC_MD5_IOV(struct ntlm_buffer *key,
                 struct ntlm_iov *iov,
                 struct ntlm_buffer *result)
{
    EVP_MD_CTX* ctx = NULL;
    EVP_PKEY* pkey = NULL;
    size_t i;
    int ret = 0;

    if (result->length != 16) return EINVAL;

    pkey = EVP_PKEY_new_mac_key(EVP_PKEY_HMAC, NULL, key->data, key->length);
    if (!pkey) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ctx = EVP_MD_CTX_new();
    if (!ctx) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = EVP_DigestSignInit(ctx, NULL, EVP_md5(), NULL, pkey);
    if (ret != 1) {
        ret = ERR_CRYPTO;
        goto done;
    }

    for (i = 0; i < iov->num; i++) {
        ret = EVP_DigestSignUpdate(ctx, iov->data[i]->data,
                                   iov->data[i]->length);
        if (ret != 1) {
            ret = ERR_CRYPTO;
            goto done;
        }
    }

    ret = EVP_DigestSignFinal(ctx, result->data, &result->length);
    if (ret != 1) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = 0;

done:
    EVP_MD_CTX_free(ctx);
    EVP_PKEY_free(pkey);
    return ret;
}

int HMAC_MD5(struct ntlm_buffer *key,
             struct ntlm_buffer *payload,
             struct ntlm_buffer *result)
{
    struct ntlm_iov iov;

    iov.num = 1;
    iov.data = &payload;
    return HMAC_MD5_IOV(key, &iov, result);
}

#if OPENSSL_VERSION_NUMBER >= 0x30000000L
typedef struct ossl3_library_context {
    OSSL_LIB_CTX *libctx;
    OSSL_PROVIDER *legacy_provider;
    OSSL_PROVIDER *default_provider;
} ossl3_context_t;

static pthread_once_t global_ossl3_ctx_init = PTHREAD_ONCE_INIT;
static ossl3_context_t *global_ossl3_ctx = NULL;

static void init_global_ossl3_ctx(void)
{
    ossl3_context_t *ctx = OPENSSL_malloc(sizeof(ossl3_context_t));
    if (!ctx) return;

    ctx->libctx = OSSL_LIB_CTX_new();
    if (!ctx->libctx) {
        OPENSSL_free(ctx);
        return;
    }

    /* Load both legacy and default provider as both may be needed */
    /* if they fail keep going and an error will be raised when we try to
     * fetch the cipher later */
    ctx->legacy_provider = OSSL_PROVIDER_load(ctx->libctx, "legacy");
    ctx->default_provider = OSSL_PROVIDER_load(ctx->libctx, "default");
    global_ossl3_ctx = ctx;
}

static ossl3_context_t *get_ossl3_ctx()
{
    int ret;

    ret = pthread_once(&global_ossl3_ctx_init, init_global_ossl3_ctx);
    if (ret != 0) {
        return NULL;
    }

    return global_ossl3_ctx;
}

__attribute__((destructor))
static void free_ossl3_ctx()
{
    ossl3_context_t *ctx = global_ossl3_ctx;
    if (ctx == NULL) return;
    if (ctx->legacy_provider) OSSL_PROVIDER_unload(ctx->legacy_provider);
    if (ctx->default_provider) OSSL_PROVIDER_unload(ctx->default_provider);
    if (ctx->libctx) OSSL_LIB_CTX_free(ctx->libctx);

    OPENSSL_free(ctx);
}
#endif

static int mdx_hash(const EVP_MD *type,
                    struct ntlm_buffer *payload,
                    struct ntlm_buffer *result)
{
    EVP_MD_CTX *ctx;
    unsigned int len;
    int ret;

    if (result->length != 16) return EINVAL;

    ctx = EVP_MD_CTX_new();
    if (!ctx) {
        ret = ERR_CRYPTO;
        goto done;
    }

    EVP_MD_CTX_init(ctx);
    ret = EVP_DigestInit_ex(ctx, type, NULL);
    if (ret == 0) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = EVP_DigestUpdate(ctx, payload->data, payload->length);
    if (ret == 0) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = EVP_DigestFinal_ex(ctx, result->data, &len);
    if (ret == 0) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = 0;

done:
    if (ctx) EVP_MD_CTX_free(ctx);
    return ret;
}

int MD4_HASH(struct ntlm_buffer *payload,
             struct ntlm_buffer *result)
{
#if OPENSSL_VERSION_NUMBER >= 0x30000000L
    ossl3_context_t *ossl3_ctx = NULL;
    EVP_MD *md;
    int ret;

    ossl3_ctx = get_ossl3_ctx();
    if (ossl3_ctx == NULL) {
        ret = ERR_CRYPTO;
        goto done;
    }

    md = EVP_MD_fetch(ossl3_ctx->libctx, "MD4", "");
    if (md == NULL) {
        ret = ERR_CRYPTO;
        goto done;
    }

    ret = mdx_hash(md, payload, result);

done:
    return ret;
#else
    return mdx_hash(EVP_md4(), payload, result);
#endif

}

int MD5_HASH(struct ntlm_buffer *payload,
             struct ntlm_buffer *result)
{
    return mdx_hash(EVP_md5(), payload, result);
}

struct ntlm_rc4_handle {
    RC4_KEY key;
};

int RC4_INIT(struct ntlm_buffer *rc4_key,
             enum ntlm_cipher_mode mode,
             struct ntlm_rc4_handle **out)
{
    struct ntlm_rc4_handle *handle;

    handle = malloc(sizeof(struct ntlm_rc4_handle));
    if (!handle) return ENOMEM;

    RC4_set_key(&handle->key, rc4_key->length, rc4_key->data);

    *out = handle;
    return 0;
}

int RC4_UPDATE(struct ntlm_rc4_handle *handle,
               struct ntlm_buffer *in, struct ntlm_buffer *out)
{
    if (out->length < in->length) return EINVAL;

    if (in->length > 0) {
        RC4(&handle->key, in->length, in->data, out->data);
    }

    out->length = in->length;
    return 0;
}

void RC4_FREE(struct ntlm_rc4_handle **handle)
{
    if (!handle || !*handle) return;
    safezero((uint8_t *)(&((*handle)->key)), sizeof(RC4_KEY));
    safefree(*handle);
}

int RC4_EXPORT(struct ntlm_rc4_handle *handle, struct ntlm_buffer *out)
{
    RC4_INT *data = (RC4_INT *)out->data;
    int len = 258 * sizeof(RC4_INT);

    if (out->length < len) return EINVAL;

    data[0] = handle->key.x;
    data[1] = handle->key.y;
    memcpy(&data[2], handle->key.data, sizeof(RC4_INT) * 256);
    out->length = len;
    return 0;
}

int RC4_IMPORT(struct ntlm_rc4_handle **_handle, struct ntlm_buffer *in)
{
    struct ntlm_rc4_handle *handle;
    RC4_INT *data = (RC4_INT *)in->data;
    int len = 258 * sizeof(RC4_INT);

    if (in->length != len) return EINVAL;

    handle = malloc(sizeof(struct ntlm_rc4_handle));
    if (!handle) return ENOMEM;

    handle->key.x = data[0];
    handle->key.y = data[1];
    memcpy(handle->key.data, &data[2], sizeof(RC4_INT) * 256);

    *_handle = handle;
    return 0;
}

int RC4K(struct ntlm_buffer *key,
         enum ntlm_cipher_mode mode,
         struct ntlm_buffer *payload,
         struct ntlm_buffer *result)
{
    struct ntlm_rc4_handle *handle;
    int ret;

    if (result->length < payload->length) return EINVAL;

    ret = RC4_INIT(key, mode, &handle);
    if (ret) return ret;

    ret = RC4_UPDATE(handle, payload, result);

    RC4_FREE(&handle);
    return ret;
}

int WEAK_DES(struct ntlm_buffer *key,
             struct ntlm_buffer *payload,
             struct ntlm_buffer *result)
{
    DES_key_schedule schedule;
    DES_cblock key8;

    if ((key->length != 7) ||
        (payload->length != 8) ||
        (result->length != 8)) {
        return EINVAL;
    }

    /* Undocumented shuffle needed before calling DES_set_key_unchecked */
    key8[0] =  key->data[0];
    key8[1] = (key->data[0] << 7) | (key->data[1] >> 1);
    key8[2] = (key->data[1] << 6) | (key->data[2] >> 2);
    key8[3] = (key->data[2] << 5) | (key->data[3] >> 3);
    key8[4] = (key->data[3] << 4) | (key->data[4] >> 4);
    key8[5] = (key->data[4] << 3) | (key->data[5] >> 5);
    key8[6] = (key->data[5] << 2) | (key->data[6] >> 6);
    key8[7] = (key->data[6] << 1);

    DES_set_key_unchecked(&key8, &schedule);
    DES_ecb_encrypt((DES_cblock *)payload->data,
                    (DES_cblock *)result->data, &schedule, 1);
    return 0;
}

int DESL(struct ntlm_buffer *key,
         struct ntlm_buffer *payload,
         struct ntlm_buffer *result)
{
    uint8_t buf7[7];
    struct ntlm_buffer key7;
    struct ntlm_buffer res8;

    if ((key->length != 16) ||
        (payload->length != 8) ||
        (result->length != 24)) {
        return EINVAL;
    }

    /* part 1 */
    key7.data = key->data;
    key7.length = 7;
    res8.data = result->data;
    res8.length = 8;
    WEAK_DES(&key7, payload, &res8);
    /* part 2 */
    key7.data = &key->data[7];
    key7.length = 7;
    res8.data = &result->data[8];
    res8.length = 8;
    WEAK_DES(&key7, payload, &res8);
    /* part 3 */
    memcpy(buf7, &key->data[14], 2);
    memset(&buf7[2], 0, 5);
    key7.data = buf7;
    key7.length = 7;
    res8.data = &result->data[16];
    res8.length = 8;
    WEAK_DES(&key7, payload, &res8);

    return 0;
}

uint32_t CRC32(uint32_t crc, struct ntlm_buffer *payload)
{
    return crc32(crc, payload->data, payload->length);
}
