Skip to content

Conversation

@phymbert
Copy link
Collaborator

@phymbert phymbert commented Mar 20, 2024

Motivation

Since we support gguf-split CLI in #6135, it is the good time to load model with multiple (potentially distributed) GGUFs for weights. For example, we can expect the Grok-1 weights to not easily fit inside a single GGUF.

This change allows to load a model regardless if it is bundled inside a single or multiple GGUFs generated with gguf-split.

Changes

  • each file is memory mapped to a distinct address, tensors are not continuous anymore in memory
  • backends that support mmap like CPU and Metal, now have different backend buffer for each file
  • introduce llama_split_path and llama_split_prefix to allow downstream tool to generate their own GGUFs split using the same file name convention: "%s-%05d-of-%05d.gguf"
  • rename GGUF KV general.split to split.no, general.split_count to split.count and add split.tensors.count: the previous splits created will not be loaded here. Use gguf-split from d0d5de4 to merge first, then split again with master version

Tests

  1. Download
cd models
../scripts/hf.sh --repo ggml-org/models --file phi-2/ggml-model-q4_0.gguf
  1. Split
gguf-split --split --split-max-tensors 64 models/ggml-model-q4_0.gguf ggml-model-q4_0-split
  1. Load
main --model models/ggml-model-q4_0-split-00001-of-00006.gguf -ngl 33 --random-prompt

You will notice the new: llama_model_loader: additional 6 GGUFs metadata loaded.

  1. Merge it back (not necessary anymore)
gguf-split --merge models/ggml-model-q4_0-split-00001-of-00006.gguf models/ggml-model-q4_0-merge.gguf
  1. Confirm single GGUF still work
main --model models/ggml-model-q4_0-merge.gguf -ngl 33 --random-prompt

References

CI Builds

Tasks

  • works on CPU backend
  • works on CUDA backend full layers offloaded
  • work on CUDA backend half layers offloaded
  • works on metal

Special thanks to @slaren and @ngxson for having supporting me in this effort.

@phymbert phymbert requested review from ggerganov, ngxson and slaren March 20, 2024 21:31
 - use only one gguf_context for metadata only
 - store all ggml_context in a vector as the files and mappings
 - store all weights in a vector along with the source tensor
 - rename ctx_gguf to meta
 - rename ctx_meta to contexts
@phymbert phymbert requested review from ggerganov, ngxson and slaren March 21, 2024 18:12
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is something not right when mmap is enabled and -ngl > 0 (at least with Metal that is). Using LLaMA 7B Q4_0:

llm_load_tensors: ggml ctx size =    0.22 MiB
ggml_backend_metal_buffer_from_ptr: allocated buffer, size =   933.69 MiB, (  933.75 / 147456.00)
ggml_backend_metal_buffer_from_ptr: allocated buffer, size =   933.69 MiB, ( 1867.44 / 147456.00)
ggml_backend_metal_buffer_from_ptr: allocated buffer, size =   933.69 MiB, ( 2801.12 / 147456.00)
ggml_backend_metal_buffer_from_ptr: error: failed to allocate buffer, size =   933.69 MiB
ggml_backend_metal_buffer_from_ptr: error: failed to allocate buffer, size =   933.69 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU
llm_load_tensors: offloaded 33/33 layers to GPU
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:        CPU buffer size =    70.31 MiB
llm_load_tensors:      Metal buffer size =   933.69 MiB
llm_load_tensors:      Metal buffer size =   933.69 MiB
llm_load_tensors:      Metal buffer size =   933.69 MiB
.................................................................llama_model_load: error loading model: vector
llama_load_model_from_file: failed to load model
llama_init_from_gpt_params: error: failed to load model './x-00001-of-00005.gguf'
main: error: unable to load model

phymbert and others added 2 commits March 21, 2024 20:50
…nsor optional

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
@phymbert
Copy link
Collaborator Author

ggml_backend_metal_buffer_from_ptr: allocated buffer, size = 933.69 MiB, ( 933.75 / 147456.00)
ggml_backend_metal_buffer_from_ptr: allocated buffer, size = 933.69 MiB, ( 1867.44 / 147456.00)
ggml_backend_metal_buffer_from_ptr: allocated buffer, size = 933.69 MiB, ( 2801.12 / 147456.00)
ggml_backend_metal_buffer_from_ptr: error: failed to allocate buffer, size = 933.69 MiB
ggml_backend_metal_buffer_from_ptr: error: failed to allocate buffer, size = 933.69 MiB
llm_load_tensors: offloading 32 repeating layers to GPU
llm_load_tensors: offloading non-repeating layers to GPU

