20 April, 2025

Introducing: ggml-easy

A simple and easy-to-use wrapper for working with GGML

Introducing: ggml-easy
Available in:
 English
Reading time: 7 min.
Table of content
  • Motivation
  • Simple demo
  • Debugging
  • Safetensors support
  • Advanced demo
  • Conclusion

In one of my previous articles, I introduced ggml and its capabilities. However, I noticed that the API was a bit cumbersome to work with. Therefore, I decided to create a simple wrapper around ggml to make it easier to use.

This wrapper is called ggml-easy and is available on GitHub. It is a simple header-only, easy-to-use wrapper for working with GGML.

Motivation

I started to work with GGML because I wanted to bring more multimodal capabilities to llama.cpp. If you haven't noticed, I also wrote a blog post about my work on vision models. However, to get more modalities like audio (input / output), I need a more efficient way to rapidly build prototypes and debug them. This is where ggml-easy comes in.

The main goals of ggml-easy are:

  • Simplifying usage: Offering a straightforward API to reduce boilerplate code.
  • Facilitating debugging: Providing user-friendly debugging utilities.
  • Maintaining compatibility: Mirroring ggml's underlying operations, allowing seamless code transfer to projects like llama.cpp.

Simple demo

Going back to the simple demo of multiplying two matrices, you can see the original code without ggml-easy here.

With ggml-easy, these steps are required to perform the same operation:

  1. Create a context
  2. Build a computation graph
  3. Set the data for the input tensors
  4. Compute the graph
  5. Get the result

Here is the code:

#include "ggml.h"
#include "ggml-easy.h" // header-only

int main() {
    // 1. Create a context
    ggml_easy::ctx_params params;
    // use GPU or accelerator if available, fallback to CPU otherwise
    params.use_gpu = true;
    ggml_easy::ctx ctx(params);

    // initialize data of matrices to perform matrix multiplication
    const int rows_A = 4, cols_A = 2;
    float matrix_A[rows_A * cols_A] = {
        2, 8,
        5, 1,
        4, 2,
        8, 6
    };
    const int rows_B = 3, cols_B = 2;
    float matrix_B[rows_B * cols_B] = {
        10, 5,
        9, 9,
        5, 4
    };

    // 2. Build a computation graph
    ctx.build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf, auto & utils) {
        ggml_tensor * a = utils.new_input("a", GGML_TYPE_F32, cols_A, rows_A);
        ggml_tensor * b = utils.new_input("b", GGML_TYPE_F32, cols_B, rows_B);
        ggml_tensor * result = ggml_mul_mat(ctx_gf, a, b);
        utils.mark_output(result, "result");
    });

    // 3. Set the data for the input tensors
    ctx.set_tensor_data("a", matrix_A);
    ctx.set_tensor_data("b", matrix_B);

    // 4. Compute the graph
    ggml_status status = ctx.compute();

    // 5. Get the result
    auto result = ctx.get_tensor_data("result");
    ggml_tensor * result_tensor        = result.first;
    std::vector<uint8_t> & result_data = result.second;

    // print result
    ggml_easy::debug::print_tensor_data(result_tensor, result_data.data());

    return 0;
}

Output:

result.data: [
     [
      [     60.0000,      55.0000,      50.0000,     110.0000],
      [     90.0000,      54.0000,      54.0000,     126.0000],
      [     42.0000,      29.0000,      28.0000,      64.0000],
     ],
    ]

Debugging

When working with a more complicated computation graph, a common use case is to print the intermediate results of the computation graph. For example, in pytorch, you can simply use print:

import torch

a = torch.randn(2, 3)
b = torch.randn(3, 4)
c = torch.matmul(a, b)
print(c) # intermediate result
d = c * 2
print(d) # final result

In ggml, without ggml-easy, you would have to manually call ggml_backend_tensor_get to retrieve the intermediate result, then print it. This is cumbersome and not very user-friendly.

With ggml-easy, you can simply use the utils.debug_print function to "mark" the tensor you want to print. After the computation is done, ggml-easy will automatically print the tensor for you, no action needed.

For example, in the code above, you can simply add:

// create cgraph
ctx.build_graph([&](ggml_context * ctx_gf, ggml_cgraph * gf, auto & utils) {
    ggml_tensor * a = utils.new_input("a", GGML_TYPE_F32, cols_A, rows_A);
    ggml_tensor * b = utils.new_input("b", GGML_TYPE_F32, cols_B, rows_B);
    // the intermediate result
    ggml_tensor * a_mul_b = ggml_mul_mat(ctx_gf, a, b);
    // mark tensor to print
    utils.debug_print(a_mul_b, "a_mul_b");
    // the final result
    ggml_tensor * result = ggml_scale(ctx_gf, a_mul_b, 2);
    utils.mark_output(result, "result");
});

Output:

a_mul_b.shape = [4, 3]
a_mul_b.data: [
     [
      [     60.0000,      55.0000,      50.0000,     110.0000],
      [     90.0000,      54.0000,      54.0000,     126.0000],
      [     42.0000,      29.0000,      28.0000,      64.0000],
     ],
    ]
result.data: [
     [
      [    120.0000,     110.0000,     100.0000,     220.0000],
      [    180.0000,     108.0000,     108.0000,     252.0000],
      [     84.0000,      58.0000,      56.0000,     128.0000],
     ],
    ]

Safetensors support

One of the most important features of ggml-easy is the native support for loading safetensors. This means you are no longer required to convert your models to GGUF format.

Currently, F32, F16 and BF16 types are supported.

You can simply load the safetensors file directly and use it in your code:

ggml_easy::ctx_params params;
ggml_easy::ctx ctx(params);
ctx.load_safetensors("mimi.safetensors", {
    // optionally, rename tensor to make it shorter (name length limit in ggml is 64 characters)
    {".acoustic_residual_vector_quantizer", ".acoustic_rvq"},
    {".semantic_residual_vector_quantizer", ".semantic_rvq"},
});

You can then query a tensor by its name:

// example with simple name
ggml_tensor * inp_norm_b = ctx.get_weight("input_layernorm.bias");

// example with formatted name
const char * prefix = "encoder";
int il = 0;
ggml_tensor * inp_norm_b = ctx.get_weight("%s.layer.%d.input_layernorm.bias", prefix, il);

For a complete example, please have a look on demo/safetensors.cpp where I load both GGUF + safetensors files, then compare them.

Advanced demo

You can find a more advanced demo here, which includes:

  • svd.cpp: Singular Value Decomposition (SVD) using GGML.
  • dyt-rms.cpp: Performance comparison of Dynamic Tanh and RMS Norm.
  • kyutai-mimi.cpp: Kyutai Mimi model using GGML. This works out of the box with the safetensors file, no conversion needed. This powers the Sesame CSM implementation in llama.cpp.
  • whisper-encoder.cpp: Reimplementation of the Whisper encoder using GGML. This is a work in progress and not yet finished at the time of writing. The goal is to do some researches on ultravox support in llama.cpp.

Conclusion

In this article, I introduced ggml-easy, a simple and easy-to-use wrapper for working with GGML. It simplifies the API, provides debugging utilities, and supports safetensors natively.

I hope you find it useful and I would love to hear your feedback. You can find the code on ngxson/ggml-easy

Want to receive latest articles from my blog?
Follow on
Discussion