feat(ui): support shift clicking to get quantization command

- support shift clicking Quantize Model button to get quantize command
- clean up imports in AutoGGUF.py and add localization keys
- use str() for getting log_dir_name
- remove legacy validate_quantization_inputs() function
- add return_command parameter to quantize_model() function
This commit is contained in:
BuildTools 2024-11-12 19:41:59 -08:00
parent 6aaefb2ccb
commit 749f3215ec
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
2 changed files with 27 additions and 22 deletions

View File

@ -5,7 +5,7 @@
import urllib.request
from datetime import datetime
from functools import partial, wraps
from typing import List
from typing import Any, List, Union
from PySide6.QtCore import *
from PySide6.QtGui import *
@ -71,7 +71,7 @@ def __init__(self, args: List[str]) -> None:
self.parse_resolution = ui_update.parse_resolution.__get__(self)
self.log_dir_name = os.environ.get("AUTOGGUF_LOG_DIR_NAME", "logs")
self.log_dir_name = str(os.environ.get("AUTOGGUF_LOG_DIR_NAME", "logs"))
width, height = self.parse_resolution()
self.logger = Logger("AutoGGUF", self.log_dir_name)
@ -775,7 +775,7 @@ def __init__(self, args: List[str]) -> None:
# Quantize button layout
quantize_layout = QHBoxLayout()
quantize_button = QPushButton(QUANTIZE_MODEL)
quantize_button.clicked.connect(self.quantize_model)
quantize_button.clicked[bool].connect(self.quantize_model_handler)
save_preset_button = QPushButton(SAVE_PRESET)
save_preset_button.clicked.connect(self.save_preset)
load_preset_button = QPushButton(LOAD_PRESET)
@ -1101,6 +1101,20 @@ def __init__(self, args: List[str]) -> None:
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed()))
def quantize_model_handler(self) -> None:
if QApplication.keyboardModifiers() == Qt.ShiftModifier and self.quantize_model(
return_command=True
):
QApplication.clipboard().setText(self.quantize_model(return_command=True))
QMessageBox.information(
None,
INFO,
f"{COPIED_COMMAND_TO_CLIPBOARD} "
+ f"<code style='font-family: monospace; white-space: pre;'>{self.quantize_model(return_command=True)}</code>",
)
else:
self.quantize_model()
def resizeEvent(self, event) -> None:
super().resizeEvent(event)
path = QPainterPath()
@ -1254,23 +1268,6 @@ def download_finished(self, extract_dir) -> None:
if index >= 0:
self.backend_combo.setCurrentIndex(index)
def validate_quantization_inputs(self) -> None:
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
errors = []
if not self.backend_combo.currentData():
errors.append(NO_BACKEND_SELECTED)
if not self.models_input.text():
errors.append(MODELS_PATH_REQUIRED)
if not self.output_input.text():
errors.append(OUTPUT_PATH_REQUIRED)
if not self.logs_input.text():
errors.append(LOGS_PATH_REQUIRED)
if not self.model_tree.currentItem():
errors.append(NO_MODEL_SELECTED)
if errors:
raise ValueError("\n".join(errors))
def load_models(self) -> None:
self.logger.info(LOADING_MODELS)
models_dir = self.models_input.text()
@ -1698,10 +1695,9 @@ def merge_gguf(self, model_dir: str, output_dir: str) -> None:
show_error(self.logger, "Error starting merge GGUF task: {}".format(e))
self.logger.info("Split GGUF task finished.")
def quantize_model(self) -> None:
def quantize_model(self, return_command=False) -> str:
self.logger.info(STARTING_MODEL_QUANTIZATION)
try:
self.validate_quantization_inputs()
selected_item = self.model_tree.currentItem()
if not selected_item:
raise ValueError(NO_MODEL_SELECTED)
@ -1822,6 +1818,12 @@ def quantize_model(self) -> None:
if self.extra_arguments.text():
command.extend(self.extra_arguments.text().split())
if return_command:
self.logger.info(
f"{QUANTIZATION_COMMAND}: {str(' '.join(command))}"
)
return str(" ".join(command))
logs_path = self.logs_input.text()
ensure_directory(logs_path)

View File

@ -454,6 +454,9 @@ def __init__(self):
self.EXTRA_COMMAND_ARGUMENTS = "Additional command-line arguments"
self.INFO = "Info"
self.COPIED_COMMAND_TO_CLIPBOARD = "Copied command to clipboard:"
class _French(_Localization):
def __init__(self):