So it does not manage to allocate 2 out 5 metal buffer, I think we should stop here. Then it tried to load the mapping to a buffer which does not exist.
@slaren Do you know why ggml_backend_metal_buffer_from_ptr can failed but just before ggml_backend_cpu_buffer_from_ptr succeded ?

@phymbert
Copy link
Collaborator Author

phymbert commented Mar 21, 2024

So it does not manage to allocate 2 out 5 metal buffer, I think we should stop here. Then it tried to load the mapping to a buffer which does not exist. @slaren Do you know why ggml_backend_metal_buffer_from_ptr can failed but just before ggml_backend_cpu_buffer_from_ptr succeded ?

@ggerganov Should we accept this case when we cannot allocate n_split metal buffer ? it means metal backend will not have all weights loaded, is it ok ?

@slaren
Copy link
Member

slaren commented Mar 21, 2024

I think there is something wrong. It should only require one CPU buffer, since there is only one tensor allocated in the CPU.

The first-last range is probably wrong, and it is causing a buffer overflow.

@slaren
Copy link
Member

slaren commented Mar 21, 2024

This should fix it:

diff --git a/llama.cpp b/llama.cpp
index cd20ad7a..2b6a5e9e 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -3199,6 +3199,9 @@ struct llama_model_loader {
         *addr = mapping->addr;
         for (ggml_tensor * tensor = ggml_get_first_tensor(ctx); tensor; tensor = ggml_get_next_tensor(ctx, tensor)) {
             const auto & w = get_weights(ggml_get_name(tensor));
+            if (w.idx != idx) {
+                continue;
+            }
             *first = std::min(*first, w.offs);
             *last  = std::max(*last, w.offs + ggml_nbytes(tensor));
         }
@@ -5145,6 +5148,9 @@ static bool llm_load_tensors(
                 void * addr = nullptr;
                 size_t first, last;
                 ml.get_mapping_range(&first, &last, &addr, file_no, ctx);
+                if (first >= last) {
+                    continue;
+                }
                 ggml_backend_buffer_t buf = ggml_backend_cpu_buffer_from_ptr((char *)addr + first, last - first);
                 if (buf != nullptr) {
                     bufs.push_back(buf);
@@ -5167,6 +5173,9 @@ static bool llm_load_tensors(
                 void * addr = nullptr;
                 size_t first, last;
                 ml.get_mapping_range(&first, &last, &addr, file_no, ctx);
+                if (first >= last) {
+                    continue;
+                }
                 ggml_backend_buffer_t buf = ggml_backend_metal_buffer_from_ptr((char *) addr + first, last - first, max_size);
                 if (buf != nullptr) {
                     bufs.push_back(buf);

Maybe we need to add a dummy NULL buffer in this case so that it does not mess with the indices of the vector?

@phymbert
Copy link
Collaborator Author

@ggerganov can you please pull and retry ? I have applied the same logic as before: if the allocation failed, fallback to cpu only

@slaren
Copy link
Member

slaren commented Mar 21, 2024

@phymbert the logic is still wrong, it is asking Metal to map a buffer beyond its size. Please check the diff I posted above.

@ggerganov
Copy link
Member

It works with the patch. Will take an extra look at the PR tomorrow. Thank you all for helping out

@phymbert phymbert changed the title llama_model_loader: support multiple split GGUFs llama_model_loader: support multiple split/shard GGUFs Mar 22, 2024
…t dest max len.

Co-authored-by: Xuan Son Nguyen <thichthat@gmail.com>
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can merge after slaren's approval

@phymbert
Copy link
Collaborator Author

We can merge after slaren's approval

Excellent, really proud of to have contribute until llama.h! Thanks all for your help, guidance and co-authoring this feature🧑‍🤝‍🧑👩🏽‍💻

@phymbert
Copy link
Collaborator Author

@ggerganov Could we add this line in Hot topics ?

- support loading sharded model split using `gguf-split` cli #6187

@ggerganov
Copy link
Member

Yes, of course

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for taking time to implement this functionality.

phymbert and others added 2 commits March 22, 2024 14:44
@phymbert phymbert merged commit dba1af6 into master Mar 22, 2024
@phymbert phymbert deleted the hp/split/load-model branch March 22, 2024 18:00
// check if dest ends with postfix
int size_prefix = str_split_path.size() - str_postfix.size();
if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) {
snprintf(dest, std::min((size_t) size_prefix, maxlen), "%s", split_path);
Copy link
Collaborator Author

@phymbert phymbert Mar 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngxson It must be snprintf(dest, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path);
I am fixing it in #6192

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah sorry I was quite rush this time, will be more careful. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants