mirror of https://github.com/leafspark/AutoGGUF
chore: update for new year and improve compliance
- updated copyright year in LICENSE file to 2025
- bundled llama.cpp licensing text in About menu to maintain MIT compliance
- updated llama.cpp and gguf Python library and scripts
- adjusted monitoring intervals from 0.2s to 0.5s
- updated Python requirements to latest compatible versions
- added new HF to GGUF conversion types: `tq1_0` and `tq2_0`
Happy New Year 🎉!
This commit is contained in:
parent
ddbf96c8e9
commit
102e3a14fd
2
LICENSE
2
LICENSE
|
@ -186,7 +186,7 @@
|
|||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2024 leafspark
|
||||
Copyright (c) 2024-2025 leafspark
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
PyYAML~=6.0.2
|
||||
psutil~=6.1.0
|
||||
psutil~=6.1.1
|
||||
pynvml~=12.0.0
|
||||
PySide6~=6.8.1
|
||||
safetensors~=0.4.5
|
||||
safetensors~=0.5.0
|
||||
numpy<2.0.0
|
||||
torch~=2.5.1
|
||||
sentencepiece~=0.2.0
|
||||
setuptools~=75.5.0
|
||||
huggingface-hub~=0.26.5
|
||||
transformers~=4.47.0
|
||||
setuptools~=75.6.0
|
||||
huggingface-hub~=0.27.0
|
||||
transformers~=4.47.1
|
||||
fastapi~=0.115.6
|
||||
uvicorn~=0.34.0
|
||||
|
|
|
@ -500,7 +500,7 @@ def __init__(self, args: List[str]) -> None:
|
|||
# Timer for updating system info
|
||||
self.timer = QTimer()
|
||||
self.timer.timeout.connect(self.update_system_info)
|
||||
self.timer.start(200)
|
||||
self.timer.start(500)
|
||||
|
||||
# Backend selection
|
||||
backend_layout = QHBoxLayout()
|
||||
|
@ -1023,7 +1023,9 @@ def __init__(self, args: List[str]) -> None:
|
|||
hf_to_gguf_layout.addRow(OUTPUT_FILE, hf_outfile_layout)
|
||||
|
||||
self.hf_outtype = QComboBox()
|
||||
self.hf_outtype.addItems(["f32", "f16", "bf16", "q8_0", "auto"])
|
||||
self.hf_outtype.addItems(
|
||||
["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"]
|
||||
)
|
||||
hf_to_gguf_layout.addRow(OUTPUT_TYPE, self.hf_outtype)
|
||||
|
||||
self.hf_vocab_only = QCheckBox(VOCAB_ONLY)
|
||||
|
|
|
@ -95,7 +95,7 @@ def __init__(self, parent=None) -> None:
|
|||
|
||||
self.timer = QTimer(self)
|
||||
self.timer.timeout.connect(self.update_gpu_info)
|
||||
self.timer.start(200) # Update every 0.2 seconds
|
||||
self.timer.start(500) # Update every 0.5 seconds
|
||||
|
||||
self.gpu_data = []
|
||||
self.vram_data = []
|
||||
|
@ -192,7 +192,7 @@ def update_graph_data() -> None:
|
|||
|
||||
timer = QTimer(dialog)
|
||||
timer.timeout.connect(update_graph_data)
|
||||
timer.start(200) # Update every 0.2 seconds
|
||||
timer.start(500) # Update every 0.5 seconds
|
||||
|
||||
dialog.exec()
|
||||
|
||||
|
@ -227,7 +227,7 @@ def update_graph_data() -> None:
|
|||
|
||||
timer = QTimer(dialog)
|
||||
timer.timeout.connect(update_graph_data)
|
||||
timer.start(200) # Update every 0.2 seconds
|
||||
timer.start(500) # Update every 0.5 seconds
|
||||
|
||||
tab_widget.addTab(gpu_graph, GPU_USAGE_OVER_TIME)
|
||||
tab_widget.addTab(vram_graph, VRAM_USAGE_OVER_TIME)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -18,14 +18,16 @@
|
|||
SupportsIndex,
|
||||
cast,
|
||||
)
|
||||
from transformers import AutoConfig
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
from gguf.constants import *
|
||||
import gguf
|
||||
|
||||
# reuse model definitions from convert_hf_to_gguf.py
|
||||
from convert_hf_to_gguf import LazyTorchTensor, Model
|
||||
|
||||
logger = logging.getLogger("lora-to-gguf")
|
||||
|
@ -37,9 +39,10 @@ class PartialLoraTensor:
|
|||
B: Tensor | None = None
|
||||
|
||||
|
||||
# magic to support tensor shape modifications and splitting
|
||||
class LoraTorchTensor:
|
||||
_lora_A: Tensor
|
||||
_lora_B: Tensor
|
||||
_lora_A: Tensor # (n_rank, row_size)
|
||||
_lora_B: Tensor # (col_size, n_rank)
|
||||
_rank: int
|
||||
|
||||
def __init__(self, A: Tensor, B: Tensor):
|
||||
|
@ -57,14 +60,20 @@ def get_lora_A_B(self) -> tuple[Tensor, Tensor]:
|
|||
|
||||
def __getitem__(
|
||||
self,
|
||||
indices: SupportsIndex | slice | tuple[SupportsIndex | slice | Tensor, ...],
|
||||
indices: (
|
||||
SupportsIndex
|
||||
| slice
|
||||
| tuple[
|
||||
SupportsIndex | slice | Tensor, ...
|
||||
] # TODO: add ellipsis in the type signature
|
||||
),
|
||||
) -> LoraTorchTensor:
|
||||
shape = self.shape
|
||||
if isinstance(indices, SupportsIndex):
|
||||
if len(shape) > 2:
|
||||
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # can't return a vector
|
||||
elif isinstance(indices, slice):
|
||||
if len(shape) > 2:
|
||||
return LoraTorchTensor(self._lora_A[indices], self._lora_B[indices])
|
||||
|
@ -74,7 +83,7 @@ def __getitem__(
|
|||
assert len(indices) > 0
|
||||
if indices[-1] is Ellipsis:
|
||||
return self[indices[:-1]]
|
||||
|
||||
# expand ellipsis
|
||||
indices = tuple(
|
||||
u
|
||||
for v in (
|
||||
|
@ -94,6 +103,7 @@ def __getitem__(
|
|||
*(slice(None, None) for _ in range(len(indices), len(shape))),
|
||||
)
|
||||
|
||||
# TODO: make sure this is correct
|
||||
indices_A = (
|
||||
*(
|
||||
(
|
||||
|
@ -109,7 +119,7 @@ def __getitem__(
|
|||
indices_B = indices[:-1]
|
||||
return LoraTorchTensor(self._lora_A[indices_A], self._lora_B[indices_B])
|
||||
else:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # unknown indice type
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
@ -132,8 +142,9 @@ def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
|
|||
new_shape = cast(tuple[int, ...], shape)
|
||||
orig_shape = self.shape
|
||||
if len(new_shape) < 2:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # can't become a vector
|
||||
|
||||
# expand -1 in the shape
|
||||
if any(dim == -1 for dim in new_shape):
|
||||
n_elems = prod(orig_shape)
|
||||
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
|
||||
|
@ -143,7 +154,7 @@ def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
|
|||
)
|
||||
|
||||
if new_shape[-1] != orig_shape[-1]:
|
||||
raise NotImplementedError
|
||||
raise NotImplementedError # can't reshape the row size trivially
|
||||
|
||||
shape_A = (*(1 for _ in new_shape[:-2]), self._rank, orig_shape[-1])
|
||||
shape_B = (*new_shape[:-1], self._rank)
|
||||
|
@ -162,7 +173,7 @@ def permute(self, *dims: int) -> LoraTorchTensor:
|
|||
shape = self.shape
|
||||
dims = tuple(dim - len(shape) if dim >= 0 else dim for dim in dims)
|
||||
if dims[-1] == -1:
|
||||
|
||||
# TODO: support higher dimensional A shapes bigger than 1
|
||||
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
|
||||
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
|
||||
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
|
||||
|
@ -170,7 +181,7 @@ def permute(self, *dims: int) -> LoraTorchTensor:
|
|||
self._lora_B.permute(*dims), self._lora_A.permute(*dims)
|
||||
)
|
||||
else:
|
||||
|
||||
# TODO: compose the above two
|
||||
raise NotImplementedError
|
||||
|
||||
def transpose(self, dim0: int, dim1: int) -> LoraTorchTensor:
|
||||
|
@ -189,7 +200,7 @@ def to(self, *args, **kwargs):
|
|||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
|
||||
del types
|
||||
del types # unused
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
@ -230,28 +241,73 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
|
|||
base_name = lora_tensor_name.replace("base_model.model.", "")
|
||||
base_name = base_name.replace(".lora_A.weight", ".weight")
|
||||
base_name = base_name.replace(".lora_B.weight", ".weight")
|
||||
# models produced by mergekit-extract-lora have token embeddings in the adapter
|
||||
base_name = base_name.replace(".lora_embedding_A", ".weight")
|
||||
base_name = base_name.replace(".lora_embedding_B", ".weight")
|
||||
return base_name
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--outfile", type=Path)
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a Hugging Face PEFT LoRA adapter to a GGUF file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outfile",
|
||||
type=Path,
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype",
|
||||
type=str,
|
||||
choices=["f32", "f16", "bf16", "q8_0", "auto"],
|
||||
default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian",
|
||||
action="store_true",
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-lazy",
|
||||
action="store_true",
|
||||
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="only print out what will be done, without writing any new files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base",
|
||||
type=Path,
|
||||
help="directory containing Hugging Face model config files (config.json, tokenizer.json) for the base model that the adapter is based on - only config is needed, actual model weights are not required. If base model is unspecified, it will be loaded from Hugging Face hub based on the adapter config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-model-id",
|
||||
type=str,
|
||||
help="the model ID of the base model, if it is not available locally or in the adapter config. If specified, it will ignore --base and load the base model config from the Hugging Face hub (Example: 'meta-llama/Llama-3.2-1B-Instruct')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"lora_path",
|
||||
type=Path,
|
||||
help="directory containing Hugging Face PEFT LoRA config (adapter_model.json) and weights (adapter_model.safetensors or adapter_model.bin)",
|
||||
)
|
||||
parser.add_argument("--bigendian", action="store_true")
|
||||
parser.add_argument("--no-lazy", action="store_true")
|
||||
parser.add_argument("--verbose", action="store_true")
|
||||
parser.add_argument("--dry-run", action="store_true")
|
||||
parser.add_argument("--base", type=Path, required=True)
|
||||
parser.add_argument("lora_path", type=Path)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_hparams_from_hf(hf_model_id: str) -> dict[str, Any]:
|
||||
# normally, adapter does not come with base model config, we need to load it from AutoConfig
|
||||
config = AutoConfig.from_pretrained(hf_model_id)
|
||||
return config.to_dict()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
@ -266,19 +322,20 @@ def parse_args() -> argparse.Namespace:
|
|||
|
||||
ftype = ftype_map[args.outtype]
|
||||
|
||||
dir_base_model: Path = args.base
|
||||
dir_base_model: Path | None = args.base
|
||||
dir_lora: Path = args.lora_path
|
||||
base_model_id: str | None = args.base_model_id
|
||||
lora_config = dir_lora / "adapter_config.json"
|
||||
input_model = dir_lora / "adapter_model.safetensors"
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
|
||||
# output in the same directory as the model by default
|
||||
fname_out = dir_lora
|
||||
|
||||
if os.path.exists(input_model):
|
||||
|
||||
# lazy import load_file only if lora is in safetensors format.
|
||||
from safetensors.torch import load_file
|
||||
|
||||
lora_model = load_file(input_model, device="cpu")
|
||||
|
@ -286,8 +343,38 @@ def parse_args() -> argparse.Namespace:
|
|||
input_model = os.path.join(dir_lora, "adapter_model.bin")
|
||||
lora_model = torch.load(input_model, map_location="cpu", weights_only=True)
|
||||
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
# load LoRA config
|
||||
with open(lora_config, "r") as f:
|
||||
lparams: dict[str, Any] = json.load(f)
|
||||
|
||||
# load base model
|
||||
if base_model_id is not None:
|
||||
logger.info(f"Loading base model from Hugging Face: {base_model_id}")
|
||||
hparams = load_hparams_from_hf(base_model_id)
|
||||
elif dir_base_model is None:
|
||||
if "base_model_name_or_path" in lparams:
|
||||
model_id = lparams["base_model_name_or_path"]
|
||||
logger.info(f"Loading base model from Hugging Face: {model_id}")
|
||||
try:
|
||||
hparams = load_hparams_from_hf(model_id)
|
||||
except OSError as e:
|
||||
logger.error(f"Failed to load base model config: {e}")
|
||||
logger.error(
|
||||
"Please try downloading the base model and add its path to --base"
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.error(
|
||||
"'base_model_name_or_path' is not found in adapter_config.json"
|
||||
)
|
||||
logger.error(
|
||||
"Base model config is required. Please download the base model and add its path to --base"
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
logger.info(f"Loading base model: {dir_base_model.name}")
|
||||
hparams = Model.load_hparams(dir_base_model)
|
||||
|
||||
with torch.inference_mode():
|
||||
try:
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
|
@ -309,6 +396,9 @@ def __init__(
|
|||
self.dir_model_card = dir_lora_model
|
||||
self.lora_alpha = float(lora_alpha)
|
||||
|
||||
def set_vocab(self):
|
||||
pass
|
||||
|
||||
def set_type(self):
|
||||
self.gguf_writer.add_type(gguf.GGUFType.ADAPTER)
|
||||
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||
|
@ -317,7 +407,10 @@ def set_gguf_parameters(self):
|
|||
self.gguf_writer.add_float32(
|
||||
gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha
|
||||
)
|
||||
super().set_gguf_parameters()
|
||||
|
||||
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
|
||||
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters
|
||||
return ()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
tensor_map: dict[str, PartialLoraTensor] = {}
|
||||
|
@ -326,14 +419,26 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
|||
if self.lazy:
|
||||
tensor = LazyTorchTensor.from_eager(tensor)
|
||||
base_name = get_base_tensor_name(name)
|
||||
is_lora_a = ".lora_A.weight" in name
|
||||
is_lora_b = ".lora_B.weight" in name
|
||||
# note: mergekit-extract-lora also adds token embeddings to the adapter
|
||||
is_lora_a = ".lora_A.weight" in name or ".lora_embedding_A" in name
|
||||
is_lora_b = ".lora_B.weight" in name or ".lora_embedding_B" in name
|
||||
if not is_lora_a and not is_lora_b:
|
||||
if ".base_layer.weight" in name:
|
||||
continue
|
||||
# mergekit-extract-lora add these layernorm to the adapter, we need to keep them
|
||||
if "_layernorm" in name or ".norm" in name:
|
||||
yield (base_name, tensor)
|
||||
continue
|
||||
logger.error(
|
||||
f"Unexpected name '{name}': Not a lora_A or lora_B tensor"
|
||||
)
|
||||
if ".embed_tokens.weight" in name or ".lm_head.weight" in name:
|
||||
logger.error(
|
||||
"Embeddings is present in the adapter. This can be due to new tokens added during fine tuning"
|
||||
)
|
||||
logger.error(
|
||||
"Please refer to https://github.com/ggerganov/llama.cpp/pull/9948"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if base_name in tensor_map:
|
||||
|
@ -358,17 +463,34 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
|||
def modify_tensors(
|
||||
self, data_torch: Tensor, name: str, bid: int | None
|
||||
) -> Iterable[tuple[str, Tensor]]:
|
||||
dest = super().modify_tensors(data_torch, name, bid)
|
||||
dest = list(super().modify_tensors(data_torch, name, bid))
|
||||
# some archs may have the same tensor for lm_head and output (tie word embeddings)
|
||||
# in this case, adapters targeting lm_head will fail when using llama-export-lora
|
||||
# therefore, we ignore them for now
|
||||
# see: https://github.com/ggerganov/llama.cpp/issues/9065
|
||||
if name == "lm_head.weight" and len(dest) == 0:
|
||||
raise ValueError(
|
||||
"lm_head is present in adapter, but is ignored in base model"
|
||||
)
|
||||
for dest_name, dest_data in dest:
|
||||
# mergekit-extract-lora add these layernorm to the adapter
|
||||
if "_norm" in dest_name:
|
||||
assert dest_data.dim() == 1
|
||||
yield (dest_name, dest_data)
|
||||
continue
|
||||
|
||||
# otherwise, we must get the lora_A and lora_B tensors
|
||||
assert isinstance(dest_data, LoraTorchTensor)
|
||||
lora_a, lora_b = dest_data.get_lora_A_B()
|
||||
|
||||
# note: mergekit-extract-lora flip and transpose A and B
|
||||
# here we only need to transpose token_embd.lora_a, see llm_build_inp_embd()
|
||||
if "token_embd.weight" in dest_name:
|
||||
lora_a = lora_a.T
|
||||
|
||||
yield (dest_name + ".lora_a", lora_a)
|
||||
yield (dest_name + ".lora_b", lora_b)
|
||||
|
||||
with open(lora_config, "r") as f:
|
||||
lparams: dict[str, Any] = json.load(f)
|
||||
|
||||
alpha: float = lparams["lora_alpha"]
|
||||
|
||||
model_instance = LoraModel(
|
||||
|
@ -381,7 +503,7 @@ def modify_tensors(
|
|||
dry_run=args.dry_run,
|
||||
dir_lora_model=dir_lora,
|
||||
lora_alpha=alpha,
|
||||
is_lora=True,
|
||||
hparams=hparams,
|
||||
)
|
||||
|
||||
logger.info("Exporting model...")
|
||||
|
|
|
@ -3,10 +3,18 @@
|
|||
from enum import Enum, IntEnum, auto
|
||||
from typing import Any
|
||||
|
||||
GGUF_MAGIC = 0x46554747
|
||||
#
|
||||
# constants
|
||||
#
|
||||
|
||||
GGUF_MAGIC = 0x46554747 # "GGUF"
|
||||
GGUF_VERSION = 3
|
||||
GGUF_DEFAULT_ALIGNMENT = 32
|
||||
GGML_QUANT_VERSION = 2
|
||||
GGML_QUANT_VERSION = 2 # GGML_QNT_VERSION from ggml.h
|
||||
|
||||
#
|
||||
# metadata keys
|
||||
#
|
||||
|
||||
|
||||
class Keys:
|
||||
|
@ -17,6 +25,7 @@ class General:
|
|||
ALIGNMENT = "general.alignment"
|
||||
FILE_TYPE = "general.file_type"
|
||||
|
||||
# Authorship Metadata
|
||||
NAME = "general.name"
|
||||
AUTHOR = "general.author"
|
||||
VERSION = "general.version"
|
||||
|
@ -30,38 +39,62 @@ class General:
|
|||
|
||||
SIZE_LABEL = "general.size_label"
|
||||
|
||||
# Licensing details
|
||||
LICENSE = "general.license"
|
||||
LICENSE_NAME = "general.license.name"
|
||||
LICENSE_LINK = "general.license.link"
|
||||
|
||||
URL = "general.url"
|
||||
# Typically represents the converted GGUF repo (Unless native)
|
||||
URL = "general.url" # Model Website/Paper
|
||||
DOI = "general.doi"
|
||||
UUID = "general.uuid"
|
||||
REPO_URL = "general.repo_url"
|
||||
REPO_URL = "general.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
SOURCE_URL = "general.source.url"
|
||||
# Model Source during conversion
|
||||
SOURCE_URL = "general.source.url" # Model Website/Paper
|
||||
SOURCE_DOI = "general.source.doi"
|
||||
SOURCE_UUID = "general.source.uuid"
|
||||
SOURCE_REPO_URL = "general.source.repo_url"
|
||||
SOURCE_REPO_URL = (
|
||||
"general.source.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
)
|
||||
|
||||
# Base Model Source. There can be more than one source if it's a merged
|
||||
# model like with 'Mistral-7B-Merge-14-v0.1'. This will assist in
|
||||
# tracing linage of models as it is finetuned or merged over time.
|
||||
BASE_MODEL_COUNT = "general.base_model.count"
|
||||
BASE_MODEL_NAME = "general.base_model.{id}.name"
|
||||
BASE_MODEL_AUTHOR = "general.base_model.{id}.author"
|
||||
BASE_MODEL_VERSION = "general.base_model.{id}.version"
|
||||
BASE_MODEL_ORGANIZATION = "general.base_model.{id}.organization"
|
||||
BASE_MODEL_URL = "general.base_model.{id}.url"
|
||||
BASE_MODEL_DESCRIPTION = "general.base_model.{id}.description"
|
||||
BASE_MODEL_URL = "general.base_model.{id}.url" # Model Website/Paper
|
||||
BASE_MODEL_DOI = "general.base_model.{id}.doi"
|
||||
BASE_MODEL_UUID = "general.base_model.{id}.uuid"
|
||||
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url"
|
||||
BASE_MODEL_REPO_URL = "general.base_model.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
|
||||
# Dataset Source
|
||||
DATASET_COUNT = "general.dataset.count"
|
||||
DATASET_NAME = "general.dataset.{id}.name"
|
||||
DATASET_AUTHOR = "general.dataset.{id}.author"
|
||||
DATASET_VERSION = "general.dataset.{id}.version"
|
||||
DATASET_ORGANIZATION = "general.dataset.{id}.organization"
|
||||
DATASET_DESCRIPTION = "general.dataset.{id}.description"
|
||||
DATASET_URL = "general.dataset.{id}.url" # Model Website/Paper
|
||||
DATASET_DOI = "general.dataset.{id}.doi"
|
||||
DATASET_UUID = "general.dataset.{id}.uuid"
|
||||
DATASET_REPO_URL = (
|
||||
"general.dataset.{id}.repo_url" # Model Source Repository (git/svn/etc...)
|
||||
)
|
||||
|
||||
# Array based KV stores
|
||||
TAGS = "general.tags"
|
||||
LANGUAGES = "general.languages"
|
||||
DATASETS = "general.datasets"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
FEATURES_LENGTH = "{arch}.features_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
|
@ -73,11 +106,14 @@ class LLM:
|
|||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
||||
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
||||
SWIN_NORM = "{arch}.swin_norm"
|
||||
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
|
||||
TIME_MIX_EXTRA_DIM = "{arch}.time_mix_extra_dim"
|
||||
TIME_DECAY_EXTRA_DIM = "{arch}.time_decay_extra_dim"
|
||||
|
@ -93,6 +129,8 @@ class Attention:
|
|||
VALUE_LENGTH = "{arch}.attention.value_length"
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
GROUPNORM_EPS = "{arch}.attention.group_norm_epsilon"
|
||||
GROUPNORM_GROUPS = "{arch}.attention.group_norm_groups"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
|
@ -102,6 +140,7 @@ class Attention:
|
|||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
|
||||
FREQ_BASE = "{arch}.rope.freq_base"
|
||||
SCALING_TYPE = "{arch}.rope.scaling.type"
|
||||
SCALING_FACTOR = "{arch}.rope.scaling.factor"
|
||||
|
@ -125,16 +164,28 @@ class SSM:
|
|||
class WKV:
|
||||
HEAD_SIZE = "{arch}.wkv.head_size"
|
||||
|
||||
class PosNet:
|
||||
EMBEDDING_LENGTH = "{arch}.posnet.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.posnet.block_count"
|
||||
|
||||
class ConvNext:
|
||||
EMBEDDING_LENGTH = "{arch}.convnext.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.convnext.block_count"
|
||||
|
||||
class Tokenizer:
|
||||
MODEL = "tokenizer.ggml.model"
|
||||
PRE = "tokenizer.ggml.pre"
|
||||
LIST = "tokenizer.ggml.tokens"
|
||||
TOKEN_TYPE = "tokenizer.ggml.token_type"
|
||||
TOKEN_TYPE_COUNT = "tokenizer.ggml.token_type_count"
|
||||
TOKEN_TYPE_COUNT = (
|
||||
"tokenizer.ggml.token_type_count" # for BERT-style token types
|
||||
)
|
||||
SCORES = "tokenizer.ggml.scores"
|
||||
MERGES = "tokenizer.ggml.merges"
|
||||
BOS_ID = "tokenizer.ggml.bos_token_id"
|
||||
EOS_ID = "tokenizer.ggml.eos_token_id"
|
||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||
EOM_ID = "tokenizer.ggml.eom_token_id"
|
||||
UNK_ID = "tokenizer.ggml.unknown_token_id"
|
||||
SEP_ID = "tokenizer.ggml.seperator_token_id"
|
||||
PAD_ID = "tokenizer.ggml.padding_token_id"
|
||||
|
@ -150,18 +201,28 @@ class Tokenizer:
|
|||
CHAT_TEMPLATE = "tokenizer.chat_template"
|
||||
CHAT_TEMPLATE_N = "tokenizer.chat_template.{name}"
|
||||
CHAT_TEMPLATES = "tokenizer.chat_templates"
|
||||
|
||||
# FIM/Infill special tokens constants
|
||||
FIM_PRE_ID = "tokenizer.ggml.fim_pre_token_id"
|
||||
FIM_SUF_ID = "tokenizer.ggml.fim_suf_token_id"
|
||||
FIM_MID_ID = "tokenizer.ggml.fim_mid_token_id"
|
||||
FIM_PAD_ID = "tokenizer.ggml.fim_pad_token_id"
|
||||
FIM_REP_ID = "tokenizer.ggml.fim_rep_token_id"
|
||||
FIM_SEP_ID = "tokenizer.ggml.fim_sep_token_id"
|
||||
# deprecated:
|
||||
PREFIX_ID = "tokenizer.ggml.prefix_token_id"
|
||||
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
|
||||
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
|
||||
EOT_ID = "tokenizer.ggml.eot_token_id"
|
||||
EOM_ID = "tokenizer.ggml.eom_token_id"
|
||||
|
||||
class Adapter:
|
||||
TYPE = "adapter.type"
|
||||
LORA_ALPHA = "adapter.lora.alpha"
|
||||
|
||||
|
||||
#
|
||||
# recommended mapping of model tensor names for storage in gguf
|
||||
#
|
||||
|
||||
|
||||
class GGUFType:
|
||||
MODEL = "model"
|
||||
ADAPTER = "adapter"
|
||||
|
@ -169,6 +230,7 @@ class GGUFType:
|
|||
|
||||
class MODEL_ARCH(IntEnum):
|
||||
LLAMA = auto()
|
||||
DECI = auto()
|
||||
FALCON = auto()
|
||||
BAICHUAN = auto()
|
||||
GROK = auto()
|
||||
|
@ -186,6 +248,7 @@ class MODEL_ARCH(IntEnum):
|
|||
QWEN = auto()
|
||||
QWEN2 = auto()
|
||||
QWEN2MOE = auto()
|
||||
QWEN2VL = auto()
|
||||
PHI2 = auto()
|
||||
PHI3 = auto()
|
||||
PLAMO = auto()
|
||||
|
@ -199,14 +262,16 @@ class MODEL_ARCH(IntEnum):
|
|||
STARCODER2 = auto()
|
||||
RWKV6 = auto()
|
||||
MAMBA = auto()
|
||||
JAMBA = auto()
|
||||
XVERSE = auto()
|
||||
COMMAND_R = auto()
|
||||
COHERE2 = auto()
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
OLMO2 = auto()
|
||||
OLMOE = auto()
|
||||
OPENELM = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
CHATGLM = auto()
|
||||
BITNET = auto()
|
||||
|
@ -216,6 +281,9 @@ class MODEL_ARCH(IntEnum):
|
|||
NEMOTRON = auto()
|
||||
EXAONE = auto()
|
||||
GRANITE = auto()
|
||||
GRANITE_MOE = auto()
|
||||
CHAMELEON = auto()
|
||||
WAVTOKENIZER_DEC = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
|
@ -254,6 +322,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
FFN_EXP_PROBS_B = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
LAYER_OUT_NORM = auto()
|
||||
|
@ -261,10 +330,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
SSM_CONV1D = auto()
|
||||
SSM_X = auto()
|
||||
SSM_DT = auto()
|
||||
SSM_DT_NORM = auto()
|
||||
SSM_A = auto()
|
||||
SSM_B_NORM = auto()
|
||||
SSM_C_NORM = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
TIME_MIX_W1 = auto()
|
||||
|
@ -326,10 +392,29 @@ class MODEL_TENSOR(IntEnum):
|
|||
ENC_FFN_DOWN = auto()
|
||||
ENC_FFN_UP = auto()
|
||||
ENC_OUTPUT_NORM = auto()
|
||||
CLS = auto() # classifier
|
||||
CLS_OUT = auto() # classifier output projection
|
||||
CONV1D = auto()
|
||||
CONVNEXT_DW = auto()
|
||||
CONVNEXT_NORM = auto()
|
||||
CONVNEXT_PW1 = auto()
|
||||
CONVNEXT_PW2 = auto()
|
||||
CONVNEXT_GAMMA = auto()
|
||||
POSNET_CONV1 = auto()
|
||||
POSNET_CONV2 = auto()
|
||||
POSNET_NORM = auto()
|
||||
POSNET_NORM1 = auto()
|
||||
POSNET_NORM2 = auto()
|
||||
POSNET_ATTN_NORM = auto()
|
||||
POSNET_ATTN_Q = auto()
|
||||
POSNET_ATTN_K = auto()
|
||||
POSNET_ATTN_V = auto()
|
||||
POSNET_ATTN_OUT = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.LLAMA: "llama",
|
||||
MODEL_ARCH.DECI: "deci",
|
||||
MODEL_ARCH.FALCON: "falcon",
|
||||
MODEL_ARCH.BAICHUAN: "baichuan",
|
||||
MODEL_ARCH.GROK: "grok",
|
||||
|
@ -347,6 +432,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_ARCH.QWEN: "qwen",
|
||||
MODEL_ARCH.QWEN2: "qwen2",
|
||||
MODEL_ARCH.QWEN2MOE: "qwen2moe",
|
||||
MODEL_ARCH.QWEN2VL: "qwen2vl",
|
||||
MODEL_ARCH.PHI2: "phi2",
|
||||
MODEL_ARCH.PHI3: "phi3",
|
||||
MODEL_ARCH.PLAMO: "plamo",
|
||||
|
@ -360,14 +446,16 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
MODEL_ARCH.RWKV6: "rwkv6",
|
||||
MODEL_ARCH.MAMBA: "mamba",
|
||||
MODEL_ARCH.JAMBA: "jamba",
|
||||
MODEL_ARCH.XVERSE: "xverse",
|
||||
MODEL_ARCH.COMMAND_R: "command-r",
|
||||
MODEL_ARCH.COHERE2: "cohere2",
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.OLMO2: "olmo2",
|
||||
MODEL_ARCH.OLMOE: "olmoe",
|
||||
MODEL_ARCH.OPENELM: "openelm",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK: "deepseek",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
MODEL_ARCH.CHATGLM: "chatglm",
|
||||
MODEL_ARCH.BITNET: "bitnet",
|
||||
|
@ -377,6 +465,9 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
MODEL_ARCH.GRANITE_MOE: "granitemoe",
|
||||
MODEL_ARCH.CHAMELEON: "chameleon",
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
|
@ -417,15 +508,13 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps",
|
||||
MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps",
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b",
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm",
|
||||
MODEL_TENSOR.SSM_IN: "blk.{bid}.ssm_in",
|
||||
MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d",
|
||||
MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x",
|
||||
MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt",
|
||||
MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm",
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm",
|
||||
MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
|
||||
|
@ -487,6 +576,24 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
|
||||
MODEL_TENSOR.CLS: "cls",
|
||||
MODEL_TENSOR.CLS_OUT: "cls.output",
|
||||
MODEL_TENSOR.CONV1D: "conv1d",
|
||||
MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw",
|
||||
MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm",
|
||||
MODEL_TENSOR.CONVNEXT_PW1: "convnext.{bid}.pw1",
|
||||
MODEL_TENSOR.CONVNEXT_PW2: "convnext.{bid}.pw2",
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA: "convnext.{bid}.gamma",
|
||||
MODEL_TENSOR.POSNET_CONV1: "posnet.{bid}.conv1",
|
||||
MODEL_TENSOR.POSNET_CONV2: "posnet.{bid}.conv2",
|
||||
MODEL_TENSOR.POSNET_NORM: "posnet.{bid}.norm",
|
||||
MODEL_TENSOR.POSNET_NORM1: "posnet.{bid}.norm1",
|
||||
MODEL_TENSOR.POSNET_NORM2: "posnet.{bid}.norm2",
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM: "posnet.{bid}.attn_norm",
|
||||
MODEL_TENSOR.POSNET_ATTN_Q: "posnet.{bid}.attn_q",
|
||||
MODEL_TENSOR.POSNET_ATTN_K: "posnet.{bid}.attn_k",
|
||||
MODEL_TENSOR.POSNET_ATTN_V: "posnet.{bid}.attn_v",
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -510,6 +617,26 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.DECI: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.GROK: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -596,6 +723,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
MODEL_TENSOR.CLS_OUT,
|
||||
],
|
||||
MODEL_ARCH.NOMIC_BERT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -627,6 +756,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||
MODEL_TENSOR.CLS,
|
||||
],
|
||||
MODEL_ARCH.MPT: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -713,6 +843,21 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.QWEN2VL: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
|
@ -790,6 +935,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_QKV,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
|
@ -849,6 +996,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
|
@ -868,6 +1017,8 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG,
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
|
@ -968,34 +1119,6 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
],
|
||||
MODEL_ARCH.JAMBA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_X,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_DT_NORM,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_B_NORM,
|
||||
MODEL_TENSOR.SSM_C_NORM,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.XVERSE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -1026,6 +1149,18 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
],
|
||||
MODEL_ARCH.COHERE2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.DBRX: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -1050,6 +1185,22 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.OLMO2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_POST_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.FFN_POST_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.OLMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -1101,6 +1252,29 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
@ -1127,6 +1301,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
MODEL_ARCH.CHATGLM: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
@ -1248,6 +1423,7 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_ARCH.GRANITE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
|
@ -1258,13 +1434,72 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GRANITE_MOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.CHAMELEON: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.WAVTOKENIZER_DEC: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM,
|
||||
MODEL_TENSOR.CONV1D,
|
||||
MODEL_TENSOR.CONVNEXT_DW,
|
||||
MODEL_TENSOR.CONVNEXT_NORM,
|
||||
MODEL_TENSOR.CONVNEXT_PW1,
|
||||
MODEL_TENSOR.CONVNEXT_PW2,
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.POSNET_CONV1,
|
||||
MODEL_TENSOR.POSNET_CONV2,
|
||||
MODEL_TENSOR.POSNET_NORM,
|
||||
MODEL_TENSOR.POSNET_NORM1,
|
||||
MODEL_TENSOR.POSNET_NORM2,
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM,
|
||||
MODEL_TENSOR.POSNET_ATTN_Q,
|
||||
MODEL_TENSOR.POSNET_ATTN_K,
|
||||
MODEL_TENSOR.POSNET_ATTN_V,
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
# tensors that will not be serialized
|
||||
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_ARCH.LLAMA: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DECI: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.BAICHUAN: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
|
@ -1289,6 +1524,10 @@ class MODEL_TENSOR(IntEnum):
|
|||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
|
@ -1302,6 +1541,10 @@ class MODEL_TENSOR(IntEnum):
|
|||
],
|
||||
}
|
||||
|
||||
#
|
||||
# types
|
||||
#
|
||||
|
||||
|
||||
class TokenType(IntEnum):
|
||||
NORMAL = 1
|
||||
|
@ -1316,6 +1559,7 @@ class RopeScalingType(Enum):
|
|||
NONE = "none"
|
||||
LINEAR = "linear"
|
||||
YARN = "yarn"
|
||||
LONGROPE = "longrope"
|
||||
|
||||
|
||||
class PoolingType(IntEnum):
|
||||
|
@ -1354,52 +1598,61 @@ class GGMLQuantizationType(IntEnum):
|
|||
F64 = 28
|
||||
IQ1_M = 29
|
||||
BF16 = 30
|
||||
Q4_0_4_4 = 31
|
||||
Q4_0_4_8 = 32
|
||||
Q4_0_8_8 = 33
|
||||
TQ1_0 = 34
|
||||
TQ2_0 = 35
|
||||
|
||||
|
||||
class ExpertGatingFuncType(IntEnum):
|
||||
SOFTMAX = 1
|
||||
SIGMOID = 2
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
|
||||
|
||||
# from llama_ftype in llama.h
|
||||
# ALL VALUES SHOULD BE THE SAME HERE AS THEY ARE OVER THERE.
|
||||
class LlamaFileType(IntEnum):
|
||||
ALL_F32 = 0
|
||||
MOSTLY_F16 = 1
|
||||
MOSTLY_Q4_0 = 2
|
||||
MOSTLY_Q4_1 = 3
|
||||
MOSTLY_F16 = 1 # except 1d tensors
|
||||
MOSTLY_Q4_0 = 2 # except 1d tensors
|
||||
MOSTLY_Q4_1 = 3 # except 1d tensors
|
||||
# MOSTLY_Q4_1_SOME_F16 = 4 # tok_embeddings.weight and output.weight are F16
|
||||
# MOSTLY_Q4_2 = 5 # support has been removed
|
||||
# MOSTLY_Q4_3 = 6 # support has been removed
|
||||
MOSTLY_Q8_0 = 7 # except 1d tensors
|
||||
MOSTLY_Q5_0 = 8 # except 1d tensors
|
||||
MOSTLY_Q5_1 = 9 # except 1d tensors
|
||||
MOSTLY_Q2_K = 10 # except 1d tensors
|
||||
MOSTLY_Q3_K_S = 11 # except 1d tensors
|
||||
MOSTLY_Q3_K_M = 12 # except 1d tensors
|
||||
MOSTLY_Q3_K_L = 13 # except 1d tensors
|
||||
MOSTLY_Q4_K_S = 14 # except 1d tensors
|
||||
MOSTLY_Q4_K_M = 15 # except 1d tensors
|
||||
MOSTLY_Q5_K_S = 16 # except 1d tensors
|
||||
MOSTLY_Q5_K_M = 17 # except 1d tensors
|
||||
MOSTLY_Q6_K = 18 # except 1d tensors
|
||||
MOSTLY_IQ2_XXS = 19 # except 1d tensors
|
||||
MOSTLY_IQ2_XS = 20 # except 1d tensors
|
||||
MOSTLY_Q2_K_S = 21 # except 1d tensors
|
||||
MOSTLY_IQ3_XS = 22 # except 1d tensors
|
||||
MOSTLY_IQ3_XXS = 23 # except 1d tensors
|
||||
MOSTLY_IQ1_S = 24 # except 1d tensors
|
||||
MOSTLY_IQ4_NL = 25 # except 1d tensors
|
||||
MOSTLY_IQ3_S = 26 # except 1d tensors
|
||||
MOSTLY_IQ3_M = 27 # except 1d tensors
|
||||
MOSTLY_IQ2_S = 28 # except 1d tensors
|
||||
MOSTLY_IQ2_M = 29 # except 1d tensors
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
# MOSTLY_Q4_0_4_4 = 33 # removed from gguf files, use Q4_0 and runtime repack
|
||||
# MOSTLY_Q4_0_4_8 = 34 # removed from gguf files, use Q4_0 and runtime repack
|
||||
# MOSTLY_Q4_0_8_8 = 35 # removed from gguf files, use Q4_0 and runtime repack
|
||||
MOSTLY_TQ1_0 = 36 # except 1d tensors
|
||||
MOSTLY_TQ2_0 = 37 # except 1d tensors
|
||||
|
||||
MOSTLY_Q8_0 = 7
|
||||
MOSTLY_Q5_0 = 8
|
||||
MOSTLY_Q5_1 = 9
|
||||
MOSTLY_Q2_K = 10
|
||||
MOSTLY_Q3_K_S = 11
|
||||
MOSTLY_Q3_K_M = 12
|
||||
MOSTLY_Q3_K_L = 13
|
||||
MOSTLY_Q4_K_S = 14
|
||||
MOSTLY_Q4_K_M = 15
|
||||
MOSTLY_Q5_K_S = 16
|
||||
MOSTLY_Q5_K_M = 17
|
||||
MOSTLY_Q6_K = 18
|
||||
MOSTLY_IQ2_XXS = 19
|
||||
MOSTLY_IQ2_XS = 20
|
||||
MOSTLY_Q2_K_S = 21
|
||||
MOSTLY_IQ3_XS = 22
|
||||
MOSTLY_IQ3_XXS = 23
|
||||
MOSTLY_IQ1_S = 24
|
||||
MOSTLY_IQ4_NL = 25
|
||||
MOSTLY_IQ3_S = 26
|
||||
MOSTLY_IQ3_M = 27
|
||||
MOSTLY_IQ2_S = 28
|
||||
MOSTLY_IQ2_M = 29
|
||||
MOSTLY_IQ4_XS = 30
|
||||
MOSTLY_IQ1_M = 31
|
||||
MOSTLY_BF16 = 32
|
||||
MOSTLY_Q4_0_4_4 = 33
|
||||
MOSTLY_Q4_0_4_8 = 34
|
||||
MOSTLY_Q4_0_8_8 = 35
|
||||
MOSTLY_TQ1_0 = 36
|
||||
MOSTLY_TQ2_0 = 37
|
||||
|
||||
GUESSED = 1024
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
|
||||
class GGUFEndian(IntEnum):
|
||||
|
@ -1434,11 +1687,12 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
return GGUFValueType.BOOL
|
||||
elif isinstance(val, int):
|
||||
return GGUFValueType.INT32
|
||||
|
||||
# TODO: need help with 64-bit types in Python
|
||||
else:
|
||||
raise ValueError(f"Unknown type: {type(val)}")
|
||||
|
||||
|
||||
# Items here are (block size, type size)
|
||||
QK_K = 256
|
||||
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.F32: (1, 4),
|
||||
|
@ -1470,13 +1724,14 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q4_0_4_4: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_4_8: (32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_8_8: (32, 2 + 16),
|
||||
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
|
||||
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
|
||||
}
|
||||
|
||||
|
||||
# Aliases for backward compatibility.
|
||||
|
||||
# general
|
||||
KEY_GENERAL_ARCHITECTURE = Keys.General.ARCHITECTURE
|
||||
KEY_GENERAL_QUANTIZATION_VERSION = Keys.General.QUANTIZATION_VERSION
|
||||
KEY_GENERAL_ALIGNMENT = Keys.General.ALIGNMENT
|
||||
|
@ -1488,6 +1743,7 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_GENERAL_SOURCE_URL = Keys.General.SOURCE_URL
|
||||
KEY_GENERAL_FILE_TYPE = Keys.General.FILE_TYPE
|
||||
|
||||
# LLM
|
||||
KEY_VOCAB_SIZE = Keys.LLM.VOCAB_SIZE
|
||||
KEY_CONTEXT_LENGTH = Keys.LLM.CONTEXT_LENGTH
|
||||
KEY_EMBEDDING_LENGTH = Keys.LLM.EMBEDDING_LENGTH
|
||||
|
@ -1496,6 +1752,7 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_USE_PARALLEL_RESIDUAL = Keys.LLM.USE_PARALLEL_RESIDUAL
|
||||
KEY_TENSOR_DATA_LAYOUT = Keys.LLM.TENSOR_DATA_LAYOUT
|
||||
|
||||
# attention
|
||||
KEY_ATTENTION_HEAD_COUNT = Keys.Attention.HEAD_COUNT
|
||||
KEY_ATTENTION_HEAD_COUNT_KV = Keys.Attention.HEAD_COUNT_KV
|
||||
KEY_ATTENTION_MAX_ALIBI_BIAS = Keys.Attention.MAX_ALIBI_BIAS
|
||||
|
@ -1503,6 +1760,7 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
|
||||
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
|
||||
|
||||
# RoPE
|
||||
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
|
||||
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
|
||||
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
|
||||
|
@ -1510,12 +1768,14 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
|
||||
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
|
||||
|
||||
# SSM
|
||||
KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL
|
||||
KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE
|
||||
KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE
|
||||
KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK
|
||||
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS
|
||||
|
||||
# tokenization
|
||||
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
|
||||
KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
|
||||
KEY_TOKENIZER_LIST = Keys.Tokenizer.LIST
|
||||
|
@ -1524,6 +1784,8 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
|
||||
KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
|
||||
KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
|
||||
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
|
||||
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID
|
||||
KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
|
||||
KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
|
||||
KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID
|
||||
|
@ -1531,8 +1793,15 @@ def get_type(val: Any) -> GGUFValueType:
|
|||
KEY_TOKENIZER_MASK_ID = Keys.Tokenizer.MASK_ID
|
||||
KEY_TOKENIZER_HF_JSON = Keys.Tokenizer.HF_JSON
|
||||
KEY_TOKENIZER_RWKV = Keys.Tokenizer.RWKV
|
||||
KEY_TOKENIZER_PRIFIX_ID = Keys.Tokenizer.PREFIX_ID
|
||||
|
||||
KEY_TOKENIZER_FIM_PRE_ID = Keys.Tokenizer.FIM_PRE_ID
|
||||
KEY_TOKENIZER_FIM_SUF_ID = Keys.Tokenizer.FIM_SUF_ID
|
||||
KEY_TOKENIZER_FIM_MID_ID = Keys.Tokenizer.FIM_MID_ID
|
||||
KEY_TOKENIZER_FIM_PAD_ID = Keys.Tokenizer.FIM_PAD_ID
|
||||
KEY_TOKENIZER_FIM_REP_ID = Keys.Tokenizer.FIM_REP_ID
|
||||
KEY_TOKENIZER_FIM_SEP_ID = Keys.Tokenizer.FIM_SEP_ID
|
||||
|
||||
# deprecated
|
||||
KEY_TOKENIZER_PREFIX_ID = Keys.Tokenizer.PREFIX_ID
|
||||
KEY_TOKENIZER_SUFFIX_ID = Keys.Tokenizer.SUFFIX_ID
|
||||
KEY_TOKENIZER_MIDDLE_ID = Keys.Tokenizer.MIDDLE_ID
|
||||
KEY_TOKENIZER_EOT_ID = Keys.Tokenizer.EOT_ID
|
||||
KEY_TOKENIZER_EOM_ID = Keys.Tokenizer.EOM_ID
|
||||
|
|
|
@ -169,11 +169,10 @@ def _get(
|
|||
count = int(count)
|
||||
itemsize = int(np.empty([], dtype=dtype).itemsize)
|
||||
end_offs = offset + itemsize * count
|
||||
return (
|
||||
self.data[offset:end_offs]
|
||||
.view(dtype=dtype)[:count]
|
||||
.newbyteorder(override_order or self.byte_order)
|
||||
)
|
||||
arr = self.data[offset:end_offs].view(dtype=dtype)[:count]
|
||||
if override_order is None:
|
||||
return arr
|
||||
return arr.view(arr.dtype.newbyteorder(override_order))
|
||||
|
||||
def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int:
|
||||
if field.name in self.fields:
|
||||
|
|
|
@ -26,12 +26,14 @@
|
|||
RopeScalingType,
|
||||
PoolingType,
|
||||
TokenType,
|
||||
ExpertGatingFuncType,
|
||||
)
|
||||
|
||||
from .quants import quant_shape_from_byte_shape
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
SHARD_NAME_FORMAT = "{:s}-{:05d}-of-{:05d}.gguf"
|
||||
|
||||
|
||||
|
@ -135,7 +137,7 @@ def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
|||
continue
|
||||
elif name.endswith(".lora_b"):
|
||||
if last_lora_a is None or last_lora_a[0] != name[:-1] + "a":
|
||||
|
||||
# Bail when the LoRA pair can't be found trivially
|
||||
logger.warning(
|
||||
"can't measure LoRA size correctly, tensor order is unusual"
|
||||
)
|
||||
|
@ -154,11 +156,14 @@ def get_total_parameter_count(self) -> tuple[int, int, int, int]:
|
|||
|
||||
total_params += size
|
||||
|
||||
# Hopefully this should work even for variable-expert-count models
|
||||
expert_count = (expert_sum // n_expert_tensors) if n_expert_tensors > 0 else 0
|
||||
|
||||
# Negate the total to signal it's likely not exact
|
||||
if last_lora_a is not None:
|
||||
total_params = -total_params
|
||||
|
||||
# NOTE: keep the output in the same order as accepted by 'size_label' in gguf-py/gguf/utility.py
|
||||
return total_params, shared_params, expert_params, expert_count
|
||||
|
||||
def format_shard_names(self, path: Path) -> list[Path]:
|
||||
|
@ -177,7 +182,7 @@ def open_output_file(self, path: Path | None = None) -> None:
|
|||
and self.fout is not None
|
||||
and (path is None or path == self.path)
|
||||
):
|
||||
|
||||
# allow calling this multiple times as long as the path is the same
|
||||
return
|
||||
|
||||
if self.state is not WriterState.NO_FILE:
|
||||
|
@ -206,7 +211,7 @@ def print_plan(self) -> list[Path]:
|
|||
if self.dry_run:
|
||||
logger.info("Dry run, not writing files")
|
||||
for name in filenames:
|
||||
print(name)
|
||||
print(name) # noqa: NP100
|
||||
exit()
|
||||
|
||||
return filenames
|
||||
|
@ -390,11 +395,12 @@ def add_tensor_info(
|
|||
if tensor_dtype == np.uint8:
|
||||
tensor_shape = quant_shape_from_byte_shape(tensor_shape, raw_dtype)
|
||||
|
||||
# make sure there is at least one tensor before splitting
|
||||
if len(self.tensors[-1]) > 0:
|
||||
if (
|
||||
if ( # split when over tensor limit
|
||||
self.split_max_tensors != 0
|
||||
and len(self.tensors[-1]) >= self.split_max_tensors
|
||||
) or (
|
||||
) or ( # split when over size limit
|
||||
self.split_max_size != 0
|
||||
and sum(ti.nbytes for ti in self.tensors[-1].values()) + tensor_nbytes
|
||||
> self.split_max_size
|
||||
|
@ -460,6 +466,8 @@ def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
|
|||
|
||||
fout = self.fout[file_id]
|
||||
|
||||
# pop the first tensor info
|
||||
# TODO: cleaner way to get the first key
|
||||
first_tensor_name = [
|
||||
name for name, _ in zip(self.tensors[file_id].keys(), range(1))
|
||||
][0]
|
||||
|
@ -506,8 +514,11 @@ def write_tensors_to_file(self, *, progress: bool = False) -> None:
|
|||
total = sum(ti.nbytes for ti in tensors.values())
|
||||
shard_bar.reset(total=(total if total > 0 else None))
|
||||
|
||||
# relying on the fact that Python dicts preserve insertion order (since 3.7)
|
||||
for ti in tensors.values():
|
||||
assert ti.tensor is not None
|
||||
assert (
|
||||
ti.tensor is not None
|
||||
) # can only iterate once over the tensors
|
||||
assert ti.tensor.nbytes == ti.nbytes
|
||||
ti.tensor.tofile(fout)
|
||||
if shard_bar is not None:
|
||||
|
@ -631,6 +642,11 @@ def add_base_model_organization(self, source_id: int, organization: str) -> None
|
|||
Keys.General.BASE_MODEL_ORGANIZATION.format(id=source_id), organization
|
||||
)
|
||||
|
||||
def add_base_model_description(self, source_id: int, description: str) -> None:
|
||||
self.add_string(
|
||||
Keys.General.BASE_MODEL_DESCRIPTION.format(id=source_id), description
|
||||
)
|
||||
|
||||
def add_base_model_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_URL.format(id=source_id), url)
|
||||
|
||||
|
@ -643,15 +659,46 @@ def add_base_model_uuid(self, source_id: int, uuid: str) -> None:
|
|||
def add_base_model_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.BASE_MODEL_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_dataset_count(self, source_count: int) -> None:
|
||||
self.add_uint32(Keys.General.DATASET_COUNT, source_count)
|
||||
|
||||
def add_dataset_name(self, source_id: int, name: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_NAME.format(id=source_id), name)
|
||||
|
||||
def add_dataset_author(self, source_id: int, author: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_AUTHOR.format(id=source_id), author)
|
||||
|
||||
def add_dataset_version(self, source_id: int, version: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_VERSION.format(id=source_id), version)
|
||||
|
||||
def add_dataset_organization(self, source_id: int, organization: str) -> None:
|
||||
self.add_string(
|
||||
Keys.General.DATASET_ORGANIZATION.format(id=source_id), organization
|
||||
)
|
||||
|
||||
def add_dataset_description(self, source_id: int, description: str) -> None:
|
||||
self.add_string(
|
||||
Keys.General.DATASET_DESCRIPTION.format(id=source_id), description
|
||||
)
|
||||
|
||||
def add_dataset_url(self, source_id: int, url: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_URL.format(id=source_id), url)
|
||||
|
||||
def add_dataset_doi(self, source_id: int, doi: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_DOI.format(id=source_id), doi)
|
||||
|
||||
def add_dataset_uuid(self, source_id: int, uuid: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_UUID.format(id=source_id), uuid)
|
||||
|
||||
def add_dataset_repo_url(self, source_id: int, repo_url: str) -> None:
|
||||
self.add_string(Keys.General.DATASET_REPO_URL.format(id=source_id), repo_url)
|
||||
|
||||
def add_tags(self, tags: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.TAGS, tags)
|
||||
|
||||
def add_languages(self, languages: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.LANGUAGES, languages)
|
||||
|
||||
def add_datasets(self, datasets: Sequence[str]) -> None:
|
||||
self.add_array(Keys.General.DATASETS, datasets)
|
||||
|
||||
def add_tensor_data_layout(self, layout: str) -> None:
|
||||
self.add_string(Keys.LLM.TENSOR_DATA_LAYOUT.format(arch=self.arch), layout)
|
||||
|
||||
|
@ -664,6 +711,21 @@ def add_context_length(self, length: int) -> None:
|
|||
def add_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_features_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_posnet_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.PosNet.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_embedding_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.EMBEDDING_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_convnext_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.ConvNext.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
|
@ -739,6 +801,15 @@ def add_expert_shared_count(self, count: int) -> None:
|
|||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_weights_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.EXPERT_WEIGHTS_NORM.format(arch=self.arch), value)
|
||||
|
||||
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_swin_norm(self, value: bool) -> None:
|
||||
self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
|
||||
|
||||
def add_rescale_every_n_layers(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.RESCALE_EVERY_N_LAYERS.format(arch=self.arch), count)
|
||||
|
||||
|
@ -763,6 +834,12 @@ def add_layer_norm_eps(self, value: float) -> None:
|
|||
def add_layer_norm_rms_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_RMS_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.GROUPNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
def add_group_norm_groups(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.GROUPNORM_GROUPS.format(arch=self.arch), value)
|
||||
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
|
@ -787,6 +864,9 @@ def add_pooling_type(self, value: PoolingType) -> None:
|
|||
def add_rope_dimension_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.Rope.DIMENSION_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_rope_dimension_sections(self, dims: Sequence[int]) -> None:
|
||||
self.add_array(Keys.Rope.DIMENSION_SECTIONS.format(arch=self.arch), dims)
|
||||
|
||||
def add_rope_freq_base(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
|
@ -893,6 +973,7 @@ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
|||
name = choice.get("name", "")
|
||||
template = choice.get("template")
|
||||
|
||||
# Allowing non-alphanumerical characters in template name is probably not a good idea, so filter it
|
||||
name = "".join(
|
||||
(c if c in ascii_letters + digits else "_" for c in name)
|
||||
)
|
||||
|
@ -916,15 +997,6 @@ def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
|
|||
|
||||
self.add_string(Keys.Tokenizer.CHAT_TEMPLATE, value)
|
||||
|
||||
def add_prefix_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.PREFIX_ID, id)
|
||||
|
||||
def add_suffix_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.SUFFIX_ID, id)
|
||||
|
||||
def add_middle_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.MIDDLE_ID, id)
|
||||
|
||||
def add_eot_token_id(self, id: int) -> None:
|
||||
self.add_uint32(Keys.Tokenizer.EOT_ID, id)
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
|
||||
|
||||
class LazyMeta(ABCMeta):
|
||||
|
||||
def __new__(
|
||||
cls, name: str, bases: tuple[type, ...], namespace: dict[str, Any], **kwargs
|
||||
):
|
||||
|
@ -34,7 +35,7 @@ def __getattr__(self, name: str) -> Any:
|
|||
|
||||
# need to make a builder for the wrapped wrapper to copy the name,
|
||||
# or else it fails with very cryptic error messages,
|
||||
# because somehow the same string would end up in every closure
|
||||
# because somehow the same string would end up in every closures
|
||||
def mk_wrap(op_name: str, *, meta_noop: bool = False):
|
||||
# need to wrap the wrapper to get self
|
||||
def wrapped_special_op(self, *args, **kwargs):
|
||||
|
@ -254,6 +255,8 @@ def from_eager(cls, t: Any) -> Any:
|
|||
class LazyNumpyTensor(LazyBase):
|
||||
_tensor_type = np.ndarray
|
||||
|
||||
shape: tuple[int, ...] # Makes the type checker happy in quants.py
|
||||
|
||||
@classmethod
|
||||
def meta_with_dtype_and_shape(
|
||||
cls, dtype: DTypeLike, shape: tuple[int, ...]
|
||||
|
|
|
@ -41,7 +41,7 @@ class Metadata:
|
|||
base_models: Optional[list[dict]] = None
|
||||
tags: Optional[list[str]] = None
|
||||
languages: Optional[list[str]] = None
|
||||
datasets: Optional[list[str]] = None
|
||||
datasets: Optional[list[dict]] = None
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
|
@ -50,7 +50,7 @@ def load(
|
|||
model_name: Optional[str] = None,
|
||||
total_params: int = 0,
|
||||
) -> Metadata:
|
||||
# This grabs as much contextual authorship metadata as possible from the model repository
|
||||
# This grabs as many contextual authorship metadata as possible from the model repository
|
||||
# making any conversion as required to match the gguf kv store metadata format
|
||||
# as well as giving users the ability to override any authorship metadata that may be incorrect
|
||||
|
||||
|
@ -126,13 +126,13 @@ def load(
|
|||
"general.base_models", metadata.base_models
|
||||
)
|
||||
|
||||
# Datasets is received here as an array of datasets
|
||||
metadata.datasets = metadata_override.get("general.datasets", metadata.datasets)
|
||||
|
||||
metadata.tags = metadata_override.get(Keys.General.TAGS, metadata.tags)
|
||||
metadata.languages = metadata_override.get(
|
||||
Keys.General.LANGUAGES, metadata.languages
|
||||
)
|
||||
metadata.datasets = metadata_override.get(
|
||||
Keys.General.DATASETS, metadata.datasets
|
||||
)
|
||||
|
||||
# Direct Metadata Override (via direct cli argument)
|
||||
if model_name is not None:
|
||||
|
@ -228,7 +228,11 @@ def get_model_id_components(
|
|||
org_component, model_full_name_component = None, model_id
|
||||
|
||||
# Check if we erroneously matched against './' or '../' etc...
|
||||
if org_component is not None and org_component[0] == ".":
|
||||
if (
|
||||
org_component is not None
|
||||
and len(org_component) > 0
|
||||
and org_component[0] == "."
|
||||
):
|
||||
org_component = None
|
||||
|
||||
name_parts: list[str] = model_full_name_component.split("-")
|
||||
|
@ -387,27 +391,86 @@ def apply_metadata_heuristic(
|
|||
########################
|
||||
if model_card is not None:
|
||||
|
||||
if "model_name" in model_card and metadata.name is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.name = model_card.get("model_name")
|
||||
def use_model_card_metadata(metadata_key: str, model_card_key: str):
|
||||
if (
|
||||
model_card_key in model_card
|
||||
and getattr(metadata, metadata_key, None) is None
|
||||
):
|
||||
setattr(metadata, metadata_key, model_card.get(model_card_key))
|
||||
|
||||
if "model_creator" in model_card and metadata.author is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.author = model_card.get("model_creator")
|
||||
def use_array_model_card_metadata(metadata_key: str, model_card_key: str):
|
||||
# Note: Will append rather than replace if already exist
|
||||
tags_value = model_card.get(model_card_key, None)
|
||||
if tags_value is None:
|
||||
return
|
||||
|
||||
if "model_type" in model_card and metadata.basename is None:
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
metadata.basename = model_card.get("model_type")
|
||||
current_value = getattr(metadata, metadata_key, None)
|
||||
if current_value is None:
|
||||
current_value = []
|
||||
|
||||
if "base_model" in model_card:
|
||||
if isinstance(tags_value, str):
|
||||
current_value.append(tags_value)
|
||||
elif isinstance(tags_value, list):
|
||||
current_value.extend(tags_value)
|
||||
|
||||
setattr(metadata, metadata_key, current_value)
|
||||
|
||||
# LLAMA.cpp's direct internal convention
|
||||
# (Definitely not part of hugging face formal/informal standard)
|
||||
#########################################
|
||||
use_model_card_metadata("name", "name")
|
||||
use_model_card_metadata("author", "author")
|
||||
use_model_card_metadata("version", "version")
|
||||
use_model_card_metadata("organization", "organization")
|
||||
use_model_card_metadata("description", "description")
|
||||
use_model_card_metadata("finetune", "finetune")
|
||||
use_model_card_metadata("basename", "basename")
|
||||
use_model_card_metadata("size_label", "size_label")
|
||||
use_model_card_metadata("source_url", "url")
|
||||
use_model_card_metadata("source_doi", "doi")
|
||||
use_model_card_metadata("source_uuid", "uuid")
|
||||
use_model_card_metadata("source_repo_url", "repo_url")
|
||||
|
||||
# LLAMA.cpp's huggingface style convention
|
||||
# (Definitely not part of hugging face formal/informal standard... but with model_ appended to match their style)
|
||||
###########################################
|
||||
use_model_card_metadata("name", "model_name")
|
||||
use_model_card_metadata("author", "model_author")
|
||||
use_model_card_metadata("version", "model_version")
|
||||
use_model_card_metadata("organization", "model_organization")
|
||||
use_model_card_metadata("description", "model_description")
|
||||
use_model_card_metadata("finetune", "model_finetune")
|
||||
use_model_card_metadata("basename", "model_basename")
|
||||
use_model_card_metadata("size_label", "model_size_label")
|
||||
use_model_card_metadata("source_url", "model_url")
|
||||
use_model_card_metadata("source_doi", "model_doi")
|
||||
use_model_card_metadata("source_uuid", "model_uuid")
|
||||
use_model_card_metadata("source_repo_url", "model_repo_url")
|
||||
|
||||
# Hugging Face Direct Convention
|
||||
#################################
|
||||
|
||||
# Not part of huggingface model card standard but notice some model creator using it
|
||||
# such as TheBloke in 'TheBloke/Mistral-7B-Instruct-v0.2-GGUF'
|
||||
use_model_card_metadata("name", "model_name")
|
||||
use_model_card_metadata("author", "model_creator")
|
||||
use_model_card_metadata("basename", "model_type")
|
||||
|
||||
if (
|
||||
"base_model" in model_card
|
||||
or "base_models" in model_card
|
||||
or "base_model_sources" in model_card
|
||||
):
|
||||
# This represents the parent models that this is based on
|
||||
# Example: stabilityai/stable-diffusion-xl-base-1.0. Can also be a list (for merges)
|
||||
# Example of merges: https://huggingface.co/EmbeddedLLM/Mistral-7B-Merge-14-v0.1/blob/main/README.md
|
||||
metadata_base_models = []
|
||||
base_model_value = model_card.get("base_model", None)
|
||||
base_model_value = model_card.get(
|
||||
"base_model",
|
||||
model_card.get(
|
||||
"base_models", model_card.get("base_model_sources", None)
|
||||
),
|
||||
)
|
||||
|
||||
if base_model_value is not None:
|
||||
if isinstance(base_model_value, str):
|
||||
|
@ -420,86 +483,195 @@ def apply_metadata_heuristic(
|
|||
|
||||
for model_id in metadata_base_models:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
(
|
||||
model_full_name_component,
|
||||
org_component,
|
||||
basename,
|
||||
finetune,
|
||||
version,
|
||||
size_label,
|
||||
) = Metadata.get_model_id_components(model_id, total_params)
|
||||
base_model = {}
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(
|
||||
model_full_name_component
|
||||
)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(org_component)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if (
|
||||
org_component is not None
|
||||
and model_full_name_component is not None
|
||||
):
|
||||
base_model["repo_url"] = (
|
||||
f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
if isinstance(model_id, str):
|
||||
if (
|
||||
model_id.startswith("http://")
|
||||
or model_id.startswith("https://")
|
||||
or model_id.startswith("ssh://")
|
||||
):
|
||||
base_model["repo_url"] = model_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in model_id:
|
||||
match = re.match(
|
||||
r"https?://huggingface.co/([^/]+/[^/]+)$", model_id
|
||||
)
|
||||
if match:
|
||||
model_id_component = match.group(1)
|
||||
(
|
||||
model_full_name_component,
|
||||
org_component,
|
||||
basename,
|
||||
finetune,
|
||||
version,
|
||||
size_label,
|
||||
) = Metadata.get_model_id_components(
|
||||
model_id_component, total_params
|
||||
)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(
|
||||
model_full_name_component
|
||||
)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = (
|
||||
Metadata.id_to_title(org_component)
|
||||
)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
(
|
||||
model_full_name_component,
|
||||
org_component,
|
||||
basename,
|
||||
finetune,
|
||||
version,
|
||||
size_label,
|
||||
) = Metadata.get_model_id_components(model_id, total_params)
|
||||
|
||||
# Populate model dictionary with extracted components
|
||||
if model_full_name_component is not None:
|
||||
base_model["name"] = Metadata.id_to_title(
|
||||
model_full_name_component
|
||||
)
|
||||
if org_component is not None:
|
||||
base_model["organization"] = Metadata.id_to_title(
|
||||
org_component
|
||||
)
|
||||
if version is not None:
|
||||
base_model["version"] = version
|
||||
if (
|
||||
org_component is not None
|
||||
and model_full_name_component is not None
|
||||
):
|
||||
base_model["repo_url"] = (
|
||||
f"https://huggingface.co/{org_component}/{model_full_name_component}"
|
||||
)
|
||||
|
||||
elif isinstance(model_id, dict):
|
||||
base_model = model_id
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
f"base model entry '{str(model_id)}' not in a known format"
|
||||
)
|
||||
|
||||
metadata.base_models.append(base_model)
|
||||
|
||||
if "license" in model_card and metadata.license is None:
|
||||
metadata.license = model_card.get("license")
|
||||
if (
|
||||
"datasets" in model_card
|
||||
or "dataset" in model_card
|
||||
or "dataset_sources" in model_card
|
||||
):
|
||||
# This represents the datasets that this was trained from
|
||||
metadata_datasets = []
|
||||
dataset_value = model_card.get(
|
||||
"datasets",
|
||||
model_card.get("dataset", model_card.get("dataset_sources", None)),
|
||||
)
|
||||
|
||||
if "license_name" in model_card and metadata.license_name is None:
|
||||
metadata.license_name = model_card.get("license_name")
|
||||
|
||||
if "license_link" in model_card and metadata.license_link is None:
|
||||
metadata.license_link = model_card.get("license_link")
|
||||
|
||||
tags_value = model_card.get("tags", None)
|
||||
if tags_value is not None:
|
||||
|
||||
if metadata.tags is None:
|
||||
metadata.tags = []
|
||||
|
||||
if isinstance(tags_value, str):
|
||||
metadata.tags.append(tags_value)
|
||||
elif isinstance(tags_value, list):
|
||||
metadata.tags.extend(tags_value)
|
||||
|
||||
pipeline_tags_value = model_card.get("pipeline_tag", None)
|
||||
if pipeline_tags_value is not None:
|
||||
|
||||
if metadata.tags is None:
|
||||
metadata.tags = []
|
||||
|
||||
if isinstance(pipeline_tags_value, str):
|
||||
metadata.tags.append(pipeline_tags_value)
|
||||
elif isinstance(pipeline_tags_value, list):
|
||||
metadata.tags.extend(pipeline_tags_value)
|
||||
|
||||
language_value = model_card.get(
|
||||
"languages", model_card.get("language", None)
|
||||
)
|
||||
if language_value is not None:
|
||||
|
||||
if metadata.languages is None:
|
||||
metadata.languages = []
|
||||
|
||||
if isinstance(language_value, str):
|
||||
metadata.languages.append(language_value)
|
||||
elif isinstance(language_value, list):
|
||||
metadata.languages.extend(language_value)
|
||||
|
||||
dataset_value = model_card.get("datasets", model_card.get("dataset", None))
|
||||
if dataset_value is not None:
|
||||
if dataset_value is not None:
|
||||
if isinstance(dataset_value, str):
|
||||
metadata_datasets.append(dataset_value)
|
||||
elif isinstance(dataset_value, list):
|
||||
metadata_datasets.extend(dataset_value)
|
||||
|
||||
if metadata.datasets is None:
|
||||
metadata.datasets = []
|
||||
|
||||
if isinstance(dataset_value, str):
|
||||
metadata.datasets.append(dataset_value)
|
||||
elif isinstance(dataset_value, list):
|
||||
metadata.datasets.extend(dataset_value)
|
||||
for dataset_id in metadata_datasets:
|
||||
# NOTE: model size of base model is assumed to be similar to the size of the current model
|
||||
dataset = {}
|
||||
if isinstance(dataset_id, str):
|
||||
if dataset_id.startswith(("http://", "https://", "ssh://")):
|
||||
dataset["repo_url"] = dataset_id
|
||||
|
||||
# Check if Hugging Face ID is present in URL
|
||||
if "huggingface.co" in dataset_id:
|
||||
match = re.match(
|
||||
r"https?://huggingface.co/([^/]+/[^/]+)$",
|
||||
dataset_id,
|
||||
)
|
||||
if match:
|
||||
dataset_id_component = match.group(1)
|
||||
(
|
||||
dataset_name_component,
|
||||
org_component,
|
||||
basename,
|
||||
finetune,
|
||||
version,
|
||||
size_label,
|
||||
) = Metadata.get_model_id_components(
|
||||
dataset_id_component, total_params
|
||||
)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(
|
||||
dataset_name_component
|
||||
)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(
|
||||
org_component
|
||||
)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
|
||||
else:
|
||||
# Likely a Hugging Face ID
|
||||
(
|
||||
dataset_name_component,
|
||||
org_component,
|
||||
basename,
|
||||
finetune,
|
||||
version,
|
||||
size_label,
|
||||
) = Metadata.get_model_id_components(
|
||||
dataset_id, total_params
|
||||
)
|
||||
|
||||
# Populate dataset dictionary with extracted components
|
||||
if dataset_name_component is not None:
|
||||
dataset["name"] = Metadata.id_to_title(
|
||||
dataset_name_component
|
||||
)
|
||||
if org_component is not None:
|
||||
dataset["organization"] = Metadata.id_to_title(
|
||||
org_component
|
||||
)
|
||||
if version is not None:
|
||||
dataset["version"] = version
|
||||
if (
|
||||
org_component is not None
|
||||
and dataset_name_component is not None
|
||||
):
|
||||
dataset["repo_url"] = (
|
||||
f"https://huggingface.co/{org_component}/{dataset_name_component}"
|
||||
)
|
||||
|
||||
elif isinstance(dataset_id, dict):
|
||||
dataset = dataset_id
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
f"dataset entry '{str(dataset_id)}' not in a known format"
|
||||
)
|
||||
|
||||
metadata.datasets.append(dataset)
|
||||
|
||||
use_model_card_metadata("license", "license")
|
||||
use_model_card_metadata("license_name", "license_name")
|
||||
use_model_card_metadata("license_link", "license_link")
|
||||
|
||||
use_array_model_card_metadata("tags", "tags")
|
||||
use_array_model_card_metadata("tags", "pipeline_tag")
|
||||
|
||||
use_array_model_card_metadata("languages", "languages")
|
||||
use_array_model_card_metadata("languages", "language")
|
||||
|
||||
# Hugging Face Parameter Heuristics
|
||||
####################################
|
||||
|
@ -508,7 +680,7 @@ def apply_metadata_heuristic(
|
|||
|
||||
hf_name_or_path = hf_params.get("_name_or_path")
|
||||
if hf_name_or_path is not None and hf_name_or_path.count("/") <= 1:
|
||||
# Use _name_or_path only if it's actually a model name and not some computer path
|
||||
# Use _name_or_path only if its actually a model name and not some computer path
|
||||
# e.g. 'meta-llama/Llama-2-7b-hf'
|
||||
model_id = hf_name_or_path
|
||||
(
|
||||
|
@ -584,7 +756,10 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
|||
gguf_writer.add_size_label(self.size_label)
|
||||
|
||||
if self.license is not None:
|
||||
gguf_writer.add_license(self.license)
|
||||
if isinstance(self.license, list):
|
||||
gguf_writer.add_license(",".join(self.license))
|
||||
else:
|
||||
gguf_writer.add_license(self.license)
|
||||
if self.license_name is not None:
|
||||
gguf_writer.add_license_name(self.license_name)
|
||||
if self.license_link is not None:
|
||||
|
@ -621,6 +796,10 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
|||
gguf_writer.add_base_model_organization(
|
||||
key, base_model_entry["organization"]
|
||||
)
|
||||
if "description" in base_model_entry:
|
||||
gguf_writer.add_base_model_description(
|
||||
key, base_model_entry["description"]
|
||||
)
|
||||
if "url" in base_model_entry:
|
||||
gguf_writer.add_base_model_url(key, base_model_entry["url"])
|
||||
if "doi" in base_model_entry:
|
||||
|
@ -632,9 +811,33 @@ def set_gguf_meta_model(self, gguf_writer: gguf.GGUFWriter):
|
|||
key, base_model_entry["repo_url"]
|
||||
)
|
||||
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_dataset_count(len(self.datasets))
|
||||
for key, dataset_entry in enumerate(self.datasets):
|
||||
if "name" in dataset_entry:
|
||||
gguf_writer.add_dataset_name(key, dataset_entry["name"])
|
||||
if "author" in dataset_entry:
|
||||
gguf_writer.add_dataset_author(key, dataset_entry["author"])
|
||||
if "version" in dataset_entry:
|
||||
gguf_writer.add_dataset_version(key, dataset_entry["version"])
|
||||
if "organization" in dataset_entry:
|
||||
gguf_writer.add_dataset_organization(
|
||||
key, dataset_entry["organization"]
|
||||
)
|
||||
if "description" in dataset_entry:
|
||||
gguf_writer.add_dataset_description(
|
||||
key, dataset_entry["description"]
|
||||
)
|
||||
if "url" in dataset_entry:
|
||||
gguf_writer.add_dataset_url(key, dataset_entry["url"])
|
||||
if "doi" in dataset_entry:
|
||||
gguf_writer.add_dataset_doi(key, dataset_entry["doi"])
|
||||
if "uuid" in dataset_entry:
|
||||
gguf_writer.add_dataset_uuid(key, dataset_entry["uuid"])
|
||||
if "repo_url" in dataset_entry:
|
||||
gguf_writer.add_dataset_repo_url(key, dataset_entry["repo_url"])
|
||||
|
||||
if self.tags is not None:
|
||||
gguf_writer.add_tags(self.tags)
|
||||
if self.languages is not None:
|
||||
gguf_writer.add_languages(self.languages)
|
||||
if self.datasets is not None:
|
||||
gguf_writer.add_datasets(self.datasets)
|
||||
|
|
1471
src/gguf/quants.py
1471
src/gguf/quants.py
File diff suppressed because it is too large
Load Diff
|
@ -7,463 +7,574 @@
|
|||
|
||||
class TensorNameMap:
|
||||
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
# Token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD: (
|
||||
"gpt_neox.embed_in",
|
||||
"transformer.wte",
|
||||
"transformer.word_embeddings",
|
||||
"word_embeddings",
|
||||
"model.embed_tokens",
|
||||
"tok_embeddings",
|
||||
"embeddings.word_embeddings",
|
||||
"language_model.embedding.word_embeddings",
|
||||
"wte",
|
||||
"transformer.embd.wte",
|
||||
"model.tok_embeddings",
|
||||
"model.embedding",
|
||||
"backbone.embedding",
|
||||
"backbone.embeddings",
|
||||
"transformer.in_out_embed",
|
||||
"embedding.word_embeddings",
|
||||
"transformer.token_embeddings",
|
||||
"shared",
|
||||
"rwkv.embeddings",
|
||||
"gpt_neox.embed_in", # gptneox
|
||||
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
|
||||
"transformer.word_embeddings", # falcon
|
||||
"word_embeddings", # bloom
|
||||
"model.embed_tokens", # llama-hf nemotron olmoe olmo2
|
||||
"tok_embeddings", # llama-pth
|
||||
"embeddings.word_embeddings", # bert nomic-bert
|
||||
"language_model.embedding.word_embeddings", # persimmon
|
||||
"wte", # gpt2
|
||||
"transformer.embd.wte", # phi2
|
||||
"model.tok_embeddings", # internlm2
|
||||
"model.embedding", # mamba-qbert
|
||||
"backbone.embedding", # mamba
|
||||
"backbone.embeddings", # mamba-hf
|
||||
"transformer.in_out_embed", # Grok
|
||||
"embedding.word_embeddings", # chatglm
|
||||
"transformer.token_embeddings", # openelm
|
||||
"shared", # t5
|
||||
"rwkv.embeddings", # rwkv
|
||||
),
|
||||
MODEL_TENSOR.TOKEN_TYPES: ("embeddings.token_type_embeddings",),
|
||||
# Token type embeddings
|
||||
MODEL_TENSOR.TOKEN_TYPES: (
|
||||
"embeddings.token_type_embeddings", # bert nomic-bert
|
||||
),
|
||||
# Normalization of token embeddings
|
||||
MODEL_TENSOR.TOKEN_EMBD_NORM: (
|
||||
"word_embeddings_layernorm",
|
||||
"embeddings.LayerNorm",
|
||||
"emb_ln",
|
||||
"transformer.norm",
|
||||
"rwkv.blocks.0.pre_ln",
|
||||
"word_embeddings_layernorm", # bloom
|
||||
"embeddings.LayerNorm", # bert
|
||||
"emb_ln", # nomic-bert
|
||||
"transformer.norm", # openelm
|
||||
"rwkv.blocks.0.pre_ln", # rwkv
|
||||
"backbone.norm", # wavtokenizer
|
||||
),
|
||||
# Position embeddings
|
||||
MODEL_TENSOR.POS_EMBD: (
|
||||
"transformer.wpe",
|
||||
"embeddings.position_embeddings",
|
||||
"wpe",
|
||||
"transformer.wpe", # gpt2
|
||||
"embeddings.position_embeddings", # bert
|
||||
"wpe", # gpt2
|
||||
),
|
||||
# Output
|
||||
MODEL_TENSOR.OUTPUT: (
|
||||
"embed_out",
|
||||
"lm_head",
|
||||
"output",
|
||||
"word_embeddings_for_head",
|
||||
"lm_head.linear",
|
||||
"output_layer",
|
||||
"head",
|
||||
"embed_out", # gptneox
|
||||
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
|
||||
"output", # llama-pth bloom internlm2
|
||||
"word_embeddings_for_head", # persimmon
|
||||
"lm_head.linear", # phi2
|
||||
"output_layer", # chatglm
|
||||
"head", # rwkv
|
||||
"head.out", # wavtokenizer
|
||||
),
|
||||
# Output norm
|
||||
MODEL_TENSOR.OUTPUT_NORM: (
|
||||
"gpt_neox.final_layer_norm",
|
||||
"transformer.ln_f",
|
||||
"model.norm",
|
||||
"norm",
|
||||
"transformer.norm_f",
|
||||
"ln_f",
|
||||
"language_model.encoder.final_layernorm",
|
||||
"model.final_layernorm",
|
||||
"lm_head.ln",
|
||||
"model.norm_f",
|
||||
"backbone.norm_f",
|
||||
"transformer.rms_norm",
|
||||
"encoder.final_layernorm",
|
||||
"transformer.norm",
|
||||
"model.norm",
|
||||
"rwkv.ln_out",
|
||||
"gpt_neox.final_layer_norm", # gptneox
|
||||
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
|
||||
"model.norm", # llama-hf baichuan internlm2 olmoe olmo2
|
||||
"norm", # llama-pth
|
||||
"transformer.norm_f", # mpt dbrx
|
||||
"ln_f", # refact bloom qwen gpt2
|
||||
"language_model.encoder.final_layernorm", # persimmon
|
||||
"model.final_layernorm", # persimmon
|
||||
"lm_head.ln", # phi2
|
||||
"model.norm_f", # mamba-qbert
|
||||
"backbone.norm_f", # mamba
|
||||
"transformer.rms_norm", # Grok
|
||||
"encoder.final_layernorm", # chatglm
|
||||
"transformer.norm", # openelm
|
||||
"model.norm", # nemotron
|
||||
"rwkv.ln_out", # rwkv
|
||||
"backbone.final_layer_norm", # wavtokenizer
|
||||
),
|
||||
# Rope frequencies
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"rope.freqs",
|
||||
"rotary_pos_emb.inv_freq",
|
||||
"rope.freqs", # llama-pth
|
||||
"rotary_pos_emb.inv_freq", # chatglm
|
||||
),
|
||||
MODEL_TENSOR.ROPE_FACTORS_LONG: (),
|
||||
MODEL_TENSOR.ROPE_FACTORS_SHORT: (),
|
||||
MODEL_TENSOR.CONV1D: ("backbone.embed",), # roberta
|
||||
}
|
||||
|
||||
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
|
||||
# Attention norm
|
||||
MODEL_TENSOR.ATTN_NORM: (
|
||||
"gpt_neox.layers.{bid}.input_layernorm",
|
||||
"transformer.h.{bid}.ln_1",
|
||||
"transformer.blocks.{bid}.norm_1",
|
||||
"transformer.h.{bid}.input_layernorm",
|
||||
"h.{bid}.input_layernorm",
|
||||
"transformer.h.{bid}.ln_mlp",
|
||||
"model.layers.{bid}.input_layernorm",
|
||||
"layers.{bid}.attention_norm",
|
||||
"language_model.encoder.layers.{bid}.input_layernorm",
|
||||
"model.layers.{bid}.ln1",
|
||||
"h.{bid}.ln_1",
|
||||
"transformer.h.{bid}.ln",
|
||||
"model.layers.layers.{bid}.norm",
|
||||
"model.layers.{bid}.attention_norm",
|
||||
"model.layers.{bid}.norm",
|
||||
"backbone.layers.{bid}.norm",
|
||||
"transformer.decoder_layer.{bid}.rms_norm",
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1",
|
||||
"encoder.layers.{bid}.input_layernorm",
|
||||
"transformer.layers.{bid}.attn_norm",
|
||||
"rwkv.blocks.{bid}.ln1",
|
||||
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
|
||||
"transformer.blocks.{bid}.norm_1", # mpt
|
||||
"transformer.h.{bid}.input_layernorm", # falcon7b
|
||||
"h.{bid}.input_layernorm", # bloom
|
||||
"transformer.h.{bid}.ln_mlp", # falcon40b
|
||||
"model.layers.{bid}.input_layernorm", # llama-hf nemotron olmoe
|
||||
"layers.{bid}.attention_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.input_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln1", # yi
|
||||
"h.{bid}.ln_1", # gpt2
|
||||
"transformer.h.{bid}.ln", # phi2
|
||||
"model.layers.layers.{bid}.norm", # plamo
|
||||
"model.layers.{bid}.attention_norm", # internlm2
|
||||
"model.layers.{bid}.norm", # mamba-qbert
|
||||
"backbone.layers.{bid}.norm", # mamba
|
||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.attn_norm", # openelm
|
||||
"rwkv.blocks.{bid}.ln1", # rwkv
|
||||
),
|
||||
# Attention norm 2
|
||||
MODEL_TENSOR.ATTN_NORM_2: (
|
||||
"transformer.h.{bid}.ln_attn",
|
||||
"encoder.layer.{bid}.layer_norm_1",
|
||||
"rwkv.blocks.{bid}.ln2",
|
||||
"transformer.h.{bid}.ln_attn", # falcon40b
|
||||
"encoder.layer.{bid}.layer_norm_1", # jina-v2-code
|
||||
"rwkv.blocks.{bid}.ln2", # rwkv
|
||||
),
|
||||
# Attention query-key-value
|
||||
MODEL_TENSOR.ATTN_QKV: (
|
||||
"gpt_neox.layers.{bid}.attention.query_key_value",
|
||||
"transformer.h.{bid}.attn.c_attn",
|
||||
"transformer.blocks.{bid}.attn.Wqkv",
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv",
|
||||
"transformer.h.{bid}.self_attention.query_key_value",
|
||||
"h.{bid}.self_attention.query_key_value",
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value",
|
||||
"model.layers.{bid}.self_attn.query_key_value",
|
||||
"h.{bid}.attn.c_attn",
|
||||
"transformer.h.{bid}.mixer.Wqkv",
|
||||
"encoder.layers.{bid}.attn.Wqkv",
|
||||
"model.layers.{bid}.self_attn.qkv_proj",
|
||||
"encoder.layers.{bid}.self_attention.query_key_value",
|
||||
"transformer.layers.{bid}.attn.qkv_proj",
|
||||
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
|
||||
"transformer.h.{bid}.attn.c_attn", # gpt2 qwen jais
|
||||
"transformer.blocks.{bid}.attn.Wqkv", # mpt
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.Wqkv", # dbrx
|
||||
"transformer.h.{bid}.self_attention.query_key_value", # falcon
|
||||
"h.{bid}.self_attention.query_key_value", # bloom
|
||||
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
|
||||
"model.layers.{bid}.self_attn.query_key_value", # persimmon
|
||||
"h.{bid}.attn.c_attn", # gpt2
|
||||
"transformer.h.{bid}.mixer.Wqkv", # phi2
|
||||
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
|
||||
"model.layers.{bid}.self_attn.qkv_proj", # phi3
|
||||
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
|
||||
"transformer.layers.{bid}.attn.qkv_proj", # openelm
|
||||
),
|
||||
# Attention query
|
||||
MODEL_TENSOR.ATTN_Q: (
|
||||
"model.layers.{bid}.self_attn.q_proj",
|
||||
"layers.{bid}.attention.wq",
|
||||
"encoder.layer.{bid}.attention.self.query",
|
||||
"transformer.h.{bid}.attn.q_proj",
|
||||
"model.layers.layers.{bid}.self_attn.q_proj",
|
||||
"model.layers.{bid}.attention.wq",
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query",
|
||||
"transformer.h.{bid}.attn.attention.q_proj",
|
||||
"model.layers.{bid}.self_attn.q_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.q_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wq", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.query", # bert
|
||||
"transformer.h.{bid}.attn.q_proj", # gpt-j
|
||||
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
|
||||
"model.layers.{bid}.attention.wq", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.query", # Grok
|
||||
"transformer.h.{bid}.attn.attention.q_proj", # exaone
|
||||
),
|
||||
# Attention key
|
||||
MODEL_TENSOR.ATTN_K: (
|
||||
"model.layers.{bid}.self_attn.k_proj",
|
||||
"layers.{bid}.attention.wk",
|
||||
"encoder.layer.{bid}.attention.self.key",
|
||||
"transformer.h.{bid}.attn.k_proj",
|
||||
"transformer.h.{bid}.attn.k",
|
||||
"model.layers.layers.{bid}.self_attn.k_proj",
|
||||
"model.layers.{bid}.attention.wk",
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key",
|
||||
"transformer.h.{bid}.attn.attention.k_proj",
|
||||
"model.layers.{bid}.self_attn.k_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.k_proj_no_perm", # llama-custom
|
||||
"layers.{bid}.attention.wk", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.key", # bert
|
||||
"transformer.h.{bid}.attn.k_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.k", # refact
|
||||
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
|
||||
"model.layers.{bid}.attention.wk", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.key", # Grok
|
||||
"transformer.h.{bid}.attn.attention.k_proj", # exaone
|
||||
),
|
||||
# Attention value
|
||||
MODEL_TENSOR.ATTN_V: (
|
||||
"model.layers.{bid}.self_attn.v_proj",
|
||||
"layers.{bid}.attention.wv",
|
||||
"encoder.layer.{bid}.attention.self.value",
|
||||
"transformer.h.{bid}.attn.v_proj",
|
||||
"transformer.h.{bid}.attn.v",
|
||||
"model.layers.layers.{bid}.self_attn.v_proj",
|
||||
"model.layers.{bid}.attention.wv",
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value",
|
||||
"transformer.h.{bid}.attn.attention.v_proj",
|
||||
"model.layers.{bid}.self_attn.v_proj", # llama-hf nemotron olmoe olmo2
|
||||
"layers.{bid}.attention.wv", # llama-pth
|
||||
"encoder.layer.{bid}.attention.self.value", # bert
|
||||
"transformer.h.{bid}.attn.v_proj", # gpt-j
|
||||
"transformer.h.{bid}.attn.v", # refact
|
||||
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
|
||||
"model.layers.{bid}.attention.wv", # internlm2
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.value", # Grok
|
||||
"transformer.h.{bid}.attn.attention.v_proj", # exaone
|
||||
),
|
||||
# Attention output
|
||||
MODEL_TENSOR.ATTN_OUT: (
|
||||
"gpt_neox.layers.{bid}.attention.dense",
|
||||
"transformer.h.{bid}.attn.c_proj",
|
||||
"transformer.blocks.{bid}.attn.out_proj",
|
||||
"transformer.h.{bid}.self_attention.dense",
|
||||
"h.{bid}.self_attention.dense",
|
||||
"model.layers.{bid}.self_attn.o_proj",
|
||||
"layers.{bid}.attention.wo",
|
||||
"encoder.layer.{bid}.attention.output.dense",
|
||||
"transformer.h.{bid}.attn.out_proj",
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense",
|
||||
"model.layers.{bid}.self_attn.dense",
|
||||
"h.{bid}.attn.c_proj",
|
||||
"transformer.h.{bid}.mixer.out_proj",
|
||||
"model.layers.layers.{bid}.self_attn.o_proj",
|
||||
"model.layers.{bid}.attention.wo",
|
||||
"encoder.layers.{bid}.attn.out_proj",
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear",
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj",
|
||||
"encoder.layers.{bid}.self_attention.dense",
|
||||
"transformer.layers.{bid}.attn.out_proj",
|
||||
"transformer.h.{bid}.attn.attention.out_proj",
|
||||
"gpt_neox.layers.{bid}.attention.dense", # gptneox
|
||||
"transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen jais
|
||||
"transformer.blocks.{bid}.attn.out_proj", # mpt
|
||||
"transformer.h.{bid}.self_attention.dense", # falcon
|
||||
"h.{bid}.self_attention.dense", # bloom
|
||||
"model.layers.{bid}.self_attn.o_proj", # llama-hf nemotron olmoe olmo2
|
||||
"model.layers.{bid}.self_attn.linear_attn", # deci
|
||||
"layers.{bid}.attention.wo", # llama-pth
|
||||
"encoder.layer.{bid}.attention.output.dense", # bert
|
||||
"transformer.h.{bid}.attn.out_proj", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
|
||||
"model.layers.{bid}.self_attn.dense", # persimmon
|
||||
"h.{bid}.attn.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mixer.out_proj", # phi2
|
||||
"model.layers.layers.{bid}.self_attn.o_proj", # plamo
|
||||
"model.layers.{bid}.attention.wo", # internlm2
|
||||
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
|
||||
"encoder.layers.{bid}.self_attention.dense", # chatglm
|
||||
"transformer.layers.{bid}.attn.out_proj", # openelm
|
||||
"transformer.h.{bid}.attn.attention.out_proj", # exaone
|
||||
),
|
||||
# Attention output norm
|
||||
MODEL_TENSOR.ATTN_OUT_NORM: (
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm",
|
||||
"encoder.layers.{bid}.norm1",
|
||||
"transformer.decoder_layer.{bid}.rms_norm_1",
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2",
|
||||
"encoder.layer.{bid}.attention.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||
),
|
||||
MODEL_TENSOR.ATTN_POST_NORM: ("model.layers.{bid}.post_attention_layernorm",),
|
||||
MODEL_TENSOR.ATTN_POST_NORM: (
|
||||
"model.layers.{bid}.post_attention_layernorm", # gemma2 olmo2
|
||||
),
|
||||
# Rotary embeddings
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD: (
|
||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq",
|
||||
"layers.{bid}.attention.inner_attention.rope.freqs",
|
||||
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq",
|
||||
"transformer.h.{bid}.attn.rotary_emb.inv_freq",
|
||||
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
|
||||
"layers.{bid}.attention.inner_attention.rope.freqs", # llama-pth
|
||||
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq", # plamo
|
||||
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
|
||||
),
|
||||
# Feed-forward norm
|
||||
MODEL_TENSOR.FFN_NORM: (
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm",
|
||||
"transformer.h.{bid}.ln_2",
|
||||
"h.{bid}.post_attention_layernorm",
|
||||
"transformer.blocks.{bid}.norm_2",
|
||||
"model.layers.{bid}.post_attention_layernorm",
|
||||
"layers.{bid}.ffn_norm",
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm",
|
||||
"model.layers.{bid}.ln2",
|
||||
"h.{bid}.ln_2",
|
||||
"model.layers.{bid}.ffn_norm",
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2",
|
||||
"encoder.layers.{bid}.post_attention_layernorm",
|
||||
"transformer.layers.{bid}.ffn_norm",
|
||||
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
|
||||
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
|
||||
"h.{bid}.post_attention_layernorm", # bloom
|
||||
"transformer.blocks.{bid}.norm_2", # mpt
|
||||
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron olmoe
|
||||
"layers.{bid}.ffn_norm", # llama-pth
|
||||
"language_model.encoder.layers.{bid}.post_attention_layernorm", # persimmon
|
||||
"model.layers.{bid}.ln2", # yi
|
||||
"h.{bid}.ln_2", # gpt2
|
||||
"model.layers.{bid}.ffn_norm", # internlm2
|
||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||
),
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_PRE_NORM: (
|
||||
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
|
||||
),
|
||||
# Post feed-forward norm
|
||||
MODEL_TENSOR.FFN_POST_NORM: (
|
||||
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
|
||||
),
|
||||
MODEL_TENSOR.FFN_PRE_NORM: ("model.layers.{bid}.pre_feedforward_layernorm",),
|
||||
MODEL_TENSOR.FFN_POST_NORM: ("model.layers.{bid}.post_feedforward_layernorm",),
|
||||
MODEL_TENSOR.FFN_GATE_INP: (
|
||||
"layers.{bid}.feed_forward.gate",
|
||||
"model.layers.{bid}.block_sparse_moe.gate",
|
||||
"model.layers.{bid}.mlp.gate",
|
||||
"transformer.decoder_layer.{bid}.router",
|
||||
"transformer.blocks.{bid}.ffn.router.layer",
|
||||
"layers.{bid}.feed_forward.gate", # mixtral
|
||||
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
|
||||
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
|
||||
"transformer.decoder_layer.{bid}.router", # Grok
|
||||
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
|
||||
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
|
||||
),
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ("model.layers.{bid}.mlp.shared_expert_gate",),
|
||||
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert_gate", # qwen2moe
|
||||
),
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B: (
|
||||
"model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
|
||||
),
|
||||
# Feed-forward up
|
||||
MODEL_TENSOR.FFN_UP: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h",
|
||||
"transformer.h.{bid}.mlp.c_fc",
|
||||
"transformer.blocks.{bid}.ffn.up_proj",
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h",
|
||||
"h.{bid}.mlp.dense_h_to_4h",
|
||||
"model.layers.{bid}.mlp.up_proj",
|
||||
"layers.{bid}.feed_forward.w3",
|
||||
"encoder.layer.{bid}.intermediate.dense",
|
||||
"transformer.h.{bid}.mlp.fc_in",
|
||||
"transformer.h.{bid}.mlp.linear_3",
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",
|
||||
"model.layers.{bid}.mlp.dense_h_to_4h",
|
||||
"transformer.h.{bid}.mlp.w1",
|
||||
"h.{bid}.mlp.c_fc",
|
||||
"transformer.h.{bid}.mlp.fc1",
|
||||
"model.layers.{bid}.mlp.fc1",
|
||||
"model.layers.{bid}.mlp.gate_up_proj",
|
||||
"model.layers.layers.{bid}.mlp.up_proj",
|
||||
"model.layers.{bid}.feed_forward.w3",
|
||||
"encoder.layers.{bid}.mlp.fc11",
|
||||
"model.layers.{bid}.mlp.c_fc",
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v",
|
||||
"model.layers.{bid}.residual_mlp.w3",
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h",
|
||||
"transformer.h.{bid}.mlp.c_fc_1",
|
||||
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_fc", # gpt2 jais
|
||||
"transformer.blocks.{bid}.ffn.up_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_h_to_4h", # falcon
|
||||
"h.{bid}.mlp.dense_h_to_4h", # bloom
|
||||
"model.layers.{bid}.mlp.up_proj", # llama-hf refact nemotron olmo2
|
||||
"layers.{bid}.feed_forward.w3", # llama-pth
|
||||
"encoder.layer.{bid}.intermediate.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_in", # gpt-j
|
||||
"transformer.h.{bid}.mlp.linear_3", # refact
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
"model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
|
||||
"transformer.h.{bid}.mlp.w1", # qwen
|
||||
"h.{bid}.mlp.c_fc", # gpt2
|
||||
"transformer.h.{bid}.mlp.fc1", # phi2
|
||||
"model.layers.{bid}.mlp.fc1", # phi2
|
||||
"model.layers.{bid}.mlp.gate_up_proj", # phi3
|
||||
"model.layers.layers.{bid}.mlp.up_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w3", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc11", # nomic-bert
|
||||
"model.layers.{bid}.mlp.c_fc", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
|
||||
"model.layers.{bid}.residual_mlp.w3", # arctic
|
||||
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
|
||||
"transformer.h.{bid}.mlp.c_fc_1", # exaone
|
||||
),
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w3",
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v",
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1",
|
||||
"model.layers.{bid}.mlp.experts.up_proj",
|
||||
"layers.{bid}.feed_forward.experts.w3", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_v", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.v1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.up_proj", # qwen2moe olmoe (merged)
|
||||
),
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj",
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj",
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
|
||||
),
|
||||
MODEL_TENSOR.FFN_ACT: ("transformer.blocks.{bid}.ffn.act",),
|
||||
# AWQ-activation gate
|
||||
MODEL_TENSOR.FFN_ACT: ("transformer.blocks.{bid}.ffn.act",), # mpt
|
||||
# Feed-forward gate
|
||||
MODEL_TENSOR.FFN_GATE: (
|
||||
"model.layers.{bid}.mlp.gate_proj",
|
||||
"layers.{bid}.feed_forward.w1",
|
||||
"transformer.h.{bid}.mlp.w2",
|
||||
"transformer.h.{bid}.mlp.c_fc2",
|
||||
"model.layers.layers.{bid}.mlp.gate_proj",
|
||||
"model.layers.{bid}.feed_forward.w1",
|
||||
"encoder.layers.{bid}.mlp.fc12",
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w",
|
||||
"transformer.h.{bid}.mlp.linear_1",
|
||||
"model.layers.{bid}.residual_mlp.w1",
|
||||
"transformer.h.{bid}.mlp.c_fc_0",
|
||||
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact olmo2
|
||||
"layers.{bid}.feed_forward.w1", # llama-pth
|
||||
"transformer.h.{bid}.mlp.w2", # qwen
|
||||
"transformer.h.{bid}.mlp.c_fc2", # jais
|
||||
"model.layers.layers.{bid}.mlp.gate_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w1", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc12", # nomic-bert
|
||||
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
|
||||
"transformer.h.{bid}.mlp.linear_1", # refact
|
||||
"model.layers.{bid}.residual_mlp.w1", # arctic
|
||||
"transformer.h.{bid}.mlp.c_fc_0", # exaone
|
||||
),
|
||||
MODEL_TENSOR.FFN_GATE_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w1",
|
||||
"transformer.decoder_layer.{bid}.moe.linear",
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1",
|
||||
"model.layers.{bid}.mlp.experts.gate_proj",
|
||||
"layers.{bid}.feed_forward.experts.w1", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w1", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged)
|
||||
),
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj",
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj",
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek deepseek2
|
||||
),
|
||||
# Feed-forward down
|
||||
MODEL_TENSOR.FFN_DOWN: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h",
|
||||
"transformer.h.{bid}.mlp.c_proj",
|
||||
"transformer.blocks.{bid}.ffn.down_proj",
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h",
|
||||
"h.{bid}.mlp.dense_4h_to_h",
|
||||
"model.layers.{bid}.mlp.down_proj",
|
||||
"layers.{bid}.feed_forward.w2",
|
||||
"encoder.layer.{bid}.output.dense",
|
||||
"transformer.h.{bid}.mlp.fc_out",
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",
|
||||
"model.layers.{bid}.mlp.dense_4h_to_h",
|
||||
"h.{bid}.mlp.c_proj",
|
||||
"transformer.h.{bid}.mlp.fc2",
|
||||
"model.layers.{bid}.mlp.fc2",
|
||||
"model.layers.layers.{bid}.mlp.down_proj",
|
||||
"model.layers.{bid}.feed_forward.w2",
|
||||
"encoder.layers.{bid}.mlp.fc2",
|
||||
"model.layers.{bid}.mlp.c_proj",
|
||||
"encoder.layer.{bid}.mlp.wo",
|
||||
"transformer.layers.{bid}.ffn.proj_2",
|
||||
"model.layers.{bid}.residual_mlp.w2",
|
||||
"encoder.layer.{bid}.mlp.down_layer",
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h",
|
||||
"model.layers.h.{bid}.mlp.c_proj",
|
||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||
"transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen jais
|
||||
"transformer.blocks.{bid}.ffn.down_proj", # mpt
|
||||
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
|
||||
"h.{bid}.mlp.dense_4h_to_h", # bloom
|
||||
"model.layers.{bid}.mlp.down_proj", # llama-hf nemotron olmo2
|
||||
"layers.{bid}.feed_forward.w2", # llama-pth
|
||||
"encoder.layer.{bid}.output.dense", # bert
|
||||
"transformer.h.{bid}.mlp.fc_out", # gpt-j
|
||||
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
|
||||
"h.{bid}.mlp.c_proj", # gpt2
|
||||
"transformer.h.{bid}.mlp.fc2", # phi2
|
||||
"model.layers.{bid}.mlp.fc2", # phi2
|
||||
"model.layers.layers.{bid}.mlp.down_proj", # plamo
|
||||
"model.layers.{bid}.feed_forward.w2", # internlm2
|
||||
"encoder.layers.{bid}.mlp.fc2", # nomic-bert
|
||||
"model.layers.{bid}.mlp.c_proj", # starcoder2
|
||||
"encoder.layer.{bid}.mlp.wo", # jina-bert-v2
|
||||
"transformer.layers.{bid}.ffn.proj_2", # openelm
|
||||
"model.layers.{bid}.residual_mlp.w2", # arctic
|
||||
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
|
||||
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
|
||||
"model.layers.h.{bid}.mlp.c_proj", # exaone
|
||||
),
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
"layers.{bid}.feed_forward.experts.w2",
|
||||
"transformer.decoder_layer.{bid}.moe.linear_1",
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2",
|
||||
"model.layers.{bid}.mlp.experts.down_proj",
|
||||
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
|
||||
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
|
||||
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
|
||||
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
|
||||
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
|
||||
),
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj",
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek deepseek2
|
||||
),
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_norm",
|
||||
"transformer.blocks.{bid}.attn.q_ln",
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q",
|
||||
"transformer.layers.{bid}.attn.q_norm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.q_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.q_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.q_norm", # openelm
|
||||
),
|
||||
MODEL_TENSOR.ATTN_K_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm",
|
||||
"model.layers.{bid}.self_attn.k_norm",
|
||||
"transformer.blocks.{bid}.attn.k_ln",
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k",
|
||||
"transformer.layers.{bid}.attn.k_norm",
|
||||
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
|
||||
"model.layers.{bid}.self_attn.k_norm", # cohere olmoe chameleon olmo2
|
||||
"transformer.blocks.{bid}.attn.k_ln", # sea-lion
|
||||
"encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2
|
||||
"transformer.layers.{bid}.attn.k_norm", # openelm
|
||||
),
|
||||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq",
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||
),
|
||||
MODEL_TENSOR.LAYER_OUT_NORM: (
|
||||
"encoder.layer.{bid}.output.LayerNorm",
|
||||
"encoder.layers.{bid}.norm2",
|
||||
"transformer.decoder_layer.{bid}.rms_norm_3",
|
||||
"encoder.layer.{bid}.mlp.layernorm",
|
||||
"encoder.layer.{bid}.layer_norm_2",
|
||||
"encoder.layer.{bid}.output.LayerNorm", # bert
|
||||
"encoder.layers.{bid}.norm2", # nomic-bert
|
||||
"transformer.decoder_layer.{bid}.rms_norm_3", # Grok
|
||||
"encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2
|
||||
"encoder.layer.{bid}.layer_norm_2", # jina-v2-code
|
||||
),
|
||||
MODEL_TENSOR.SSM_IN: (
|
||||
"model.layers.{bid}.in_proj",
|
||||
"backbone.layers.{bid}.mixer.in_proj",
|
||||
"model.layers.{bid}.mamba.in_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_CONV1D: (
|
||||
"model.layers.{bid}.conv1d",
|
||||
"backbone.layers.{bid}.mixer.conv1d",
|
||||
"model.layers.{bid}.mamba.conv1d",
|
||||
),
|
||||
MODEL_TENSOR.SSM_X: (
|
||||
"model.layers.{bid}.x_proj",
|
||||
"backbone.layers.{bid}.mixer.x_proj",
|
||||
"model.layers.{bid}.mamba.x_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_DT: (
|
||||
"model.layers.{bid}.dt_proj",
|
||||
"backbone.layers.{bid}.mixer.dt_proj",
|
||||
"model.layers.{bid}.mamba.dt_proj",
|
||||
),
|
||||
MODEL_TENSOR.SSM_DT_NORM: ("model.layers.{bid}.mamba.dt_layernorm",),
|
||||
MODEL_TENSOR.SSM_A: (
|
||||
"model.layers.{bid}.A_log",
|
||||
"backbone.layers.{bid}.mixer.A_log",
|
||||
"model.layers.{bid}.mamba.A_log",
|
||||
),
|
||||
MODEL_TENSOR.SSM_B_NORM: (
|
||||
"model.layers.{bid}.mamba.b_layernorm",
|
||||
"model.layers.{bid}.mamba.B_layernorm",
|
||||
),
|
||||
MODEL_TENSOR.SSM_C_NORM: (
|
||||
"model.layers.{bid}.mamba.c_layernorm",
|
||||
"model.layers.{bid}.mamba.C_layernorm",
|
||||
),
|
||||
MODEL_TENSOR.SSM_D: (
|
||||
"model.layers.{bid}.D",
|
||||
"backbone.layers.{bid}.mixer.D",
|
||||
"model.layers.{bid}.mamba.D",
|
||||
),
|
||||
MODEL_TENSOR.SSM_OUT: (
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
"model.layers.{bid}.mamba.out_proj",
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_W1: ("rwkv.blocks.{bid}.attention.time_maa_w1",),
|
||||
MODEL_TENSOR.TIME_MIX_W2: ("rwkv.blocks.{bid}.attention.time_maa_w2",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: ("rwkv.blocks.{bid}.attention.time_maa_x",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: ("rwkv.blocks.{bid}.attention.time_maa_k",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: ("rwkv.blocks.{bid}.attention.time_maa_v",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: ("rwkv.blocks.{bid}.attention.time_maa_r",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: ("rwkv.blocks.{bid}.attention.time_maa_g",),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: ("rwkv.blocks.{bid}.attention.time_maa_w",),
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: ("rwkv.blocks.{bid}.attention.time_faaaa",),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: ("rwkv.blocks.{bid}.attention.time_decay",),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: ("rwkv.blocks.{bid}.attention.time_decay_w1",),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: ("rwkv.blocks.{bid}.attention.time_decay_w2",),
|
||||
MODEL_TENSOR.TIME_MIX_KEY: ("rwkv.blocks.{bid}.attention.key",),
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: ("rwkv.blocks.{bid}.attention.value",),
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: ("rwkv.blocks.{bid}.attention.receptance",),
|
||||
MODEL_TENSOR.TIME_MIX_GATE: ("rwkv.blocks.{bid}.attention.gate",),
|
||||
MODEL_TENSOR.TIME_MIX_LN: ("rwkv.blocks.{bid}.attention.ln_x",),
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: ("rwkv.blocks.{bid}.attention.output",),
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: ("rwkv.blocks.{bid}.feed_forward.time_maa_k",),
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: ("rwkv.blocks.{bid}.feed_forward.time_maa_r",),
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: ("rwkv.blocks.{bid}.feed_forward.key",),
|
||||
MODEL_TENSOR.TIME_MIX_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w1", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w2", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_X: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_x", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_k", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_V: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_v", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_r", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_G: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_g", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_LERP_W: (
|
||||
"rwkv.blocks.{bid}.attention.time_maa_w", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_FIRST: (
|
||||
"rwkv.blocks.{bid}.attention.time_faaaa", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W1: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w1", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_DECAY_W2: (
|
||||
"rwkv.blocks.{bid}.attention.time_decay_w2", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_KEY: ("rwkv.blocks.{bid}.attention.key",), # rwkv
|
||||
MODEL_TENSOR.TIME_MIX_VALUE: ("rwkv.blocks.{bid}.attention.value",), # rwkv
|
||||
MODEL_TENSOR.TIME_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.attention.receptance", # rwkv
|
||||
),
|
||||
MODEL_TENSOR.TIME_MIX_GATE: ("rwkv.blocks.{bid}.attention.gate",), # rwkv
|
||||
MODEL_TENSOR.TIME_MIX_LN: ("rwkv.blocks.{bid}.attention.ln_x",), # rwkv
|
||||
MODEL_TENSOR.TIME_MIX_OUTPUT: ("rwkv.blocks.{bid}.attention.output",), # rwkv
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_K: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_k", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.CHANNEL_MIX_LERP_R: (
|
||||
"rwkv.blocks.{bid}.feed_forward.time_maa_r", # rwkv v6
|
||||
),
|
||||
MODEL_TENSOR.CHANNEL_MIX_KEY: ("rwkv.blocks.{bid}.feed_forward.key",), # rwkv
|
||||
MODEL_TENSOR.CHANNEL_MIX_RECEPTANCE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance",
|
||||
"rwkv.blocks.{bid}.feed_forward.receptance", # rwkv
|
||||
),
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: ("rwkv.blocks.{bid}.feed_forward.value",),
|
||||
MODEL_TENSOR.ATTN_Q_A: ("model.layers.{bid}.self_attn.q_a_proj",),
|
||||
MODEL_TENSOR.ATTN_Q_B: ("model.layers.{bid}.self_attn.q_b_proj",),
|
||||
MODEL_TENSOR.CHANNEL_MIX_VALUE: (
|
||||
"rwkv.blocks.{bid}.feed_forward.value", # rwkv
|
||||
),
|
||||
MODEL_TENSOR.ATTN_Q_A: ("model.layers.{bid}.self_attn.q_a_proj",), # deepseek2
|
||||
MODEL_TENSOR.ATTN_Q_B: ("model.layers.{bid}.self_attn.q_b_proj",), # deepseek2
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa",
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
),
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: (
|
||||
"model.layers.{bid}.self_attn.inner_attn_ln", # bitnet
|
||||
),
|
||||
MODEL_TENSOR.FFN_SUB_NORM: ("model.layers.{bid}.mlp.ffn_layernorm",), # bitnet
|
||||
MODEL_TENSOR.DEC_ATTN_NORM: ("decoder.block.{bid}.layer.0.layer_norm",), # t5
|
||||
MODEL_TENSOR.DEC_ATTN_Q: ("decoder.block.{bid}.layer.0.SelfAttention.q",), # t5
|
||||
MODEL_TENSOR.DEC_ATTN_K: ("decoder.block.{bid}.layer.0.SelfAttention.k",), # t5
|
||||
MODEL_TENSOR.DEC_ATTN_V: ("decoder.block.{bid}.layer.0.SelfAttention.v",), # t5
|
||||
MODEL_TENSOR.DEC_ATTN_OUT: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.o", # t5
|
||||
),
|
||||
MODEL_TENSOR.ATTN_KV_B: ("model.layers.{bid}.self_attn.kv_b_proj",),
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: ("model.layers.{bid}.self_attn.q_a_layernorm",),
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: ("model.layers.{bid}.self_attn.kv_a_layernorm",),
|
||||
MODEL_TENSOR.ATTN_SUB_NORM: ("model.layers.{bid}.self_attn.inner_attn_ln",),
|
||||
MODEL_TENSOR.FFN_SUB_NORM: ("model.layers.{bid}.mlp.ffn_layernorm",),
|
||||
MODEL_TENSOR.DEC_ATTN_NORM: ("decoder.block.{bid}.layer.0.layer_norm",),
|
||||
MODEL_TENSOR.DEC_ATTN_Q: ("decoder.block.{bid}.layer.0.SelfAttention.q",),
|
||||
MODEL_TENSOR.DEC_ATTN_K: ("decoder.block.{bid}.layer.0.SelfAttention.k",),
|
||||
MODEL_TENSOR.DEC_ATTN_V: ("decoder.block.{bid}.layer.0.SelfAttention.v",),
|
||||
MODEL_TENSOR.DEC_ATTN_OUT: ("decoder.block.{bid}.layer.0.SelfAttention.o",),
|
||||
MODEL_TENSOR.DEC_ATTN_REL_B: (
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
|
||||
"decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
|
||||
"decoder.block.{bid}.layer.1.layer_norm", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: ("decoder.block.{bid}.layer.1.layer_norm",),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.q",
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.q", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_K: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.k",
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.k", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_V: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.v",
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.v", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.o",
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.o", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias",
|
||||
"decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_FFN_NORM: ("decoder.block.{bid}.layer.2.layer_norm",), # t5
|
||||
MODEL_TENSOR.DEC_FFN_GATE: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi_0", # flan-t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_FFN_NORM: ("decoder.block.{bid}.layer.2.layer_norm",),
|
||||
MODEL_TENSOR.DEC_FFN_GATE: ("decoder.block.{bid}.layer.2.DenseReluDense.wi_0",),
|
||||
MODEL_TENSOR.DEC_FFN_UP: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi",
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi_1",
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi", # t5
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wi_1", # flan-t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_FFN_DOWN: (
|
||||
"decoder.block.{bid}.layer.2.DenseReluDense.wo", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_OUTPUT_NORM: ("decoder.final_layer_norm",), # t5
|
||||
MODEL_TENSOR.ENC_ATTN_NORM: ("encoder.block.{bid}.layer.0.layer_norm",), # t5
|
||||
MODEL_TENSOR.ENC_ATTN_Q: ("encoder.block.{bid}.layer.0.SelfAttention.q",), # t5
|
||||
MODEL_TENSOR.ENC_ATTN_K: ("encoder.block.{bid}.layer.0.SelfAttention.k",), # t5
|
||||
MODEL_TENSOR.ENC_ATTN_V: ("encoder.block.{bid}.layer.0.SelfAttention.v",), # t5
|
||||
MODEL_TENSOR.ENC_ATTN_OUT: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.o", # t5
|
||||
),
|
||||
MODEL_TENSOR.DEC_FFN_DOWN: ("decoder.block.{bid}.layer.2.DenseReluDense.wo",),
|
||||
MODEL_TENSOR.DEC_OUTPUT_NORM: ("decoder.final_layer_norm",),
|
||||
MODEL_TENSOR.ENC_ATTN_NORM: ("encoder.block.{bid}.layer.0.layer_norm",),
|
||||
MODEL_TENSOR.ENC_ATTN_Q: ("encoder.block.{bid}.layer.0.SelfAttention.q",),
|
||||
MODEL_TENSOR.ENC_ATTN_K: ("encoder.block.{bid}.layer.0.SelfAttention.k",),
|
||||
MODEL_TENSOR.ENC_ATTN_V: ("encoder.block.{bid}.layer.0.SelfAttention.v",),
|
||||
MODEL_TENSOR.ENC_ATTN_OUT: ("encoder.block.{bid}.layer.0.SelfAttention.o",),
|
||||
MODEL_TENSOR.ENC_ATTN_REL_B: (
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
|
||||
"encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias", # t5
|
||||
),
|
||||
MODEL_TENSOR.ENC_FFN_NORM: ("encoder.block.{bid}.layer.1.layer_norm",), # t5
|
||||
MODEL_TENSOR.ENC_FFN_GATE: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi_0", # flan-t5
|
||||
),
|
||||
MODEL_TENSOR.ENC_FFN_NORM: ("encoder.block.{bid}.layer.1.layer_norm",),
|
||||
MODEL_TENSOR.ENC_FFN_GATE: ("encoder.block.{bid}.layer.1.DenseReluDense.wi_0",),
|
||||
MODEL_TENSOR.ENC_FFN_UP: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi",
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi", # t5
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wi_1", # flan-t5
|
||||
),
|
||||
MODEL_TENSOR.ENC_FFN_DOWN: (
|
||||
"encoder.block.{bid}.layer.1.DenseReluDense.wo", # t5
|
||||
),
|
||||
############################################################################
|
||||
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: ("encoder.final_layer_norm",), # t5
|
||||
MODEL_TENSOR.CLS: (
|
||||
"classifier", # jina
|
||||
"classifier.dense", # roberta
|
||||
),
|
||||
MODEL_TENSOR.CLS_OUT: ("classifier.out_proj",), # roberta
|
||||
#############################################################################
|
||||
MODEL_TENSOR.CONVNEXT_DW: ("backbone.convnext.{bid}.dwconv",), # wavtokenizer
|
||||
MODEL_TENSOR.CONVNEXT_NORM: ("backbone.convnext.{bid}.norm",), # wavtokenizer
|
||||
MODEL_TENSOR.CONVNEXT_PW1: ("backbone.convnext.{bid}.pwconv1",), # wavtokenizer
|
||||
MODEL_TENSOR.CONVNEXT_PW2: ("backbone.convnext.{bid}.pwconv2",), # wavtokenizer
|
||||
MODEL_TENSOR.CONVNEXT_GAMMA: ("backbone.convnext.{bid}.gamma",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_CONV1: ("backbone.posnet.{bid}.conv1",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_CONV2: ("backbone.posnet.{bid}.conv2",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_NORM: ("backbone.posnet.{bid}.norm",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_NORM1: ("backbone.posnet.{bid}.norm1",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_NORM2: ("backbone.posnet.{bid}.norm2",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_ATTN_NORM: ("backbone.posnet.{bid}.norm",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_ATTN_Q: ("backbone.posnet.{bid}.q",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_ATTN_K: ("backbone.posnet.{bid}.k",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_ATTN_V: ("backbone.posnet.{bid}.v",), # wavtokenizer
|
||||
MODEL_TENSOR.POSNET_ATTN_OUT: (
|
||||
"backbone.posnet.{bid}.proj_out", # wavtokenizer
|
||||
),
|
||||
MODEL_TENSOR.ENC_FFN_DOWN: ("encoder.block.{bid}.layer.1.DenseReluDense.wo",),
|
||||
MODEL_TENSOR.ENC_OUTPUT_NORM: ("encoder.final_layer_norm",),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
|
||||
MODEL_ARCH.ARCTIC: {
|
||||
MODEL_TENSOR.FFN_NORM: ("model.layers.{bid}.residual_layernorm",),
|
||||
|
|
|
@ -157,8 +157,36 @@ def _try_load_from_tokenizer_json(self, path: Path) -> bool:
|
|||
tokenizer = json.load(f)
|
||||
if self.load_merges:
|
||||
merges = tokenizer.get("model", {}).get("merges")
|
||||
if isinstance(merges, list) and merges and isinstance(merges[0], str):
|
||||
self.merges = merges
|
||||
if isinstance(merges, list) and merges:
|
||||
if isinstance(merges[0], str):
|
||||
self.merges = merges
|
||||
elif (
|
||||
isinstance(merges[0], list)
|
||||
and len(merges[0]) == 2
|
||||
and isinstance(merges[0][0], str)
|
||||
):
|
||||
# New format since transformers 4.45 to support spaces in merges
|
||||
# ref: https://github.com/ggerganov/llama.cpp/issues/9692
|
||||
# TODO: internally store as the new format instead of converting to old
|
||||
if any(" " in s for pair in merges for s in pair):
|
||||
logger.warning(
|
||||
f'Spaces in merges detected, encoding as {chr(ord(" ") + 256)!r}'
|
||||
)
|
||||
self.merges = [
|
||||
" ".join(
|
||||
[
|
||||
# ensure the spaces are properly encoded
|
||||
"".join(
|
||||
chr(ord(c) + 256) if c == " " else c
|
||||
for c in part
|
||||
)
|
||||
for part in pair
|
||||
]
|
||||
)
|
||||
for pair in merges
|
||||
]
|
||||
else:
|
||||
raise ValueError("Unknown tokenizer merges format")
|
||||
added_tokens = tokenizer.get("added_tokens", {})
|
||||
else:
|
||||
added_tokens = {}
|
||||
|
@ -225,7 +253,6 @@ class Vocab(BaseVocab, Protocol):
|
|||
fname_tokenizer: Path
|
||||
|
||||
def __init__(self, base_path: Path): ...
|
||||
|
||||
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
|
||||
|
||||
|
||||
|
|
|
@ -80,11 +80,15 @@ def load_dotenv(self=Any) -> None:
|
|||
|
||||
|
||||
def show_about(self) -> None:
|
||||
about_text = (
|
||||
"AutoGGUF\n\n"
|
||||
f"Version: {AUTOGGUF_VERSION}\n\n"
|
||||
"A tool for managing and converting GGUF models."
|
||||
)
|
||||
about_text = f"""AutoGGUF
|
||||
|
||||
Version: {AUTOGGUF_VERSION}
|
||||
|
||||
A tool for managing and converting GGUF models.
|
||||
This application is licensed under the Apache License 2.0.
|
||||
Copyright (c) 2025 leafspark.
|
||||
It also utilizes llama.cpp, licensed under the MIT License.
|
||||
Copyright (c) 2023-2024 The ggml authors."""
|
||||
QMessageBox.about(self, "About AutoGGUF", about_text)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue