add assets fix, code format, progress bar fix

This commit is contained in:
BuildTools 2024-08-04 15:07:24 -07:00
parent 3e8d7b1415
commit 10911b71a0
7 changed files with 1239 additions and 457 deletions

View File

@ -12,6 +12,7 @@
import platform import platform
import requests import requests
import zipfile import zipfile
import re
from datetime import datetime from datetime import datetime
from imports_and_globals import ensure_directory, open_file_safe from imports_and_globals import ensure_directory, open_file_safe
from DownloadThread import DownloadThread from DownloadThread import DownloadThread
@ -144,10 +145,6 @@ def __init__(self):
self.quant_type = QComboBox() self.quant_type = QComboBox()
self.quant_type.addItems( self.quant_type.addItems(
[ [
"Q4_0",
"Q4_1",
"Q5_0",
"Q5_1",
"IQ2_XXS", "IQ2_XXS",
"IQ2_XS", "IQ2_XS",
"IQ2_S", "IQ2_S",
@ -174,11 +171,15 @@ def __init__(self):
"Q5_K_M", "Q5_K_M",
"Q6_K", "Q6_K",
"Q8_0", "Q8_0",
"Q4_0",
"Q4_1",
"Q5_0",
"Q5_1",
"Q4_0_4_4", "Q4_0_4_4",
"Q4_0_4_8", "Q4_0_4_8",
"Q4_0_8_8", "Q4_0_8_8",
"F16",
"BF16", "BF16",
"F16",
"F32", "F32",
"COPY", "COPY",
] ]
@ -796,7 +797,7 @@ def export_lora(self):
thread = QuantizationThread(command, backend_path, log_file) thread = QuantizationThread(command, backend_path, log_file)
self.quant_threads.append(thread) self.quant_threads.append(thread)
task_item = TaskListItem(EXPORTING_LORA, log_file) task_item = TaskListItem(EXPORTING_LORA, log_file, show_progress_bar=False)
list_item = QListWidgetItem(self.task_list) list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint()) list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item) self.task_list.addItem(list_item)
@ -888,7 +889,7 @@ def convert_lora(self):
task_name = LORA_CONVERSION_FROM_TO.format( task_name = LORA_CONVERSION_FROM_TO.format(
os.path.basename(lora_input_path), os.path.basename(lora_output_path) os.path.basename(lora_input_path), os.path.basename(lora_output_path)
) )
task_item = TaskListItem(task_name, log_file) task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
list_item = QListWidgetItem(self.task_list) list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint()) list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item) self.task_list.addItem(list_item)
@ -967,13 +968,14 @@ def refresh_releases(self):
response = requests.get( response = requests.get(
"https://api.github.com/repos/ggerganov/llama.cpp/releases" "https://api.github.com/repos/ggerganov/llama.cpp/releases"
) )
response.raise_for_status() # Raise an exception for bad status codes
releases = response.json() releases = response.json()
self.release_combo.clear() self.release_combo.clear()
for release in releases: for release in releases:
self.release_combo.addItem(release["tag_name"], userData=release) self.release_combo.addItem(release["tag_name"], userData=release)
self.release_combo.currentIndexChanged.connect(self.update_assets) self.release_combo.currentIndexChanged.connect(self.update_assets)
self.update_assets() self.update_assets()
except Exception as e: except requests.exceptions.RequestException as e:
self.show_error(ERROR_FETCHING_RELEASES.format(str(e))) self.show_error(ERROR_FETCHING_RELEASES.format(str(e)))
def update_assets(self): def update_assets(self):
@ -981,20 +983,13 @@ def update_assets(self):
self.asset_combo.clear() self.asset_combo.clear()
release = self.release_combo.currentData() release = self.release_combo.currentData()
if release: if release:
if "assets" in release:
for asset in release["assets"]: for asset in release["assets"]:
self.asset_combo.addItem(asset["name"], userData=asset) self.asset_combo.addItem(asset["name"], userData=asset)
else:
self.show_error(NO_ASSETS_FOUND_FOR_RELEASE.format(release["tag_name"]))
self.update_cuda_option() self.update_cuda_option()
def update_cuda_option(self):
self.logger.debug(UPDATING_CUDA_OPTIONS)
asset = self.asset_combo.currentData()
is_cuda = asset and "cudart" in asset["name"].lower()
self.cuda_extract_checkbox.setVisible(is_cuda)
self.cuda_backend_label.setVisible(is_cuda)
self.backend_combo_cuda.setVisible(is_cuda)
if is_cuda:
self.update_cuda_backends()
def download_llama_cpp(self): def download_llama_cpp(self):
self.logger.info(STARTING_LLAMACPP_DOWNLOAD) self.logger.info(STARTING_LLAMACPP_DOWNLOAD)
asset = self.asset_combo.currentData() asset = self.asset_combo.currentData()
@ -1017,6 +1012,25 @@ def download_llama_cpp(self):
self.download_button.setEnabled(False) self.download_button.setEnabled(False)
self.download_progress.setValue(0) self.download_progress.setValue(0)
def update_cuda_option(self):
self.logger.debug(UPDATING_CUDA_OPTIONS)
asset = self.asset_combo.currentData()
# Handle the case where asset is None
if asset is None:
self.logger.warning(NO_ASSET_SELECTED_FOR_CUDA_CHECK)
self.cuda_extract_checkbox.setVisible(False)
self.cuda_backend_label.setVisible(False)
self.backend_combo_cuda.setVisible(False)
return # Exit the function early
is_cuda = asset and "cudart" in asset["name"].lower()
self.cuda_extract_checkbox.setVisible(is_cuda)
self.cuda_backend_label.setVisible(is_cuda)
self.backend_combo_cuda.setVisible(is_cuda)
if is_cuda:
self.update_cuda_backends()
def update_cuda_backends(self): def update_cuda_backends(self):
self.logger.debug(UPDATING_CUDA_BACKENDS) self.logger.debug(UPDATING_CUDA_BACKENDS)
self.backend_combo_cuda.clear() self.backend_combo_cuda.clear()
@ -1385,6 +1399,8 @@ def quantize_model(self):
self.task_list.addItem(list_item) self.task_list.addItem(list_item)
self.task_list.setItemWidget(list_item, task_item) self.task_list.setItemWidget(list_item, task_item)
# Connect the output signal to the new progress parsing function
thread.output_signal.connect(lambda line: self.parse_progress(line, task_item))
thread.status_signal.connect(task_item.update_status) thread.status_signal.connect(task_item.update_status)
thread.finished_signal.connect(lambda: self.task_finished(thread)) thread.finished_signal.connect(lambda: self.task_finished(thread))
thread.error_signal.connect(lambda err: self.handle_error(err, task_item)) thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
@ -1401,6 +1417,15 @@ def update_model_info(self, model_info):
# TODO: Do something with this # TODO: Do something with this
pass pass
def parse_progress(self, line, task_item):
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[(\d+)/(\d+)\]", line)
if match:
current = int(match.group(1))
total = int(match.group(2))
progress = int((current / total) * 100)
task_item.update_progress(progress)
def task_finished(self, thread): def task_finished(self, thread):
self.logger.info(TASK_FINISHED.format(thread.log_file)) self.logger.info(TASK_FINISHED.format(thread.log_file))
if thread in self.quant_threads: if thread in self.quant_threads:
@ -1508,7 +1533,7 @@ def generate_imatrix(self):
task_name = GENERATING_IMATRIX_FOR.format( task_name = GENERATING_IMATRIX_FOR.format(
os.path.basename(self.imatrix_model.text()) os.path.basename(self.imatrix_model.text())
) )
task_item = TaskListItem(task_name, log_file) task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
list_item = QListWidgetItem(self.task_list) list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint()) list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item) self.task_list.addItem(list_item)

View File

@ -2,6 +2,7 @@
from PyQt6.QtCore import pyqtSignal, QRegularExpression from PyQt6.QtCore import pyqtSignal, QRegularExpression
from PyQt6.QtGui import QDoubleValidator, QIntValidator, QRegularExpressionValidator from PyQt6.QtGui import QDoubleValidator, QIntValidator, QRegularExpressionValidator
class KVOverrideEntry(QWidget): class KVOverrideEntry(QWidget):
deleted = pyqtSignal(QWidget) deleted = pyqtSignal(QWidget)
@ -13,7 +14,7 @@ def __init__(self, parent=None):
self.key_input = QLineEdit() self.key_input = QLineEdit()
self.key_input.setPlaceholderText("Key") self.key_input.setPlaceholderText("Key")
# Set validator for key input (letters and dots only) # Set validator for key input (letters and dots only)
key_validator = QRegularExpressionValidator(QRegularExpression(r'[A-Za-z.]+')) key_validator = QRegularExpressionValidator(QRegularExpression(r"[A-Za-z.]+"))
self.key_input.setValidator(key_validator) self.key_input.setValidator(key_validator)
layout.addWidget(self.key_input) layout.addWidget(self.key_input)

View File

@ -16,6 +16,7 @@
from imports_and_globals import open_file_safe from imports_and_globals import open_file_safe
class QuantizationThread(QThread): class QuantizationThread(QThread):
# Define custom signals for communication with the main thread
output_signal = pyqtSignal(str) output_signal = pyqtSignal(str)
status_signal = pyqtSignal(str) status_signal = pyqtSignal(str)
finished_signal = pyqtSignal() finished_signal = pyqtSignal()
@ -32,48 +33,62 @@ def __init__(self, command, cwd, log_file):
def run(self): def run(self):
try: try:
self.process = subprocess.Popen(self.command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, # Start the subprocess
text=True, cwd=self.cwd) self.process = subprocess.Popen(
with open_file_safe(self.log_file, 'w') as log: self.command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
cwd=self.cwd,
)
# Open log file and process output
with open_file_safe(self.log_file, "w") as log:
for line in self.process.stdout: for line in self.process.stdout:
line = line.strip() line = line.strip()
self.output_signal.emit(line) self.output_signal.emit(line)
log.write(line + '\n') log.write(line + "\n")
log.flush() log.flush()
self.status_signal.emit("In Progress") self.status_signal.emit("In Progress")
self.parse_model_info(line) self.parse_model_info(line)
# Wait for process to complete
self.process.wait() self.process.wait()
if self.process.returncode == 0: if self.process.returncode == 0:
self.status_signal.emit("Completed") self.status_signal.emit("Completed")
self.model_info_signal.emit(self.model_info) self.model_info_signal.emit(self.model_info)
else: else:
self.error_signal.emit(f"Process exited with code {self.process.returncode}") self.error_signal.emit(
f"Process exited with code {self.process.returncode}"
)
self.finished_signal.emit() self.finished_signal.emit()
except Exception as e: except Exception as e:
self.error_signal.emit(str(e)) self.error_signal.emit(str(e))
def parse_model_info(self, line): def parse_model_info(self, line):
# Parse output for model information
if "llama_model_loader: loaded meta data with" in line: if "llama_model_loader: loaded meta data with" in line:
parts = line.split() parts = line.split()
self.model_info['kv_pairs'] = parts[6] self.model_info["kv_pairs"] = parts[6]
self.model_info['tensors'] = parts[9] self.model_info["tensors"] = parts[9]
elif "general.architecture" in line: elif "general.architecture" in line:
self.model_info['architecture'] = line.split('=')[-1].strip() self.model_info["architecture"] = line.split("=")[-1].strip()
elif line.startswith("llama_model_loader: - kv"): elif line.startswith("llama_model_loader: - kv"):
key = line.split(':')[2].strip() key = line.split(":")[2].strip()
value = line.split('=')[-1].strip() value = line.split("=")[-1].strip()
self.model_info.setdefault('kv_data', {})[key] = value self.model_info.setdefault("kv_data", {})[key] = value
elif line.startswith("llama_model_loader: - type"): elif line.startswith("llama_model_loader: - type"):
parts = line.split(':') parts = line.split(":")
if len(parts) > 1: if len(parts) > 1:
quant_type = parts[1].strip() quant_type = parts[1].strip()
tensors = parts[2].strip().split()[0] tensors = parts[2].strip().split()[0]
self.model_info.setdefault('quantization_type', []).append(f"{quant_type}: {tensors} tensors") self.model_info.setdefault("quantization_type", []).append(
f"{quant_type}: {tensors} tensors"
)
def terminate(self): def terminate(self):
# Terminate the subprocess if it's still running
if self.process: if self.process:
os.kill(self.process.pid, signal.SIGTERM) os.kill(self.process.pid, signal.SIGTERM)
self.process.wait(timeout=5) self.process.wait(timeout=5)
if self.process.poll() is None: if self.process.poll() is None:
os.kill(self.process.pid, signal.SIGKILL) os.kill(self.process.pid, signal.SIGKILL)

View File

@ -14,7 +14,7 @@
from datetime import datetime from datetime import datetime
class TaskListItem(QWidget): class TaskListItem(QWidget):
def __init__(self, task_name, log_file, parent=None): def __init__(self, task_name, log_file, show_progress_bar=True, parent=None):
super().__init__(parent) super().__init__(parent)
self.task_name = task_name self.task_name = task_name
self.log_file = log_file self.log_file = log_file
@ -27,6 +27,14 @@ def __init__(self, task_name, log_file, parent=None):
layout.addWidget(self.task_label) layout.addWidget(self.task_label)
layout.addWidget(self.progress_bar) layout.addWidget(self.progress_bar)
layout.addWidget(self.status_label) layout.addWidget(self.status_label)
# Hide progress bar if show_progress_bar is False
self.progress_bar.setVisible(show_progress_bar)
# Use indeterminate progress bar if not showing percentage
if not show_progress_bar:
self.progress_bar.setRange(0, 0)
self.progress_timer = QTimer(self) self.progress_timer = QTimer(self)
self.progress_timer.timeout.connect(self.update_progress) self.progress_timer.timeout.connect(self.update_progress)
self.progress_value = 0 self.progress_value = 0
@ -35,6 +43,8 @@ def update_status(self, status):
self.status = status self.status = status
self.status_label.setText(status) self.status_label.setText(status)
if status == "In Progress": if status == "In Progress":
# Only start timer if showing percentage progress
if self.progress_bar.isVisible():
self.progress_bar.setRange(0, 100) self.progress_bar.setRange(0, 100)
self.progress_timer.start(100) self.progress_timer.start(100)
elif status == "Completed": elif status == "Completed":
@ -51,7 +61,12 @@ def set_error(self):
self.progress_bar.setRange(0, 100) self.progress_bar.setRange(0, 100)
self.progress_timer.stop() self.progress_timer.stop()
def update_progress(self): def update_progress(self, value=None):
if value is not None:
# Update progress bar with specific value
self.progress_value = value
self.progress_bar.setValue(self.progress_value)
else:
# Increment progress bar for indeterminate progress
self.progress_value = (self.progress_value + 1) % 101 self.progress_value = (self.progress_value + 1) % 101
self.progress_bar.setValue(self.progress_value) self.progress_bar.setValue(self.progress_value)

File diff suppressed because it is too large Load Diff

View File

@ -11,15 +11,24 @@
import json import json
from math import prod from math import prod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Sequence,
SupportsIndex,
cast,
)
import torch import torch
if TYPE_CHECKING: if TYPE_CHECKING:
from torch import Tensor from torch import Tensor
if 'NO_LOCAL_GGUF' not in os.environ: if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py')) sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
import gguf import gguf
# reuse model definitions from convert_hf_to_gguf.py # reuse model definitions from convert_hf_to_gguf.py
@ -33,6 +42,7 @@ class PartialLoraTensor:
A: Tensor | None = None A: Tensor | None = None
B: Tensor | None = None B: Tensor | None = None
# magic to support tensor shape modifications and splitting # magic to support tensor shape modifications and splitting
class LoraTorchTensor: class LoraTorchTensor:
_lora_A: Tensor # (n_rank, row_size) _lora_A: Tensor # (n_rank, row_size)
@ -57,7 +67,9 @@ def __getitem__(
indices: ( indices: (
SupportsIndex SupportsIndex
| slice | slice
| tuple[SupportsIndex | slice | Tensor, ...] # TODO: add ellipsis in the type signature | tuple[
SupportsIndex | slice | Tensor, ...
] # TODO: add ellipsis in the type signature
), ),
) -> LoraTorchTensor: ) -> LoraTorchTensor:
shape = self.shape shape = self.shape
@ -90,7 +102,10 @@ def __getitem__(
) )
if len(indices) < len(shape): if len(indices) < len(shape):
indices = (*indices, *(slice(None, None) for _ in range(len(indices), len(shape)))) indices = (
*indices,
*(slice(None, None) for _ in range(len(indices), len(shape))),
)
# TODO: make sure this is correct # TODO: make sure this is correct
indices_A = ( indices_A = (
@ -138,7 +153,9 @@ def reshape(self, *shape: int | tuple[int, ...]) -> LoraTorchTensor:
n_elems = prod(orig_shape) n_elems = prod(orig_shape)
n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape) n_new_elems = prod(dim if dim != -1 else 1 for dim in new_shape)
assert n_elems % n_new_elems == 0 assert n_elems % n_new_elems == 0
new_shape = (*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),) new_shape = (
*(dim if dim != -1 else n_elems // n_new_elems for dim in new_shape),
)
if new_shape[-1] != orig_shape[-1]: if new_shape[-1] != orig_shape[-1]:
raise NotImplementedError # can't reshape the row size trivially raise NotImplementedError # can't reshape the row size trivially
@ -164,7 +181,9 @@ def permute(self, *dims: int) -> LoraTorchTensor:
assert all(dim == 1 for dim in self._lora_A.shape[:-2]) assert all(dim == 1 for dim in self._lora_A.shape[:-2])
return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims)) return LoraTorchTensor(self._lora_A, self._lora_B.permute(*dims))
if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1: if len(shape) == 2 and dims[-1] == -2 and dims[-2] == -1:
return LoraTorchTensor(self._lora_B.permute(*dims), self._lora_A.permute(*dims)) return LoraTorchTensor(
self._lora_B.permute(*dims), self._lora_A.permute(*dims)
)
else: else:
# TODO: compose the above two # TODO: compose the above two
raise NotImplementedError raise NotImplementedError
@ -179,7 +198,9 @@ def swapaxes(self, axis0: int, axis1: int) -> LoraTorchTensor:
return self.transpose(axis0, axis1) return self.transpose(axis0, axis1)
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
return LoraTorchTensor(self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)) return LoraTorchTensor(
self._lora_A.to(*args, **kwargs), self._lora_B.to(*args, **kwargs)
)
@classmethod @classmethod
def __torch_function__(cls, func: Callable, types, args=(), kwargs=None): def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
@ -226,50 +247,64 @@ def get_base_tensor_name(lora_tensor_name: str) -> str:
base_name = base_name.replace(".lora_B.weight", ".weight") base_name = base_name.replace(".lora_B.weight", ".weight")
return base_name return base_name
def pyinstaller_include(): def pyinstaller_include():
# PyInstaller import # PyInstaller import
pass pass
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file") description="Convert a huggingface PEFT LoRA adapter to a GGML compatible file"
)
parser.add_argument( parser.add_argument(
"--outfile", type=Path, "--outfile",
type=Path,
help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", help="path to write to; default: based on input. {ftype} will be replaced by the outtype.",
) )
parser.add_argument( parser.add_argument(
"--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", "--outtype",
type=str,
choices=["f32", "f16", "bf16", "q8_0", "auto"],
default="f16",
help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type",
) )
parser.add_argument( parser.add_argument(
"--bigendian", action="store_true", "--bigendian",
action="store_true",
help="model is executed on big endian machine", help="model is executed on big endian machine",
) )
parser.add_argument( parser.add_argument(
"--no-lazy", action="store_true", "--no-lazy",
action="store_true",
help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)", help="use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)",
) )
parser.add_argument( parser.add_argument(
"--verbose", action="store_true", "--verbose",
action="store_true",
help="increase output verbosity", help="increase output verbosity",
) )
parser.add_argument( parser.add_argument(
"--dry-run", action="store_true", "--dry-run",
action="store_true",
help="only print out what will be done, without writing any new files", help="only print out what will be done, without writing any new files",
) )
parser.add_argument( parser.add_argument(
"--base", type=Path, required=True, "--base",
type=Path,
required=True,
help="directory containing base model file", help="directory containing base model file",
) )
parser.add_argument( parser.add_argument(
"lora_path", type=Path, "lora_path",
type=Path,
help="directory containing LoRA adapter file", help="directory containing LoRA adapter file",
) )
return parser.parse_args() return parser.parse_args()
if __name__ == '__main__': if __name__ == "__main__":
args = parse_args() args = parse_args()
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
@ -318,7 +353,9 @@ class LoraModel(model_class):
lora_alpha: float lora_alpha: float
def __init__(self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs): def __init__(
self, *args, dir_lora_model: Path, lora_alpha: float, **kwargs
):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -330,7 +367,9 @@ def set_type(self):
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self): def set_gguf_parameters(self):
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(
gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha
)
super().set_gguf_parameters() super().set_gguf_parameters()
def get_tensors(self) -> Iterator[tuple[str, Tensor]]: def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
@ -345,7 +384,9 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
if not is_lora_a and not is_lora_b: if not is_lora_a and not is_lora_b:
if ".base_layer.weight" in name: if ".base_layer.weight" in name:
continue continue
logger.error(f"Unexpected name '{name}': Not a lora_A or lora_B tensor") logger.error(
f"Unexpected name '{name}': Not a lora_A or lora_B tensor"
)
sys.exit(1) sys.exit(1)
if base_name in tensor_map: if base_name in tensor_map:
@ -362,9 +403,14 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for name, tensor in tensor_map.items(): for name, tensor in tensor_map.items():
assert tensor.A is not None assert tensor.A is not None
assert tensor.B is not None assert tensor.B is not None
yield (name, cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B))) yield (
name,
cast(torch.Tensor, LoraTorchTensor(tensor.A, tensor.B)),
)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(
self, data_torch: Tensor, name: str, bid: int | None
) -> Iterable[tuple[str, Tensor]]:
dest = super().modify_tensors(data_torch, name, bid) dest = super().modify_tensors(data_torch, name, bid)
for dest_name, dest_data in dest: for dest_name, dest_data in dest:
assert isinstance(dest_data, LoraTorchTensor) assert isinstance(dest_data, LoraTorchTensor)

View File

@ -224,6 +224,7 @@ def __init__(self):
self.LORA_CONVERSION_FROM_TO = "" self.LORA_CONVERSION_FROM_TO = ""
self.GENERATING_IMATRIX_FOR = "" self.GENERATING_IMATRIX_FOR = ""
self.MODEL_PATH_REQUIRED_FOR_IMATRIX = "" self.MODEL_PATH_REQUIRED_FOR_IMATRIX = ""
self.NO_ASSET_SELECTED_FOR_CUDA_CHECK = ""
class _English(_Localization): class _English(_Localization):
def __init__(self): def __init__(self):
@ -450,6 +451,7 @@ def __init__(self):
self.LORA_CONVERSION_FROM_TO = "LoRA Conversion from {} to {}" self.LORA_CONVERSION_FROM_TO = "LoRA Conversion from {} to {}"
self.GENERATING_IMATRIX_FOR = "Generating IMatrix for {}" self.GENERATING_IMATRIX_FOR = "Generating IMatrix for {}"
self.MODEL_PATH_REQUIRED_FOR_IMATRIX = "Model path is required for IMatrix generation." self.MODEL_PATH_REQUIRED_FOR_IMATRIX = "Model path is required for IMatrix generation."
self.NO_ASSET_SELECTED_FOR_CUDA_CHECK = "No asset selected for CUDA check"
class _French: class _French:
# French localization # French localization
@ -5231,7 +5233,7 @@ def set_language(lang_code):
global ADDING_LORA_ADAPTER, DELETING_LORA_ADAPTER, LORA_FILES, SELECT_LORA_ADAPTER_FILE, STARTING_LORA_EXPORT global ADDING_LORA_ADAPTER, DELETING_LORA_ADAPTER, LORA_FILES, SELECT_LORA_ADAPTER_FILE, STARTING_LORA_EXPORT
global OUTPUT_TYPE, SELECT_OUTPUT_TYPE, GGUF_AND_BIN_FILES, BASE_MODEL, SELECT_BASE_MODEL_FILE global OUTPUT_TYPE, SELECT_OUTPUT_TYPE, GGUF_AND_BIN_FILES, BASE_MODEL, SELECT_BASE_MODEL_FILE
global BASE_MODEL_PATH_REQUIRED, BROWSING_FOR_BASE_MODEL_FILE, SELECT_BASE_MODEL_FOLDER, BROWSING_FOR_BASE_MODEL_FOLDER global BASE_MODEL_PATH_REQUIRED, BROWSING_FOR_BASE_MODEL_FILE, SELECT_BASE_MODEL_FOLDER, BROWSING_FOR_BASE_MODEL_FOLDER
global LORA_CONVERSION_FROM_TO, GENERATING_IMATRIX_FOR, MODEL_PATH_REQUIRED_FOR_IMATRIX global LORA_CONVERSION_FROM_TO, GENERATING_IMATRIX_FOR, MODEL_PATH_REQUIRED_FOR_IMATRIX, NO_ASSET_SELECTED_FOR_CUDA_CHECK
loc = _languages.get(lang_code, _English)() loc = _languages.get(lang_code, _English)()
english_loc = _English() # Create an instance of English localization for fallback english_loc = _English() # Create an instance of English localization for fallback