c_src/llama.cpp/src/models/wavtokenizer-dec.cpp

#include "models.h"

void llama_model_wavtokenizer_dec::load_arch_hparams(llama_model_loader & ml) {
    ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
    ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS,    hparams.f_norm_group_eps);
    ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
}

void llama_model_wavtokenizer_dec::load_arch_tensors(llama_model_loader &) {
    LLAMA_LOAD_LOCALS;

    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0);

    conv1d   = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0);
    conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias",   0), {1, hparams.posnet.n_embd}, 0);

    // posnet
    {
        const int64_t n_embd = hparams.posnet.n_embd;

        for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) {
            auto & layer = layers[i].posnet;

            // posnet:
            //
            //  - resnet
            //  - resnet
            //  - attn
            //  - resnet
            //  - resnet
            //  - norm
            //
            switch (i) {
                case 0:
                case 1:
                case 3:
                case 4:
                    {
                        layer.norm1   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0);
                        layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias",   i), {1, n_embd}, 0);

                        layer.conv1   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0);
                        layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias",   i), {1, n_embd}, 0);

                        layer.norm2   = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0);
                        layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias",   i), {1, n_embd}, 0);

                        layer.conv2   = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0);
                        layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias",   i), {1, n_embd}, 0);
                    } break;
                case 2:
                    {
                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);

                        layer.attn_q      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "weight", i), {1, n_embd, n_embd}, 0);
                        layer.attn_q_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q,    "bias",   i), {1, n_embd}, 0);

                        layer.attn_k      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "weight", i), {1, n_embd, n_embd}, 0);
                        layer.attn_k_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K,    "bias",   i), {1, n_embd}, 0);

                        layer.attn_v      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "weight", i), {1, n_embd, n_embd}, 0);
                        layer.attn_v_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V,    "bias",   i), {1, n_embd}, 0);

                        layer.attn_o      = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "weight", i), {1, n_embd, n_embd}, 0);
                        layer.attn_o_b    = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT,  "bias",   i), {1, n_embd}, 0);
                    } break;
                case 5:
                    {
                        layer.norm   = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0);
                        layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias",   i), {1, n_embd}, 0);
                    } break;
                default: GGML_ABORT("unknown posnet layer");
            };
        }
    }

    GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd);

    tok_norm   = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0);
    tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias",   0), {hparams.posnet.n_embd}, 0);

    // convnext
    {
        const int64_t n_embd = hparams.convnext.n_embd;

        for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) {
            auto & layer = layers[i].convnext;

            layer.dw     = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "weight", i), {7, 1, n_embd}, 0);
            layer.dw_b   = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW,    "bias",   i), {1, n_embd}, 0);

            layer.norm   = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "weight", i), {n_embd}, 0);
            layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM,  "bias",   i), {n_embd}, 0);

            layer.pw1    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "weight", i), {n_embd, n_ff}, 0);
            layer.pw1_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1,   "bias",   i), {n_ff}, 0);

            layer.pw2    = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "weight", i), {n_ff, n_embd}, 0);
            layer.pw2_b  = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2,   "bias",   i), {n_embd}, 0);

            layer.gamma  = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0);
        }

        // output
        output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
        output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
    }

    output   = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0);
    output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"),   {hparams.n_embd_out()}, 0);
}

std::unique_ptr<llm_graph_context> llama_model_wavtokenizer_dec::build_arch_graph(const llm_graph_params & params) const {
    return std::make_unique<graph>(*this, params);
}

llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
    ggml_tensor * cur;
    ggml_tensor * inpL;

    inpL = build_inp_embd(model.tok_embd);

    cur = ggml_cont(ctx0, ggml_transpose(ctx0, inpL));

    cur = ggml_conv_1d_ph(ctx0, model.conv1d, cur, 1, 1);
    cur = ggml_add(ctx0, cur, model.conv1d_b);

    // posnet
    for (uint32_t il = 0; il < hparams.posnet.n_layer; ++il) {
        const auto & layer = model.layers[il].posnet;

        inpL = cur;

        switch (il) {
            case 0:
            case 1:
            case 3:
            case 4:
                {
                    cur = build_norm(cur,
                            layer.norm1,
                            layer.norm1_b,
                            LLM_NORM_GROUP, 0);

                    cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);

                    cur = ggml_conv_1d_ph(ctx0, layer.conv1, cur, 1, 1);
                    cur = ggml_add(ctx0, cur, layer.conv1_b);

                    cur = build_norm(cur,
                            layer.norm2,
                            layer.norm2_b,
                            LLM_NORM_GROUP, 0);

                    cur = ggml_mul(ctx0, ggml_sigmoid(ctx0, cur), cur);

                    cur = ggml_conv_1d_ph(ctx0, layer.conv2, cur, 1, 1);
                    cur = ggml_add(ctx0, cur, layer.conv2_b);

                    cur = ggml_add(ctx0, cur, inpL);
                } break;
            case 2:
                {
                    cur = build_norm(cur,
                            layer.attn_norm,
                            layer.attn_norm_b,
                            LLM_NORM_GROUP, 0);

                    ggml_tensor * q;
                    ggml_tensor * k;
                    ggml_tensor * v;

                    q = ggml_conv_1d_ph(ctx0, layer.attn_q, cur, 1, 1);
                    k = ggml_conv_1d_ph(ctx0, layer.attn_k, cur, 1, 1);
                    v = ggml_conv_1d_ph(ctx0, layer.attn_v, cur, 1, 1);

                    q = ggml_add(ctx0, q, layer.attn_q_b);
                    k = ggml_add(ctx0, k, layer.attn_k_b);
                    v = ggml_add(ctx0, v, layer.attn_v_b);

                    q = ggml_cont(ctx0, ggml_transpose(ctx0, q));
                    k = ggml_cont(ctx0, ggml_transpose(ctx0, k));

                    ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);

                    kq = ggml_soft_max_ext(ctx0, kq, nullptr, 1.0f/sqrtf(float(hparams.posnet.n_embd)), 0.0f);

                    cur = ggml_mul_mat(ctx0, kq, v);

                    cur = ggml_conv_1d_ph(ctx0, layer.attn_o, cur, 1, 1);
                    cur = ggml_add(ctx0, cur, layer.attn_o_b);

                    cur = ggml_add(ctx0, cur, inpL);
                } break;
            case 5:
                {
                    cur = build_norm(cur,
                            layer.norm,
                            layer.norm_b,
                            LLM_NORM_GROUP, 0);
                } break;
            default: GGML_ABORT("unknown posnet layer");
        };
    }
    cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));

    cur = build_norm(cur,
            model.tok_norm,
            model.tok_norm_b,
            LLM_NORM, 0);

    cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));

    inpL = cur;

    // convnext
    for (uint32_t il = 0; il < hparams.convnext.n_layer; ++il) {
        const auto & layer = model.layers[il].convnext;

        cur = inpL;

        cur = ggml_conv_1d_dw_ph(ctx0, layer.dw, cur, 1, 1);
        cur = ggml_add(ctx0, cur, layer.dw_b);

        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));

        cur = build_norm(cur,
                layer.norm,
                layer.norm_b,
                LLM_NORM, -1);

        cur = build_ffn(cur,
                layer.pw1, layer.pw1_b, NULL,
                NULL,      NULL,        NULL,
                layer.pw2, layer.pw2_b, NULL,
                NULL,
                LLM_FFN_GELU, LLM_FFN_SEQ, il);

        cur = ggml_mul(ctx0, cur, layer.gamma);

        cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));

        inpL = ggml_add(ctx0, cur, inpL);
    }
    cur = inpL;

    cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));

    cur = build_norm(cur,
            model.output_norm,
            model.output_norm_b,
            LLM_NORM, -1);

    // lm_head
    cur = build_lora_mm(model.output, cur);

    cur = ggml_add(ctx0, cur, model.output_b);

    cb(cur, "result_embd", -1);
    res->t_embd = cur;

    ggml_build_forward_expand(gf, cur);
}