refactor: optimize GGUF imports

- optimize imports in GGUF conversion utilities
- rename gguf library modules
- update .gitignore and build workflow
This commit is contained in:
BuildTools 2024-09-14 10:11:43 -07:00
parent 3804da0a3f
commit 747aa7b9a8
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
17 changed files with 605 additions and 762 deletions

View File

@ -61,8 +61,8 @@ jobs:
if: matrix.os == 'windows-latest'
run: |
$distPath = if ("${{ github.event.inputs.build_type }}" -eq "RELEASE") { "build\release\dist" } else { "build\dev\dist" }
New-Item -ItemType Directory -Force -Path "$distPath\src\gguf-py"
Copy-Item -Path "src\gguf-py\*" -Destination "$distPath\src\gguf-py" -Recurse
New-Item -ItemType Directory -Force -Path "$distPath\src\gguf"
Copy-Item -Path "src\gguf\*" -Destination "$distPath\src\gguf" -Recurse
Copy-Item -Path "src\convert_hf_to_gguf.py" -Destination "$distPath\src"
Copy-Item -Path "src\convert_lora_to_gguf.py" -Destination "$distPath\src"
Copy-Item -Path "src\convert_lora_to_ggml.py" -Destination "$distPath\src"
@ -72,8 +72,8 @@ jobs:
if: matrix.os != 'windows-latest'
run: |
distPath=$(if [ "${{ github.event.inputs.build_type }}" = "RELEASE" ]; then echo "build/release/dist"; else echo "build/dev/dist"; fi)
mkdir -p $distPath/src/gguf-py
cp -R src/gguf-py/* $distPath/src/gguf-py/
mkdir -p $distPath/src/gguf
cp -R src/gguf/* $distPath/src/gguf/
cp src/convert_hf_to_gguf.py $distPath/src/
cp src/convert_lora_to_gguf.py $distPath/src/
cp src/convert_lora_to_ggml.py $distPath/src/

3
.gitignore vendored
View File

@ -20,6 +20,9 @@ __pycache__/
!src/
src/*
!src/*.py
!src/gguf
src/gguf/*
!src/gguf/*.py
# Allow docs folder and its .py files
!docs/

View File

@ -1,8 +1,8 @@
import importlib
import json
import shutil
import urllib.request
import urllib.error
import urllib.request
from datetime import datetime
from functools import partial, wraps
from typing import Any, Dict, List, Tuple
@ -24,10 +24,10 @@
from error_handling import handle_error, show_error
from imports_and_globals import (
ensure_directory,
load_dotenv,
open_file_safe,
resource_path,
show_about,
load_dotenv,
)
@ -41,21 +41,18 @@ def wrapper(self, *args, **kwargs):
# Length check
if len(value) > 1024:
show_error(f"{field} exceeds maximum length")
show_error(self.logger, f"{field} exceeds maximum length")
# Normalize path
normalized_path = os.path.normpath(value)
# Check for path traversal attempts
if ".." in normalized_path:
show_error(f"Invalid path in {field}")
show_error(self.logger, f"Invalid path in {field}")
# Disallow control characters and null bytes
if re.search(r"[\x00-\x1f\x7f]", value):
show_error(f"Invalid characters in {field}")
# Update the field with normalized path
getattr(self, field).setText(normalized_path)
show_error(self.logger, f"Invalid characters in {field}")
return func(self, *args, **kwargs)

View File

@ -30,8 +30,6 @@
if TYPE_CHECKING:
from torch import Tensor
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
import gguf
logger = logging.getLogger("hf-to-gguf")

View File

@ -1,19 +1,17 @@
from __future__ import annotations
import logging
import json
import logging
import os
import struct
import sys
from pathlib import Path
from typing import Any, BinaryIO, Sequence
from typing import BinaryIO
import numpy as np
import torch
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
import gguf
from gguf.constants import *
from gguf.tensor_mapping import *
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger("lora-to-gguf")
@ -51,11 +49,6 @@ def write_tensor_header(
fout.seek((fout.tell() + 31) & -32)
def pyinstaller_include():
# PyInstaller import
pass
if __name__ == "__main__":
if len(sys.argv) < 2:
logger.info(f"Usage: python {sys.argv[0]} <path> <output_path> [arch]")
@ -63,7 +56,7 @@ def pyinstaller_include():
"Path must contain HuggingFace PEFT LoRA files 'adapter_config.json' and 'adapter_model.bin'"
)
logger.info(
f"Arch must be one of {list(gguf.MODEL_ARCH_NAMES.values())} (default: llama)"
f"Arch must be one of {list(MODEL_ARCH_NAMES.values())} (default: llama)"
)
sys.exit(1)
@ -82,14 +75,14 @@ def pyinstaller_include():
arch_name = sys.argv[3] if len(sys.argv) == 4 else "llama"
if arch_name not in gguf.MODEL_ARCH_NAMES.values():
if arch_name not in MODEL_ARCH_NAMES.values():
logger.error(f"Error: unsupported architecture {arch_name}")
sys.exit(1)
arch = list(gguf.MODEL_ARCH_NAMES.keys())[
list(gguf.MODEL_ARCH_NAMES.values()).index(arch_name)
arch = list(MODEL_ARCH_NAMES.keys())[
list(MODEL_ARCH_NAMES.values()).index(arch_name)
]
name_map = gguf.TensorNameMap(arch, 500)
name_map = TensorNameMap(arch, 500)
with open(input_json, "r") as f:
params = json.load(f)

View File

@ -24,9 +24,7 @@
if TYPE_CHECKING:
from torch import Tensor
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
import gguf
from gguf.constants import *
from convert_hf_to_gguf import LazyTorchTensor, Model

View File

@ -1,15 +0,0 @@
# This file left for compatibility. If you want to use the GGUF API from Python
# then don't import gguf/gguf.py directly. If you're looking for examples, see the
# examples/ directory for gguf-py
import importlib
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
# Compatibility for people trying to import gguf/gguf.py directly instead of as a package.
importlib.invalidate_caches()
import gguf # noqa: E402
importlib.reload(gguf)

View File

@ -8,6 +8,7 @@
GGUF_DEFAULT_ALIGNMENT = 32
GGML_QUANT_VERSION = 2
class Keys:
class General:
TYPE = "general.type"
@ -150,10 +151,12 @@ class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"
class GGUFType:
MODEL = "model"
ADAPTER = "adapter"
class MODEL_ARCH(IntEnum):
LLAMA = auto()
FALCON = auto()
@ -199,6 +202,7 @@ class MODEL_ARCH(IntEnum):
NEMOTRON = auto()
EXAONE = auto()
class MODEL_TENSOR(IntEnum):
TOKEN_EMBD = auto()
TOKEN_EMBD_NORM = auto()
@ -282,6 +286,7 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.LLAMA: "llama",
MODEL_ARCH.FALCON: "falcon",
@ -1075,7 +1080,6 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
}
MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -1120,6 +1124,7 @@ class MODEL_TENSOR(IntEnum):
],
}
class TokenType(IntEnum):
NORMAL = 1
UNKNOWN = 2
@ -1128,16 +1133,19 @@ class TokenType(IntEnum):
UNUSED = 5
BYTE = 6
class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
NONE = "none"
LINEAR = "linear"
YARN = "yarn"
class PoolingType(IntEnum):
NONE = 0
MEAN = 1
CLS = 2
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
@ -1172,6 +1180,7 @@ class GGMLQuantizationType(IntEnum):
Q4_0_4_8 = 32
Q4_0_8_8 = 33
class LlamaFileType(IntEnum):
ALL_F32 = 0
MOSTLY_F16 = 1
@ -1210,10 +1219,12 @@ class LlamaFileType(IntEnum):
GUESSED = 1024
class GGUFEndian(IntEnum):
LITTLE = 0
BIG = 1
class GGUFValueType(IntEnum):
UINT8 = 0
INT8 = 1
@ -1245,6 +1256,7 @@ def get_type(val: Any) -> GGUFValueType:
else:
raise ValueError(f"Unknown type: {type(val)}")
QK_K = 256
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4),

View File

@ -441,9 +441,9 @@ def apply_metadata_heuristic(
org_component is not None
and model_full_name_component is not None
):
base_model[
"repo_url"
] = f"https://huggingface.co/{org_component}/{model_full_name_component}"
base_model["repo_url"] = (
f"https://huggingface.co/{org_component}/{model_full_name_component}"
)
metadata.base_models.append(base_model)
if "license" in model_card and metadata.license is None:

View File

@ -4,9 +4,9 @@
from .constants import MODEL_ARCH, MODEL_TENSOR, MODEL_TENSORS, TENSOR_NAMES
class TensorNameMap:
mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in",
"transformer.wte",
@ -27,24 +27,18 @@ class TensorNameMap:
"transformer.token_embeddings",
"shared",
),
MODEL_TENSOR.TOKEN_TYPES: (
"embeddings.token_type_embeddings",
),
MODEL_TENSOR.TOKEN_TYPES: ("embeddings.token_type_embeddings",),
MODEL_TENSOR.TOKEN_EMBD_NORM: (
"word_embeddings_layernorm",
"embeddings.LayerNorm",
"emb_ln",
"transformer.norm",
),
MODEL_TENSOR.POS_EMBD: (
"transformer.wpe",
"embeddings.position_embeddings",
"wpe",
),
MODEL_TENSOR.OUTPUT: (
"embed_out",
"lm_head",
@ -53,7 +47,6 @@ class TensorNameMap:
"lm_head.linear",
"output_layer",
),
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm",
"transformer.ln_f",
@ -71,7 +64,6 @@ class TensorNameMap:
"transformer.norm",
"model.norm",
),
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs",
"rotary_pos_emb.inv_freq",
@ -79,7 +71,6 @@ class TensorNameMap:
}
block_mappings_cfg: dict[MODEL_TENSOR, tuple[str, ...]] = {
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm",
"transformer.h.{bid}.ln_1",
@ -102,12 +93,10 @@ class TensorNameMap:
"encoder.layers.{bid}.input_layernorm",
"transformer.layers.{bid}.attn_norm",
),
MODEL_TENSOR.ATTN_NORM_2: (
"transformer.h.{bid}.ln_attn",
"encoder.layer.{bid}.layer_norm_1",
),
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value",
"transformer.h.{bid}.attn.c_attn",
@ -124,7 +113,6 @@ class TensorNameMap:
"encoder.layers.{bid}.self_attention.query_key_value",
"transformer.layers.{bid}.attn.qkv_proj",
),
MODEL_TENSOR.ATTN_Q: (
"model.layers.{bid}.self_attn.q_proj",
"layers.{bid}.attention.wq",
@ -135,7 +123,6 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.multi_head_attention.query",
"transformer.h.{bid}.attn.attention.q_proj",
),
MODEL_TENSOR.ATTN_K: (
"model.layers.{bid}.self_attn.k_proj",
"layers.{bid}.attention.wk",
@ -147,7 +134,6 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.multi_head_attention.key",
"transformer.h.{bid}.attn.attention.k_proj",
),
MODEL_TENSOR.ATTN_V: (
"model.layers.{bid}.self_attn.v_proj",
"layers.{bid}.attention.wv",
@ -159,7 +145,6 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.multi_head_attention.value",
"transformer.h.{bid}.attn.attention.v_proj",
),
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense",
"transformer.h.{bid}.attn.c_proj",
@ -183,25 +168,19 @@ class TensorNameMap:
"transformer.layers.{bid}.attn.out_proj",
"transformer.h.{bid}.attn.attention.out_proj",
),
MODEL_TENSOR.ATTN_OUT_NORM: (
"encoder.layer.{bid}.attention.output.LayerNorm",
"encoder.layers.{bid}.norm1",
"transformer.decoder_layer.{bid}.rms_norm_1",
"transformer.blocks.{bid}.norm_attn_norm.norm_2",
),
MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm",
),
MODEL_TENSOR.ATTN_POST_NORM: ("model.layers.{bid}.post_attention_layernorm",),
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq",
"layers.{bid}.attention.inner_attention.rope.freqs",
"model.layers.layers.{bid}.self_attn.rotary_emb.inv_freq",
"transformer.h.{bid}.attn.rotary_emb.inv_freq",
),
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm",
"transformer.h.{bid}.ln_2",
@ -217,15 +196,8 @@ class TensorNameMap:
"encoder.layers.{bid}.post_attention_layernorm",
"transformer.layers.{bid}.ffn_norm",
),
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm",
),
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm",
),
MODEL_TENSOR.FFN_PRE_NORM: ("model.layers.{bid}.pre_feedforward_layernorm",),
MODEL_TENSOR.FFN_POST_NORM: ("model.layers.{bid}.post_feedforward_layernorm",),
MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate",
"model.layers.{bid}.block_sparse_moe.gate",
@ -233,11 +205,7 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.router",
"transformer.blocks.{bid}.ffn.router.layer",
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert_gate",
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: ("model.layers.{bid}.mlp.shared_expert_gate",),
MODEL_TENSOR.FFN_UP: (
"gpt_neox.layers.{bid}.mlp.dense_h_to_4h",
"transformer.h.{bid}.mlp.c_fc",
@ -265,23 +233,17 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_h_to_4h",
"transformer.h.{bid}.mlp.c_fc_1",
),
MODEL_TENSOR.FFN_UP_EXP: (
"layers.{bid}.feed_forward.experts.w3",
"transformer.decoder_layer.{bid}.moe.linear_v",
"transformer.blocks.{bid}.ffn.experts.mlp.v1",
"model.layers.{bid}.mlp.experts.up_proj",
),
MODEL_TENSOR.FFN_UP_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.up_proj",
"model.layers.{bid}.mlp.shared_experts.up_proj",
),
MODEL_TENSOR.FFN_ACT: (
"transformer.blocks.{bid}.ffn.act",
),
MODEL_TENSOR.FFN_ACT: ("transformer.blocks.{bid}.ffn.act",),
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj",
"layers.{bid}.feed_forward.w1",
@ -295,19 +257,16 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w1",
"transformer.h.{bid}.mlp.c_fc_0",
),
MODEL_TENSOR.FFN_GATE_EXP: (
"layers.{bid}.feed_forward.experts.w1",
"transformer.decoder_layer.{bid}.moe.linear",
"transformer.blocks.{bid}.ffn.experts.mlp.w1",
"model.layers.{bid}.mlp.experts.gate_proj",
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.gate_proj",
"model.layers.{bid}.mlp.shared_experts.gate_proj",
),
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h",
"transformer.h.{bid}.mlp.c_proj",
@ -334,19 +293,16 @@ class TensorNameMap:
"encoder.layers.{bid}.mlp.dense_4h_to_h",
"model.layers.h.{bid}.mlp.c_proj",
),
MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2",
"transformer.decoder_layer.{bid}.moe.linear_1",
"transformer.blocks.{bid}.ffn.experts.mlp.w2",
"model.layers.{bid}.mlp.experts.down_proj",
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
"model.layers.{bid}.mlp.shared_expert.down_proj",
"model.layers.{bid}.mlp.shared_experts.down_proj",
),
MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm",
@ -355,7 +311,6 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_q",
"transformer.layers.{bid}.attn.q_norm",
),
MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm",
@ -364,209 +319,108 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.self.layer_norm_k",
"transformer.layers.{bid}.attn.k_norm",
),
MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq",
),
MODEL_TENSOR.LAYER_OUT_NORM: (
"encoder.layer.{bid}.output.LayerNorm",
"encoder.layers.{bid}.norm2",
"transformer.decoder_layer.{bid}.rms_norm_3",
"encoder.layer.{bid}.mlp.layernorm",
"encoder.layer.{bid}.layer_norm_2"
"encoder.layer.{bid}.layer_norm_2",
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
),
MODEL_TENSOR.SSM_X: (
"model.layers.{bid}.x_proj",
"backbone.layers.{bid}.mixer.x_proj",
),
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
),
MODEL_TENSOR.ATTN_Q_A: (
"model.layers.{bid}.self_attn.q_a_proj",
),
MODEL_TENSOR.ATTN_Q_B: (
"model.layers.{bid}.self_attn.q_b_proj",
),
MODEL_TENSOR.ATTN_Q_A: ("model.layers.{bid}.self_attn.q_a_proj",),
MODEL_TENSOR.ATTN_Q_B: ("model.layers.{bid}.self_attn.q_b_proj",),
MODEL_TENSOR.ATTN_KV_A_MQA: (
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa",
),
MODEL_TENSOR.ATTN_KV_B: (
"model.layers.{bid}.self_attn.kv_b_proj",
),
MODEL_TENSOR.ATTN_Q_A_NORM: (
"model.layers.{bid}.self_attn.q_a_layernorm",
),
MODEL_TENSOR.ATTN_KV_A_NORM: (
"model.layers.{bid}.self_attn.kv_a_layernorm",
),
MODEL_TENSOR.ATTN_SUB_NORM: (
"model.layers.{bid}.self_attn.inner_attn_ln",
),
MODEL_TENSOR.FFN_SUB_NORM: (
"model.layers.{bid}.mlp.ffn_layernorm",
),
MODEL_TENSOR.DEC_ATTN_NORM: (
"decoder.block.{bid}.layer.0.layer_norm",
),
MODEL_TENSOR.DEC_ATTN_Q: (
"decoder.block.{bid}.layer.0.SelfAttention.q",
),
MODEL_TENSOR.DEC_ATTN_K: (
"decoder.block.{bid}.layer.0.SelfAttention.k",
),
MODEL_TENSOR.DEC_ATTN_V: (
"decoder.block.{bid}.layer.0.SelfAttention.v",
),
MODEL_TENSOR.DEC_ATTN_OUT: (
"decoder.block.{bid}.layer.0.SelfAttention.o",
),
MODEL_TENSOR.ATTN_KV_B: ("model.layers.{bid}.self_attn.kv_b_proj",),
MODEL_TENSOR.ATTN_Q_A_NORM: ("model.layers.{bid}.self_attn.q_a_layernorm",),
MODEL_TENSOR.ATTN_KV_A_NORM: ("model.layers.{bid}.self_attn.kv_a_layernorm",),
MODEL_TENSOR.ATTN_SUB_NORM: ("model.layers.{bid}.self_attn.inner_attn_ln",),
MODEL_TENSOR.FFN_SUB_NORM: ("model.layers.{bid}.mlp.ffn_layernorm",),
MODEL_TENSOR.DEC_ATTN_NORM: ("decoder.block.{bid}.layer.0.layer_norm",),
MODEL_TENSOR.DEC_ATTN_Q: ("decoder.block.{bid}.layer.0.SelfAttention.q",),
MODEL_TENSOR.DEC_ATTN_K: ("decoder.block.{bid}.layer.0.SelfAttention.k",),
MODEL_TENSOR.DEC_ATTN_V: ("decoder.block.{bid}.layer.0.SelfAttention.v",),
MODEL_TENSOR.DEC_ATTN_OUT: ("decoder.block.{bid}.layer.0.SelfAttention.o",),
MODEL_TENSOR.DEC_ATTN_REL_B: (
"decoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
),
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: (
"decoder.block.{bid}.layer.1.layer_norm",
),
MODEL_TENSOR.DEC_CROSS_ATTN_NORM: ("decoder.block.{bid}.layer.1.layer_norm",),
MODEL_TENSOR.DEC_CROSS_ATTN_Q: (
"decoder.block.{bid}.layer.1.EncDecAttention.q",
),
MODEL_TENSOR.DEC_CROSS_ATTN_K: (
"decoder.block.{bid}.layer.1.EncDecAttention.k",
),
MODEL_TENSOR.DEC_CROSS_ATTN_V: (
"decoder.block.{bid}.layer.1.EncDecAttention.v",
),
MODEL_TENSOR.DEC_CROSS_ATTN_OUT: (
"decoder.block.{bid}.layer.1.EncDecAttention.o",
),
MODEL_TENSOR.DEC_CROSS_ATTN_REL_B: (
"decoder.block.{bid}.layer.1.EncDecAttention.relative_attention_bias",
),
MODEL_TENSOR.DEC_FFN_NORM: (
"decoder.block.{bid}.layer.2.layer_norm",
),
MODEL_TENSOR.DEC_FFN_GATE: (
"decoder.block.{bid}.layer.2.DenseReluDense.wi_0",
),
MODEL_TENSOR.DEC_FFN_NORM: ("decoder.block.{bid}.layer.2.layer_norm",),
MODEL_TENSOR.DEC_FFN_GATE: ("decoder.block.{bid}.layer.2.DenseReluDense.wi_0",),
MODEL_TENSOR.DEC_FFN_UP: (
"decoder.block.{bid}.layer.2.DenseReluDense.wi",
"decoder.block.{bid}.layer.2.DenseReluDense.wi_1",
),
MODEL_TENSOR.DEC_FFN_DOWN: (
"decoder.block.{bid}.layer.2.DenseReluDense.wo",
),
MODEL_TENSOR.DEC_OUTPUT_NORM: (
"decoder.final_layer_norm",
),
MODEL_TENSOR.ENC_ATTN_NORM: (
"encoder.block.{bid}.layer.0.layer_norm",
),
MODEL_TENSOR.ENC_ATTN_Q: (
"encoder.block.{bid}.layer.0.SelfAttention.q",
),
MODEL_TENSOR.ENC_ATTN_K: (
"encoder.block.{bid}.layer.0.SelfAttention.k",
),
MODEL_TENSOR.ENC_ATTN_V: (
"encoder.block.{bid}.layer.0.SelfAttention.v",
),
MODEL_TENSOR.ENC_ATTN_OUT: (
"encoder.block.{bid}.layer.0.SelfAttention.o",
),
MODEL_TENSOR.DEC_FFN_DOWN: ("decoder.block.{bid}.layer.2.DenseReluDense.wo",),
MODEL_TENSOR.DEC_OUTPUT_NORM: ("decoder.final_layer_norm",),
MODEL_TENSOR.ENC_ATTN_NORM: ("encoder.block.{bid}.layer.0.layer_norm",),
MODEL_TENSOR.ENC_ATTN_Q: ("encoder.block.{bid}.layer.0.SelfAttention.q",),
MODEL_TENSOR.ENC_ATTN_K: ("encoder.block.{bid}.layer.0.SelfAttention.k",),
MODEL_TENSOR.ENC_ATTN_V: ("encoder.block.{bid}.layer.0.SelfAttention.v",),
MODEL_TENSOR.ENC_ATTN_OUT: ("encoder.block.{bid}.layer.0.SelfAttention.o",),
MODEL_TENSOR.ENC_ATTN_REL_B: (
"encoder.block.{bid}.layer.0.SelfAttention.relative_attention_bias",
),
MODEL_TENSOR.ENC_FFN_NORM: (
"encoder.block.{bid}.layer.1.layer_norm",
),
MODEL_TENSOR.ENC_FFN_GATE: (
"encoder.block.{bid}.layer.1.DenseReluDense.wi_0",
),
MODEL_TENSOR.ENC_FFN_NORM: ("encoder.block.{bid}.layer.1.layer_norm",),
MODEL_TENSOR.ENC_FFN_GATE: ("encoder.block.{bid}.layer.1.DenseReluDense.wi_0",),
MODEL_TENSOR.ENC_FFN_UP: (
"encoder.block.{bid}.layer.1.DenseReluDense.wi",
"encoder.block.{bid}.layer.1.DenseReluDense.wi_1",
),
MODEL_TENSOR.ENC_FFN_DOWN: (
"encoder.block.{bid}.layer.1.DenseReluDense.wo",
),
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm",
),
MODEL_TENSOR.ENC_FFN_DOWN: ("encoder.block.{bid}.layer.1.DenseReluDense.wo",),
MODEL_TENSOR.ENC_OUTPUT_NORM: ("encoder.final_layer_norm",),
}
arch_block_mappings_cfg: dict[MODEL_ARCH, dict[MODEL_TENSOR, tuple[str, ...]]] = {
MODEL_ARCH.ARCTIC: {
MODEL_TENSOR.FFN_NORM: (
"model.layers.{bid}.residual_layernorm",
),
MODEL_TENSOR.FFN_NORM_EXP: (
"model.layers.{bid}.post_attention_layernorm",
),
MODEL_TENSOR.FFN_NORM: ("model.layers.{bid}.residual_layernorm",),
MODEL_TENSOR.FFN_NORM_EXP: ("model.layers.{bid}.post_attention_layernorm",),
},
}
@ -594,7 +448,9 @@ def __init__(self, arch: MODEL_ARCH, n_blocks: int):
key = key.format(bid=bid)
self.mapping[key] = (tensor, tensor_name)
def get_type_and_name(self, key: str, try_suffixes: Sequence[str] = ()) -> tuple[MODEL_TENSOR, str] | None:
def get_type_and_name(
self, key: str, try_suffixes: Sequence[str] = ()
) -> tuple[MODEL_TENSOR, str] | None:
result = self.mapping.get(key)
if result is not None:
return result
@ -611,7 +467,9 @@ def get_name(self, key: str, try_suffixes: Sequence[str] = ()) -> str | None:
return None
return result[1]
def get_type(self, key: str, try_suffixes: Sequence[str] = ()) -> MODEL_TENSOR | None:
def get_type(
self, key: str, try_suffixes: Sequence[str] = ()
) -> MODEL_TENSOR | None:
result = self.get_type_and_name(key, try_suffixes=try_suffixes)
if result is None:
return None
@ -629,5 +487,6 @@ def __contains__(self, key: str) -> bool:
def __repr__(self) -> str:
return repr(self.mapping)
def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> TensorNameMap:
return TensorNameMap(arch, n_blocks)

View File

@ -224,11 +224,9 @@ class Vocab(BaseVocab, Protocol):
added_tokens_list: list[str]
fname_tokenizer: Path
def __init__(self, base_path: Path):
...
def __init__(self, base_path: Path): ...
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
...
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ...
class NoVocab(BaseVocab):