mirror of https://github.com/leafspark/AutoGGUF
add assets fix, code format, progress bar fix
This commit is contained in:
parent
3e8d7b1415
commit
10911b71a0
|
@ -12,6 +12,7 @@
|
|||
import platform
|
||||
import requests
|
||||
import zipfile
|
||||
import re
|
||||
from datetime import datetime
|
||||
from imports_and_globals import ensure_directory, open_file_safe
|
||||
from DownloadThread import DownloadThread
|
||||
|
@ -144,10 +145,6 @@ def __init__(self):
|
|||
self.quant_type = QComboBox()
|
||||
self.quant_type.addItems(
|
||||
[
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q5_0",
|
||||
"Q5_1",
|
||||
"IQ2_XXS",
|
||||
"IQ2_XS",
|
||||
"IQ2_S",
|
||||
|
@ -173,12 +170,16 @@ def __init__(self):
|
|||
"Q5_K_S",
|
||||
"Q5_K_M",
|
||||
"Q6_K",
|
||||
"Q8_0",
|
||||
"Q8_0",
|
||||
"Q4_0",
|
||||
"Q4_1",
|
||||
"Q5_0",
|
||||
"Q5_1",
|
||||
"Q4_0_4_4",
|
||||
"Q4_0_4_8",
|
||||
"Q4_0_8_8",
|
||||
"F16",
|
||||
"BF16",
|
||||
"F16",
|
||||
"F32",
|
||||
"COPY",
|
||||
]
|
||||
|
@ -796,7 +797,7 @@ def export_lora(self):
|
|||
thread = QuantizationThread(command, backend_path, log_file)
|
||||
self.quant_threads.append(thread)
|
||||
|
||||
task_item = TaskListItem(EXPORTING_LORA, log_file)
|
||||
task_item = TaskListItem(EXPORTING_LORA, log_file, show_progress_bar=False)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
@ -888,7 +889,7 @@ def convert_lora(self):
|
|||
task_name = LORA_CONVERSION_FROM_TO.format(
|
||||
os.path.basename(lora_input_path), os.path.basename(lora_output_path)
|
||||
)
|
||||
task_item = TaskListItem(task_name, log_file)
|
||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
@ -967,13 +968,14 @@ def refresh_releases(self):
|
|||
response = requests.get(
|
||||
"https://api.github.com/repos/ggerganov/llama.cpp/releases"
|
||||
)
|
||||
response.raise_for_status() # Raise an exception for bad status codes
|
||||
releases = response.json()
|
||||
self.release_combo.clear()
|
||||
for release in releases:
|
||||
self.release_combo.addItem(release["tag_name"], userData=release)
|
||||
self.release_combo.currentIndexChanged.connect(self.update_assets)
|
||||
self.update_assets()
|
||||
except Exception as e:
|
||||
except requests.exceptions.RequestException as e:
|
||||
self.show_error(ERROR_FETCHING_RELEASES.format(str(e)))
|
||||
|
||||
def update_assets(self):
|
||||
|
@ -981,20 +983,13 @@ def update_assets(self):
|
|||
self.asset_combo.clear()
|
||||
release = self.release_combo.currentData()
|
||||
if release:
|
||||
for asset in release["assets"]:
|
||||
self.asset_combo.addItem(asset["name"], userData=asset)
|
||||
if "assets" in release:
|
||||
for asset in release["assets"]:
|
||||
self.asset_combo.addItem(asset["name"], userData=asset)
|
||||
else:
|
||||
self.show_error(NO_ASSETS_FOUND_FOR_RELEASE.format(release["tag_name"]))
|
||||
self.update_cuda_option()
|
||||
|
||||
def update_cuda_option(self):
|
||||
self.logger.debug(UPDATING_CUDA_OPTIONS)
|
||||
asset = self.asset_combo.currentData()
|
||||
is_cuda = asset and "cudart" in asset["name"].lower()
|
||||
self.cuda_extract_checkbox.setVisible(is_cuda)
|
||||
self.cuda_backend_label.setVisible(is_cuda)
|
||||
self.backend_combo_cuda.setVisible(is_cuda)
|
||||
if is_cuda:
|
||||
self.update_cuda_backends()
|
||||
|
||||
def download_llama_cpp(self):
|
||||
self.logger.info(STARTING_LLAMACPP_DOWNLOAD)
|
||||
asset = self.asset_combo.currentData()
|
||||
|
@ -1017,6 +1012,25 @@ def download_llama_cpp(self):
|
|||
self.download_button.setEnabled(False)
|
||||
self.download_progress.setValue(0)
|
||||
|
||||
def update_cuda_option(self):
|
||||
self.logger.debug(UPDATING_CUDA_OPTIONS)
|
||||
asset = self.asset_combo.currentData()
|
||||
|
||||
# Handle the case where asset is None
|
||||
if asset is None:
|
||||
self.logger.warning(NO_ASSET_SELECTED_FOR_CUDA_CHECK)
|
||||
self.cuda_extract_checkbox.setVisible(False)
|
||||
self.cuda_backend_label.setVisible(False)
|
||||
self.backend_combo_cuda.setVisible(False)
|
||||
return # Exit the function early
|
||||
|
||||
is_cuda = asset and "cudart" in asset["name"].lower()
|
||||
self.cuda_extract_checkbox.setVisible(is_cuda)
|
||||
self.cuda_backend_label.setVisible(is_cuda)
|
||||
self.backend_combo_cuda.setVisible(is_cuda)
|
||||
if is_cuda:
|
||||
self.update_cuda_backends()
|
||||
|
||||
def update_cuda_backends(self):
|
||||
self.logger.debug(UPDATING_CUDA_BACKENDS)
|
||||
self.backend_combo_cuda.clear()
|
||||
|
@ -1385,6 +1399,8 @@ def quantize_model(self):
|
|||
self.task_list.addItem(list_item)
|
||||
self.task_list.setItemWidget(list_item, task_item)
|
||||
|
||||
# Connect the output signal to the new progress parsing function
|
||||
thread.output_signal.connect(lambda line: self.parse_progress(line, task_item))
|
||||
thread.status_signal.connect(task_item.update_status)
|
||||
thread.finished_signal.connect(lambda: self.task_finished(thread))
|
||||
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
|
||||
|
@ -1401,6 +1417,15 @@ def update_model_info(self, model_info):
|
|||
# TODO: Do something with this
|
||||
pass
|
||||
|
||||
def parse_progress(self, line, task_item):
|
||||
# Parses the output line for progress information and updates the task item.
|
||||
match = re.search(r"\[(\d+)/(\d+)\]", line)
|
||||
if match:
|
||||
current = int(match.group(1))
|
||||
total = int(match.group(2))
|
||||
progress = int((current / total) * 100)
|
||||
task_item.update_progress(progress)
|
||||
|
||||
def task_finished(self, thread):
|
||||
self.logger.info(TASK_FINISHED.format(thread.log_file))
|
||||
if thread in self.quant_threads:
|
||||
|
@ -1508,7 +1533,7 @@ def generate_imatrix(self):
|
|||
task_name = GENERATING_IMATRIX_FOR.format(
|
||||
os.path.basename(self.imatrix_model.text())
|
||||
)
|
||||
task_item = TaskListItem(task_name, log_file)
|
||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from PyQt6.QtCore import pyqtSignal, QRegularExpression
|
||||
from PyQt6.QtGui import QDoubleValidator, QIntValidator, QRegularExpressionValidator
|
||||
|
||||
|
||||
class KVOverrideEntry(QWidget):
|
||||
deleted = pyqtSignal(QWidget)
|
||||
|
||||
|
@ -13,7 +14,7 @@ def __init__(self, parent=None):
|
|||
self.key_input = QLineEdit()
|
||||
self.key_input.setPlaceholderText("Key")
|
||||
# Set validator for key input (letters and dots only)
|
||||
key_validator = QRegularExpressionValidator(QRegularExpression(r'[A-Za-z.]+'))
|
||||
key_validator = QRegularExpressionValidator(QRegularExpression(r"[A-Za-z.]+"))
|
||||
self.key_input.setValidator(key_validator)
|
||||
layout.addWidget(self.key_input)
|
||||
|
||||
|
@ -32,7 +33,7 @@ def __init__(self, parent=None):
|
|||
|
||||
# Connect type change to validator update
|
||||
self.type_combo.currentTextChanged.connect(self.update_validator)
|
||||
|
||||
|
||||
# Initialize validator
|
||||
self.update_validator(self.type_combo.currentText())
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from imports_and_globals import open_file_safe
|
||||
|
||||
class QuantizationThread(QThread):
|
||||
# Define custom signals for communication with the main thread
|
||||
output_signal = pyqtSignal(str)
|
||||
status_signal = pyqtSignal(str)
|
||||
finished_signal = pyqtSignal()
|
||||
|
@ -32,48 +33,62 @@ def __init__(self, command, cwd, log_file):
|
|||
|
||||
def run(self):
|
||||
try:
|
||||
self.process = subprocess.Popen(self.command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
text=True, cwd=self.cwd)
|
||||
with open_file_safe(self.log_file, 'w') as log:
|
||||
# Start the subprocess
|
||||
self.process = subprocess.Popen(
|
||||
self.command,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
cwd=self.cwd,
|
||||
)
|
||||
# Open log file and process output
|
||||
with open_file_safe(self.log_file, "w") as log:
|
||||
for line in self.process.stdout:
|
||||
line = line.strip()
|
||||
self.output_signal.emit(line)
|
||||
log.write(line + '\n')
|
||||
log.write(line + "\n")
|
||||
log.flush()
|
||||
self.status_signal.emit("In Progress")
|
||||
self.parse_model_info(line)
|
||||
|
||||
# Wait for process to complete
|
||||
self.process.wait()
|
||||
if self.process.returncode == 0:
|
||||
self.status_signal.emit("Completed")
|
||||
self.model_info_signal.emit(self.model_info)
|
||||
else:
|
||||
self.error_signal.emit(f"Process exited with code {self.process.returncode}")
|
||||
self.error_signal.emit(
|
||||
f"Process exited with code {self.process.returncode}"
|
||||
)
|
||||
self.finished_signal.emit()
|
||||
except Exception as e:
|
||||
self.error_signal.emit(str(e))
|
||||
|
||||
def parse_model_info(self, line):
|
||||
# Parse output for model information
|
||||
if "llama_model_loader: loaded meta data with" in line:
|
||||
parts = line.split()
|
||||
self.model_info['kv_pairs'] = parts[6]
|
||||
self.model_info['tensors'] = parts[9]
|
||||
self.model_info["kv_pairs"] = parts[6]
|
||||
self.model_info["tensors"] = parts[9]
|
||||
elif "general.architecture" in line:
|
||||
self.model_info['architecture'] = line.split('=')[-1].strip()
|
||||
self.model_info["architecture"] = line.split("=")[-1].strip()
|
||||
elif line.startswith("llama_model_loader: - kv"):
|
||||
key = line.split(':')[2].strip()
|
||||
value = line.split('=')[-1].strip()
|
||||
self.model_info.setdefault('kv_data', {})[key] = value
|
||||
key = line.split(":")[2].strip()
|
||||
value = line.split("=")[-1].strip()
|
||||
self.model_info.setdefault("kv_data", {})[key] = value
|
||||
elif line.startswith("llama_model_loader: - type"):
|
||||
parts = line.split(':')
|
||||
parts = line.split(":")
|
||||
if len(parts) > 1:
|
||||
quant_type = parts[1].strip()
|
||||
tensors = parts[2].strip().split()[0]
|
||||
self.model_info.setdefault('quantization_type', []).append(f"{quant_type}: {tensors} tensors")
|
||||
self.model_info.setdefault("quantization_type", []).append(
|
||||
f"{quant_type}: {tensors} tensors"
|
||||
)
|
||||
|
||||
def terminate(self):
|
||||
# Terminate the subprocess if it's still running
|
||||
if self.process:
|
||||
os.kill(self.process.pid, signal.SIGTERM)
|
||||
self.process.wait(timeout=5)
|
||||
if self.process.poll() is None:
|
||||
os.kill(self.process.pid, signal.SIGKILL)
|
||||
|
||||
os.kill(self.process.pid, signal.SIGKILL)
|
|
@ -14,7 +14,7 @@
|
|||
from datetime import datetime
|
||||
|
||||
class TaskListItem(QWidget):
|
||||
def __init__(self, task_name, log_file, parent=None):
|
||||
def __init__(self, task_name, log_file, show_progress_bar=True, parent=None):
|
||||
super().__init__(parent)
|
||||
self.task_name = task_name
|
||||
self.log_file = log_file
|
||||
|
@ -27,6 +27,14 @@ def __init__(self, task_name, log_file, parent=None):
|
|||
layout.addWidget(self.task_label)
|
||||
layout.addWidget(self.progress_bar)
|
||||
layout.addWidget(self.status_label)
|
||||
|
||||
# Hide progress bar if show_progress_bar is False
|
||||
self.progress_bar.setVisible(show_progress_bar)
|
||||
|
||||
# Use indeterminate progress bar if not showing percentage
|
||||
if not show_progress_bar:
|
||||
self.progress_bar.setRange(0, 0)
|
||||
|
||||
self.progress_timer = QTimer(self)
|
||||
self.progress_timer.timeout.connect(self.update_progress)
|
||||
self.progress_value = 0
|
||||
|
@ -35,15 +43,17 @@ def update_status(self, status):
|
|||
self.status = status
|
||||
self.status_label.setText(status)
|
||||
if status == "In Progress":
|
||||
self.progress_bar.setRange(0, 100)
|
||||
self.progress_timer.start(100)
|
||||
# Only start timer if showing percentage progress
|
||||
if self.progress_bar.isVisible():
|
||||
self.progress_bar.setRange(0, 100)
|
||||
self.progress_timer.start(100)
|
||||
elif status == "Completed":
|
||||
self.progress_timer.stop()
|
||||
self.progress_bar.setValue(100)
|
||||
elif status == "Canceled":
|
||||
self.progress_timer.stop()
|
||||
self.progress_bar.setValue(0)
|
||||
|
||||
|
||||
def set_error(self):
|
||||
self.status = "Error"
|
||||
self.status_label.setText("Error")
|
||||
|
@ -51,7 +61,12 @@ def set_error(self):
|
|||
self.progress_bar.setRange(0, 100)
|
||||
self.progress_timer.stop()
|
||||
|
||||
def update_progress(self):
|
||||
self.progress_value = (self.progress_value + 1) % 101
|
||||
self.progress_bar.setValue(self.progress_value)
|
||||
|
||||
def update_progress(self, value=None):
|
||||
if value is not None:
|
||||
# Update progress bar with specific value
|
||||
self.progress_value = value
|
||||
self.progress_bar.setValue(self.progress_value)
|
||||
else:
|
||||
# Increment progress bar for indeterminate progress
|
||||
self.progress_value = (self.progress_value + 1) % 101
|
||||
self.progress_bar.setValue(self.progress_value)
|
File diff suppressed because it is too large
Load Diff
|
@ -11,15 +11,24 @@
|
|||
import json
|
||||
from math import prod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
SupportsIndex,
|
||||
cast,
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
if 'NO_LOCAL_GGUF' not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
|
||||
if "NO_LOCAL_GGUF" not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
|
||||
import gguf
|
||||
|
||||
# reuse model definitions from convert_hf_to_gguf.py
|
||||
|
@ -33,6 +42,7 @@ class PartialLoraTensor:
|
|||
A: Tensor | None = None
|
||||
B: Tensor | None = None
|
||||
|
||||
|
||||
# magic to support tensor shape modifications and splitting
|
||||
class LoraTorchTensor:
|
||||
_lora_A: Tensor # (n_rank, row_size)
|
||||
|
@ -57,7 +67,9 @@ def __getitem__(
|
|||
indices: (
|
||||
SupportsIndex
|
||||
| slice
|
||||
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature
|
||||
| tuple[
|
||||
SupportsIndex | slice | Tensor, ...
|
||||
] # TODO: add ellipsis in the type signature
|
||||
),
|
||||
) -> LoraTorchTensor:
|
||||
shape = self.shape
|
||||
|
@ -90,7 +102,10 @@ def __getitem__(
|
|||
)
|
||||
|
||||
if len(indices) < len(shape):
|
||||
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape))))
|
||||
indices = (
|
||||
*indices,
|
||||
*(slice(None, None) for _ in range(len(indices), len(shape))),
|
||||
)
|
||||
|
||||
# TODO: make sure this is correct
|
||||
indices_A = (
|
||||
|
@ -138,7 +153,9 @@ def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
|
|||
n_elems = prod(orig_shape)
|
||||
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
|
||||
assert n_elems % n_new_elems == 0
|
||||
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),)
|
||||
new_shape = (
|
||||
*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),
|
||||
)
|
||||
|
||||
if new_shape[-1] != orig_shape[-1]:
|
||||
raise NotImplementedError # can't reshape the row size trivially
|
||||
|
@ -164,7 +181,9 @@ def permute(self, *dims: int) -> LoraTorchTensor:
|
|||
assert all(dim == 1 for dim in self._lora_A.shape[:-2])
|
||||
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
|
||||
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
|
||||
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims))
|
||||
return LoraTorchTensor(
|
||||
self._lora_B.permute(*dims), self._lora_A.permute(*dims)
|
||||
)
|
||||
else:
|
||||
# TODO: compose the above two
|
||||
raise NotImplementedError
|
||||
|
@ -179,7 +198,9 @@ def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
|
|||
return self.transpose(axis0, axis1)
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs))
|
||||
return LoraTorchTensor(
|
||||
self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
|
||||
|
@ -226,50 +247,64 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
|
|||
base_name = base_name.replace(".lora_B.weight", ".weight")
|
||||
return base_name
|
||||
|
||||
|
||||
def pyinstaller_include():
|
||||
# PyInstaller import
|
||||
pass
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file")
|
||||
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outfile", type=Path,
|
||||
"--outfile",
|
||||
type=Path,
|
||||
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16",
|
||||
"--outtype",
|
||||
type=str,
|
||||
choices=["f32", "f16", "bf16", "q8_0", "auto"],
|
||||
default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian", action="store_true",
|
||||
"--bigendian",
|
||||
action="store_true",
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-lazy", action="store_true",
|
||||
"--no-lazy",
|
||||
action="store_true",
|
||||
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose", action="store_true",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="increase output verbosity",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run", action="store_true",
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="only print out what will be done, without writing any new files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base", type=Path, required=True,
|
||||
"--base",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="directory containing base model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"lora_path", type=Path,
|
||||
"lora_path",
|
||||
type=Path,
|
||||
help="directory containing LoRA adapter file",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||
|
||||
|
@ -318,7 +353,9 @@ class LoraModel(model_class):
|
|||
|
||||
lora_alpha: float
|
||||
|
||||
def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs):
|
||||
def __init__(
|
||||
self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs
|
||||
):
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
@ -330,7 +367,9 @@ def set_type(self):
|
|||
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
|
||||
self.gguf_writer.add_float32(
|
||||
gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha
|
||||
)
|
||||
super().set_gguf_parameters()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
|
@ -345,7 +384,9 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
|||
if not is_lora_a and not is_lora_b:
|
||||
if ".base_layer.weight" in name:
|
||||
continue
|
||||
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor")
|
||||
logger.error(
|
||||
f"Unexpected name '{name}': Not a lora_A or lora_B tensor"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
if base_name in tensor_map:
|
||||
|
@ -362,9 +403,14 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
|||
for name, tensor in tensor_map.items():
|
||||
assert tensor.A is not None
|
||||
assert tensor.B is not None
|
||||
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)))
|
||||
yield (
|
||||
name,
|
||||
cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)),
|
||||
)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
def modify_tensors(
|
||||
self, data_torch: Tensor, name: str, bid: int | None
|
||||
) -> Iterable[tuple[str, Tensor]]:
|
||||
dest = super().modify_tensors(data_torch, name, bid)
|
||||
for dest_name, dest_data in dest:
|
||||
assert isinstance(dest_data, LoraTorchTensor)
|
||||
|
|
|
@ -224,6 +224,7 @@ def __init__(self):
|
|||
self.LORA_CONVERSION_FROM_TO = ""
|
||||
self.GENERATING_IMATRIX_FOR = ""
|
||||
self.MODEL_PATH_REQUIRED_FOR_IMATRIX = ""
|
||||
self.NO_ASSET_SELECTED_FOR_CUDA_CHECK = ""
|
||||
|
||||
class _English(_Localization):
|
||||
def __init__(self):
|
||||
|
@ -450,6 +451,7 @@ def __init__(self):
|
|||
self.LORA_CONVERSION_FROM_TO = "LoRA Conversion from {} to {}"
|
||||
self.GENERATING_IMATRIX_FOR = "Generating IMatrix for {}"
|
||||
self.MODEL_PATH_REQUIRED_FOR_IMATRIX = "Model path is required for IMatrix generation."
|
||||
self.NO_ASSET_SELECTED_FOR_CUDA_CHECK = "No asset selected for CUDA check"
|
||||
|
||||
class _French:
|
||||
# French localization
|
||||
|
@ -5231,7 +5233,7 @@ def set_language(lang_code):
|
|||
global ADDING_LORA_ADAPTER, DELETING_LORA_ADAPTER, LORA_FILES, SELECT_LORA_ADAPTER_FILE, STARTING_LORA_EXPORT
|
||||
global OUTPUT_TYPE, SELECT_OUTPUT_TYPE, GGUF_AND_BIN_FILES, BASE_MODEL, SELECT_BASE_MODEL_FILE
|
||||
global BASE_MODEL_PATH_REQUIRED, BROWSING_FOR_BASE_MODEL_FILE, SELECT_BASE_MODEL_FOLDER, BROWSING_FOR_BASE_MODEL_FOLDER
|
||||
global LORA_CONVERSION_FROM_TO, GENERATING_IMATRIX_FOR, MODEL_PATH_REQUIRED_FOR_IMATRIX
|
||||
global LORA_CONVERSION_FROM_TO, GENERATING_IMATRIX_FOR, MODEL_PATH_REQUIRED_FOR_IMATRIX, NO_ASSET_SELECTED_FOR_CUDA_CHECK
|
||||
|
||||
loc = _languages.get(lang_code, _English)()
|
||||
english_loc = _English() # Create an instance of English localization for fallback
|
||||
|
|
Loading…
Reference in New Issue