From be38e35d99f509be695f9d92fcbebcbfc7ee88ff Mon Sep 17 00:00:00 2001 From: BuildTools Date: Mon, 9 Sep 2024 15:11:03 -0700 Subject: [PATCH] refactor: move functions into classes - move functions into existing classes and files - move AutoFP8 dialog out of a function and into __init__ --- src/AutoGGUF.py | 236 ++++++++++--------------------------- src/QuantizationThread.py | 10 ++ src/TaskListItem.py | 20 ++++ src/imports_and_globals.py | 45 ++++++- src/lora_conversion.py | 6 + src/ui_update.py | 51 ++++++++ 6 files changed, 195 insertions(+), 173 deletions(-) diff --git a/src/AutoGGUF.py b/src/AutoGGUF.py index afde662..f8ec809 100644 --- a/src/AutoGGUF.py +++ b/src/AutoGGUF.py @@ -28,6 +28,7 @@ open_file_safe, resource_path, show_about, + load_dotenv, ) @@ -48,7 +49,7 @@ def __init__(self, args: List[str]) -> None: self.setGeometry(100, 100, width, height) self.setWindowFlag(Qt.FramelessWindowHint) - self.load_dotenv() # Loads the .env file + load_dotenv(self) # Loads the .env file # Configuration self.model_dir_name = os.environ.get("AUTOGGUF_MODEL_DIR_NAME", "models") @@ -117,6 +118,18 @@ def __init__(self, args: List[str]) -> None: self.delete_lora_adapter_item = partial( lora_conversion.delete_lora_adapter_item, self ) + self.lora_conversion_finished = partial( + lora_conversion.lora_conversion_finished, self + ) + self.parse_progress = partial(QuantizationThread.parse_progress, self) + self.create_label = partial(ui_update.create_label, self) + self.browse_imatrix_datafile = ui_update.browse_imatrix_datafile.__get__(self) + self.browse_imatrix_model = ui_update.browse_imatrix_model.__get__(self) + self.browse_imatrix_output = ui_update.browse_imatrix_output.__get__(self) + self.restart_task = partial(TaskListItem.restart_task, self) + self.browse_hf_outfile = ui_update.browse_hf_outfile.__get__(self) + self.browse_hf_model_input = ui_update.browse_hf_model_input.__get__(self) + self.browse_base_model = ui_update.browse_base_model.__get__(self) # Set up main widget and layout main_widget = QWidget() @@ -154,11 +167,56 @@ def __init__(self, args: List[str]) -> None: about_action.triggered.connect(self.show_about) help_menu.addAction(about_action) + # AutoFP8 Window + self.fp8_dialog = QDialog(self) + self.fp8_dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC) + self.fp8_dialog.setFixedWidth(500) + self.fp8_layout = QVBoxLayout() + + # Input path + input_layout = QHBoxLayout() + self.fp8_input = QLineEdit() + input_button = QPushButton(BROWSE) + input_button.clicked.connect( + lambda: self.fp8_input.setText( + QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER) + ) + ) + input_layout.addWidget(QLabel(INPUT_MODEL)) + input_layout.addWidget(self.fp8_input) + input_layout.addWidget(input_button) + self.fp8_layout.addLayout(input_layout) + + # Output path + output_layout = QHBoxLayout() + self.fp8_output = QLineEdit() + output_button = QPushButton(BROWSE) + output_button.clicked.connect( + lambda: self.fp8_output.setText( + QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER) + ) + ) + output_layout.addWidget(QLabel(OUTPUT)) + output_layout.addWidget(self.fp8_output) + output_layout.addWidget(output_button) + self.fp8_layout.addLayout(output_layout) + + # Quantize button + quantize_button = QPushButton(QUANTIZE) + quantize_button.clicked.connect( + lambda: self.quantize_to_fp8_dynamic( + self.fp8_input.text(), self.fp8_output.text() + ) + ) + + self.fp8_layout.addWidget(quantize_button) + self.fp8_dialog.setLayout(self.fp8_layout) + # Tools menu tools_menu = self.menubar.addMenu("&Tools") autofp8_action = QAction("&AutoFP8", self) autofp8_action.setShortcut(QKeySequence("Shift+Q")) - autofp8_action.triggered.connect(self.show_autofp8_window) + autofp8_action.triggered.connect(self.fp8_dialog.exec) tools_menu.addAction(autofp8_action) # Content widget @@ -744,18 +802,16 @@ def __init__(self, args: List[str]) -> None: self.hf_no_lazy = QCheckBox(NO_LAZY_EVALUATION) hf_to_gguf_layout.addRow(self.hf_no_lazy) + self.hf_verbose = QCheckBox(VERBOSE) + hf_to_gguf_layout.addRow(self.hf_verbose) + self.hf_dry_run = QCheckBox(DRY_RUN) + hf_to_gguf_layout.addRow(self.hf_dry_run) self.hf_model_name = QLineEdit() hf_to_gguf_layout.addRow(MODEL_NAME, self.hf_model_name) - self.hf_verbose = QCheckBox(VERBOSE) - hf_to_gguf_layout.addRow(self.hf_verbose) - self.hf_split_max_size = QLineEdit() hf_to_gguf_layout.addRow(SPLIT_MAX_SIZE, self.hf_split_max_size) - self.hf_dry_run = QCheckBox(DRY_RUN) - hf_to_gguf_layout.addRow(self.hf_dry_run) - hf_to_gguf_convert_button = QPushButton(CONVERT_HF_TO_GGUF) hf_to_gguf_convert_button.clicked.connect(self.convert_hf_to_gguf) hf_to_gguf_layout.addRow(hf_to_gguf_convert_button) @@ -812,41 +868,6 @@ def __init__(self, args: List[str]) -> None: self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE) self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed())) - def load_dotenv(self): - if not os.path.isfile(".env"): - self.logger.warning(DOTENV_FILE_NOT_FOUND) - return - - try: - with open(".env") as f: - for line in f: - # Strip leading/trailing whitespace - line = line.strip() - - # Ignore comments and empty lines - if not line or line.startswith("#"): - continue - - # Match key-value pairs (unquoted and quoted values) - match = re.match(r"^([^=]+)=(.*)$", line) - if not match: - self.logger.warning(COULD_NOT_PARSE_LINE.format(line)) - continue - - key, value = match.groups() - - # Remove any surrounding quotes from the value - if value.startswith(("'", '"')) and value.endswith(("'", '"')): - value = value[1:-1] - - # Decode escape sequences - value = bytes(value, "utf-8").decode("unicode_escape") - - # Set the environment variable - os.environ[key.strip()] = value.strip() - except Exception as e: - self.logger.error(ERROR_LOADING_DOTENV.format(e)) - def load_plugins(self) -> Dict[str, Dict[str, Any]]: plugins = {} plugin_dir = "plugins" @@ -1038,28 +1059,6 @@ def save_task_preset(self, task_item) -> None: ) break - def browse_base_model(self) -> None: - self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message - base_model_folder = QFileDialog.getExistingDirectory( - self, SELECT_BASE_MODEL_FOLDER - ) - if base_model_folder: - self.base_model_path.setText(os.path.abspath(base_model_folder)) - - def browse_hf_model_input(self) -> None: - self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY) - model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY) - if model_dir: - self.hf_model_input.setText(os.path.abspath(model_dir)) - - def browse_hf_outfile(self) -> None: - self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT) - outfile, _ = QFileDialog.getSaveFileName( - self, SELECT_OUTPUT_FILE, "", GGUF_FILES - ) - if outfile: - self.hf_outfile.setText(os.path.abspath(outfile)) - def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None: self.logger.info( QUANTIZING_TO_WITH_AUTOFP8.format(os.path.basename(model_dir), output_dir) @@ -1107,52 +1106,6 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None: show_error(self.logger, f"{ERROR_STARTING_AUTOFP8_QUANTIZATION}: {e}") self.logger.info(AUTOFP8_QUANTIZATION_TASK_STARTED) - def show_autofp8_window(self): - dialog = QDialog(self) - dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC) - dialog.setFixedWidth(500) - layout = QVBoxLayout() - - # Input path - input_layout = QHBoxLayout() - self.fp8_input = QLineEdit() - input_button = QPushButton(BROWSE) - input_button.clicked.connect( - lambda: self.fp8_input.setText( - QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER) - ) - ) - input_layout.addWidget(QLabel(INPUT_MODEL)) - input_layout.addWidget(self.fp8_input) - input_layout.addWidget(input_button) - layout.addLayout(input_layout) - - # Output path - output_layout = QHBoxLayout() - self.fp8_output = QLineEdit() - output_button = QPushButton(BROWSE) - output_button.clicked.connect( - lambda: self.fp8_output.setText( - QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER) - ) - ) - output_layout.addWidget(QLabel(OUTPUT)) - output_layout.addWidget(self.fp8_output) - output_layout.addWidget(output_button) - layout.addLayout(output_layout) - - # Quantize button - quantize_button = QPushButton(QUANTIZE) - quantize_button.clicked.connect( - lambda: self.quantize_to_fp8_dynamic( - self.fp8_input.text(), self.fp8_output.text() - ) - ) - layout.addWidget(quantize_button) - - dialog.setLayout(layout) - dialog.exec() - def convert_hf_to_gguf(self) -> None: self.logger.info(STARTING_HF_TO_GGUF_CONVERSION) try: @@ -1229,31 +1182,6 @@ def convert_hf_to_gguf(self) -> None: show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e))) self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED) - def restart_task(self, task_item) -> None: - self.logger.info(RESTARTING_TASK.format(task_item.task_name)) - for thread in self.quant_threads: - if thread.log_file == task_item.log_file: - new_thread = QuantizationThread( - thread.command, thread.cwd, thread.log_file - ) - self.quant_threads.append(new_thread) - new_thread.status_signal.connect(task_item.update_status) - new_thread.finished_signal.connect( - lambda: self.task_finished(new_thread, task_item) - ) - new_thread.error_signal.connect( - lambda err: handle_error(self.logger, err, task_item) - ) - new_thread.model_info_signal.connect(self.update_model_info) - new_thread.start() - task_item.update_status(IN_PROGRESS) - break - - def lora_conversion_finished(self, thread) -> None: - self.logger.info(LORA_CONVERSION_FINISHED) - if thread in self.quant_threads: - self.quant_threads.remove(thread) - def download_finished(self, extract_dir) -> None: self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir)) self.download_button.setEnabled(True) @@ -1313,11 +1241,6 @@ def download_error(self, error_message) -> None: if os.path.exists(partial_file): os.remove(partial_file) - def create_label(self, text, tooltip) -> QLabel: - label = QLabel(text) - label.setToolTip(tooltip) - return label - def verify_gguf(self, file_path) -> bool: try: with open(file_path, "rb") as f: @@ -1613,15 +1536,6 @@ def quantize_model(self) -> None: except Exception as e: show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e))) - def parse_progress(self, line, task_item) -> None: - # Parses the output line for progress information and updates the task item. - match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line) - if match: - current = int(match.group(1)) - total = int(match.group(2)) - progress = int((current / total) * 100) - task_item.update_progress(progress) - def task_finished(self, thread, task_item) -> None: self.logger.info(TASK_FINISHED.format(thread.log_file)) if thread in self.quant_threads: @@ -1681,28 +1595,6 @@ def import_model(self) -> None: self.load_models() self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name)) - def browse_imatrix_datafile(self) -> None: - self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE) - datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES) - if datafile: - self.imatrix_datafile.setText(os.path.abspath(datafile)) - - def browse_imatrix_model(self) -> None: - self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE) - model_file, _ = QFileDialog.getOpenFileName( - self, SELECT_MODEL_FILE, "", GGUF_FILES - ) - if model_file: - self.imatrix_model.setText(os.path.abspath(model_file)) - - def browse_imatrix_output(self) -> None: - self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE) - output_file, _ = QFileDialog.getSaveFileName( - self, SELECT_OUTPUT_FILE, "", DAT_FILES - ) - if output_file: - self.imatrix_output.setText(os.path.abspath(output_file)) - def generate_imatrix(self) -> None: self.logger.info(STARTING_IMATRIX_GENERATION) try: diff --git a/src/QuantizationThread.py b/src/QuantizationThread.py index bbc4540..d196936 100644 --- a/src/QuantizationThread.py +++ b/src/QuantizationThread.py @@ -1,4 +1,5 @@ import os +import re import signal import subprocess @@ -78,6 +79,15 @@ def parse_model_info(self, line) -> None: f"{quant_type}: {tensors} tensors" ) + def parse_progress(self, line, task_item) -> None: + # Parses the output line for progress information and updates the task item. + match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line) + if match: + current = int(match.group(1)) + total = int(match.group(2)) + progress = int((current / total) * 100) + task_item.update_progress(progress) + def terminate(self) -> None: # Terminate the subprocess if it's still running if self.process: diff --git a/src/TaskListItem.py b/src/TaskListItem.py index 28f567b..2b4bd1d 100644 --- a/src/TaskListItem.py +++ b/src/TaskListItem.py @@ -162,3 +162,23 @@ def update_progress(self, value=None) -> None: else: # Set progress bar to zero for indeterminate progress self.progress_bar.setValue(0) + + def restart_task(self, task_item) -> None: + self.logger.info(RESTARTING_TASK.format(task_item.task_name)) + for thread in self.quant_threads: + if thread.log_file == task_item.log_file: + new_thread = QuantizationThread( + thread.command, thread.cwd, thread.log_file + ) + self.quant_threads.append(new_thread) + new_thread.status_signal.connect(task_item.update_status) + new_thread.finished_signal.connect( + lambda: self.task_finished(new_thread, task_item) + ) + new_thread.error_signal.connect( + lambda err: handle_error(self.logger, err, task_item) + ) + new_thread.model_info_signal.connect(self.update_model_info) + new_thread.start() + task_item.update_status(IN_PROGRESS) + break diff --git a/src/imports_and_globals.py b/src/imports_and_globals.py index f99e8e0..b5177be 100644 --- a/src/imports_and_globals.py +++ b/src/imports_and_globals.py @@ -1,3 +1,5 @@ +import os +import re import sys from typing import TextIO, Union @@ -5,7 +7,48 @@ QMessageBox, ) -from Localizations import * +from Localizations import ( + DOTENV_FILE_NOT_FOUND, + COULD_NOT_PARSE_LINE, + ERROR_LOADING_DOTENV, + AUTOGGUF_VERSION, +) + + +def load_dotenv(self) -> None: + if not os.path.isfile(".env"): + self.logger.warning(DOTENV_FILE_NOT_FOUND) + return + + try: + with open(".env") as f: + for line in f: + # Strip leading/trailing whitespace + line = line.strip() + + # Ignore comments and empty lines + if not line or line.startswith("#"): + continue + + # Match key-value pairs (unquoted and quoted values) + match = re.match(r"^([^=]+)=(.*)$", line) + if not match: + self.logger.warning(COULD_NOT_PARSE_LINE.format(line)) + continue + + key, value = match.groups() + + # Remove any surrounding quotes from the value + if value.startswith(("'", '"')) and value.endswith(("'", '"')): + value = value[1:-1] + + # Decode escape sequences + value = bytes(value, "utf-8").decode("unicode_escape") + + # Set the environment variable + os.environ[key.strip()] = value.strip() + except Exception as e: + self.logger.error(ERROR_LOADING_DOTENV.format(e)) def show_about(self) -> None: diff --git a/src/lora_conversion.py b/src/lora_conversion.py index 0691e17..4652548 100644 --- a/src/lora_conversion.py +++ b/src/lora_conversion.py @@ -98,6 +98,12 @@ def export_lora(self) -> None: show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e))) +def lora_conversion_finished(self, thread) -> None: + self.logger.info(LORA_CONVERSION_FINISHED) + if thread in self.quant_threads: + self.quant_threads.remove(thread) + + def delete_lora_adapter_item(self, adapter_widget) -> None: self.logger.info(DELETING_LORA_ADAPTER) # Find the QListWidgetItem containing the adapter_widget diff --git a/src/ui_update.py b/src/ui_update.py index bb997e2..aba9938 100644 --- a/src/ui_update.py +++ b/src/ui_update.py @@ -1,11 +1,62 @@ from PySide6.QtCore import QTimer from PySide6.QtGui import Qt +from PySide6.QtWidgets import QFileDialog, QLabel from Localizations import * import psutil from error_handling import show_error +def browse_base_model(self) -> None: + self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message + base_model_folder = QFileDialog.getExistingDirectory(self, SELECT_BASE_MODEL_FOLDER) + if base_model_folder: + self.base_model_path.setText(os.path.abspath(base_model_folder)) + + +def browse_hf_model_input(self) -> None: + self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY) + model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY) + if model_dir: + self.hf_model_input.setText(os.path.abspath(model_dir)) + + +def browse_hf_outfile(self) -> None: + self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT) + outfile, _ = QFileDialog.getSaveFileName(self, SELECT_OUTPUT_FILE, "", GGUF_FILES) + if outfile: + self.hf_outfile.setText(os.path.abspath(outfile)) + + +def browse_imatrix_datafile(self) -> None: + self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE) + datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES) + if datafile: + self.imatrix_datafile.setText(os.path.abspath(datafile)) + + +def browse_imatrix_model(self) -> None: + self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE) + model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES) + if model_file: + self.imatrix_model.setText(os.path.abspath(model_file)) + + +def browse_imatrix_output(self) -> None: + self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE) + output_file, _ = QFileDialog.getSaveFileName( + self, SELECT_OUTPUT_FILE, "", DAT_FILES + ) + if output_file: + self.imatrix_output.setText(os.path.abspath(output_file)) + + +def create_label(self, text, tooltip) -> QLabel: + label = QLabel(text) + label.setToolTip(tooltip) + return label + + def toggle_gpu_offload_auto(self, state) -> None: is_auto = state == Qt.CheckState.Checked self.gpu_offload_slider.setEnabled(not is_auto)