refactor: move functions into classes

- move functions into existing classes and files
- move AutoFP8 dialog out of a function and into __init__
This commit is contained in:
BuildTools 2024-09-09 15:11:03 -07:00
parent e46c6260ee
commit be38e35d99
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
6 changed files with 195 additions and 173 deletions

View File

@ -28,6 +28,7 @@
open_file_safe,
resource_path,
show_about,
load_dotenv,
)
@ -48,7 +49,7 @@ def __init__(self, args: List[str]) -> None:
self.setGeometry(100, 100, width, height)
self.setWindowFlag(Qt.FramelessWindowHint)
self.load_dotenv() # Loads the .env file
load_dotenv(self) # Loads the .env file
# Configuration
self.model_dir_name = os.environ.get("AUTOGGUF_MODEL_DIR_NAME", "models")
@ -117,6 +118,18 @@ def __init__(self, args: List[str]) -> None:
self.delete_lora_adapter_item = partial(
lora_conversion.delete_lora_adapter_item, self
)
self.lora_conversion_finished = partial(
lora_conversion.lora_conversion_finished, self
)
self.parse_progress = partial(QuantizationThread.parse_progress, self)
self.create_label = partial(ui_update.create_label, self)
self.browse_imatrix_datafile = ui_update.browse_imatrix_datafile.__get__(self)
self.browse_imatrix_model = ui_update.browse_imatrix_model.__get__(self)
self.browse_imatrix_output = ui_update.browse_imatrix_output.__get__(self)
self.restart_task = partial(TaskListItem.restart_task, self)
self.browse_hf_outfile = ui_update.browse_hf_outfile.__get__(self)
self.browse_hf_model_input = ui_update.browse_hf_model_input.__get__(self)
self.browse_base_model = ui_update.browse_base_model.__get__(self)
# Set up main widget and layout
main_widget = QWidget()
@ -154,11 +167,56 @@ def __init__(self, args: List[str]) -> None:
about_action.triggered.connect(self.show_about)
help_menu.addAction(about_action)
# AutoFP8 Window
self.fp8_dialog = QDialog(self)
self.fp8_dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC)
self.fp8_dialog.setFixedWidth(500)
self.fp8_layout = QVBoxLayout()
# Input path
input_layout = QHBoxLayout()
self.fp8_input = QLineEdit()
input_button = QPushButton(BROWSE)
input_button.clicked.connect(
lambda: self.fp8_input.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
input_layout.addWidget(QLabel(INPUT_MODEL))
input_layout.addWidget(self.fp8_input)
input_layout.addWidget(input_button)
self.fp8_layout.addLayout(input_layout)
# Output path
output_layout = QHBoxLayout()
self.fp8_output = QLineEdit()
output_button = QPushButton(BROWSE)
output_button.clicked.connect(
lambda: self.fp8_output.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
output_layout.addWidget(QLabel(OUTPUT))
output_layout.addWidget(self.fp8_output)
output_layout.addWidget(output_button)
self.fp8_layout.addLayout(output_layout)
# Quantize button
quantize_button = QPushButton(QUANTIZE)
quantize_button.clicked.connect(
lambda: self.quantize_to_fp8_dynamic(
self.fp8_input.text(), self.fp8_output.text()
)
)
self.fp8_layout.addWidget(quantize_button)
self.fp8_dialog.setLayout(self.fp8_layout)
# Tools menu
tools_menu = self.menubar.addMenu("&Tools")
autofp8_action = QAction("&AutoFP8", self)
autofp8_action.setShortcut(QKeySequence("Shift+Q"))
autofp8_action.triggered.connect(self.show_autofp8_window)
autofp8_action.triggered.connect(self.fp8_dialog.exec)
tools_menu.addAction(autofp8_action)
# Content widget
@ -744,18 +802,16 @@ def __init__(self, args: List[str]) -> None:
self.hf_no_lazy = QCheckBox(NO_LAZY_EVALUATION)
hf_to_gguf_layout.addRow(self.hf_no_lazy)
self.hf_verbose = QCheckBox(VERBOSE)
hf_to_gguf_layout.addRow(self.hf_verbose)
self.hf_dry_run = QCheckBox(DRY_RUN)
hf_to_gguf_layout.addRow(self.hf_dry_run)
self.hf_model_name = QLineEdit()
hf_to_gguf_layout.addRow(MODEL_NAME, self.hf_model_name)
self.hf_verbose = QCheckBox(VERBOSE)
hf_to_gguf_layout.addRow(self.hf_verbose)
self.hf_split_max_size = QLineEdit()
hf_to_gguf_layout.addRow(SPLIT_MAX_SIZE, self.hf_split_max_size)
self.hf_dry_run = QCheckBox(DRY_RUN)
hf_to_gguf_layout.addRow(self.hf_dry_run)
hf_to_gguf_convert_button = QPushButton(CONVERT_HF_TO_GGUF)
hf_to_gguf_convert_button.clicked.connect(self.convert_hf_to_gguf)
hf_to_gguf_layout.addRow(hf_to_gguf_convert_button)
@ -812,41 +868,6 @@ def __init__(self, args: List[str]) -> None:
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed()))
def load_dotenv(self):
if not os.path.isfile(".env"):
self.logger.warning(DOTENV_FILE_NOT_FOUND)
return
try:
with open(".env") as f:
for line in f:
# Strip leading/trailing whitespace
line = line.strip()
# Ignore comments and empty lines
if not line or line.startswith("#"):
continue
# Match key-value pairs (unquoted and quoted values)
match = re.match(r"^([^=]+)=(.*)$", line)
if not match:
self.logger.warning(COULD_NOT_PARSE_LINE.format(line))
continue
key, value = match.groups()
# Remove any surrounding quotes from the value
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
value = value[1:-1]
# Decode escape sequences
value = bytes(value, "utf-8").decode("unicode_escape")
# Set the environment variable
os.environ[key.strip()] = value.strip()
except Exception as e:
self.logger.error(ERROR_LOADING_DOTENV.format(e))
def load_plugins(self) -> Dict[str, Dict[str, Any]]:
plugins = {}
plugin_dir = "plugins"
@ -1038,28 +1059,6 @@ def save_task_preset(self, task_item) -> None:
)
break
def browse_base_model(self) -> None:
self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message
base_model_folder = QFileDialog.getExistingDirectory(
self, SELECT_BASE_MODEL_FOLDER
)
if base_model_folder:
self.base_model_path.setText(os.path.abspath(base_model_folder))
def browse_hf_model_input(self) -> None:
self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir))
def browse_hf_outfile(self) -> None:
self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT)
outfile, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES
)
if outfile:
self.hf_outfile.setText(os.path.abspath(outfile))
def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
self.logger.info(
QUANTIZING_TO_WITH_AUTOFP8.format(os.path.basename(model_dir), output_dir)
@ -1107,52 +1106,6 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
show_error(self.logger, f"{ERROR_STARTING_AUTOFP8_QUANTIZATION}: {e}")
self.logger.info(AUTOFP8_QUANTIZATION_TASK_STARTED)
def show_autofp8_window(self):
dialog = QDialog(self)
dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC)
dialog.setFixedWidth(500)
layout = QVBoxLayout()
# Input path
input_layout = QHBoxLayout()
self.fp8_input = QLineEdit()
input_button = QPushButton(BROWSE)
input_button.clicked.connect(
lambda: self.fp8_input.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
input_layout.addWidget(QLabel(INPUT_MODEL))
input_layout.addWidget(self.fp8_input)
input_layout.addWidget(input_button)
layout.addLayout(input_layout)
# Output path
output_layout = QHBoxLayout()
self.fp8_output = QLineEdit()
output_button = QPushButton(BROWSE)
output_button.clicked.connect(
lambda: self.fp8_output.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
output_layout.addWidget(QLabel(OUTPUT))
output_layout.addWidget(self.fp8_output)
output_layout.addWidget(output_button)
layout.addLayout(output_layout)
# Quantize button
quantize_button = QPushButton(QUANTIZE)
quantize_button.clicked.connect(
lambda: self.quantize_to_fp8_dynamic(
self.fp8_input.text(), self.fp8_output.text()
)
)
layout.addWidget(quantize_button)
dialog.setLayout(layout)
dialog.exec()
def convert_hf_to_gguf(self) -> None:
self.logger.info(STARTING_HF_TO_GGUF_CONVERSION)
try:
@ -1229,31 +1182,6 @@ def convert_hf_to_gguf(self) -> None:
show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED)
def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
new_thread = QuantizationThread(
thread.command, thread.cwd, thread.log_file
)
self.quant_threads.append(new_thread)
new_thread.status_signal.connect(task_item.update_status)
new_thread.finished_signal.connect(
lambda: self.task_finished(new_thread, task_item)
)
new_thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
new_thread.model_info_signal.connect(self.update_model_info)
new_thread.start()
task_item.update_status(IN_PROGRESS)
break
def lora_conversion_finished(self, thread) -> None:
self.logger.info(LORA_CONVERSION_FINISHED)
if thread in self.quant_threads:
self.quant_threads.remove(thread)
def download_finished(self, extract_dir) -> None:
self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir))
self.download_button.setEnabled(True)
@ -1313,11 +1241,6 @@ def download_error(self, error_message) -> None:
if os.path.exists(partial_file):
os.remove(partial_file)
def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text)
label.setToolTip(tooltip)
return label
def verify_gguf(self, file_path) -> bool:
try:
with open(file_path, "rb") as f:
@ -1613,15 +1536,6 @@ def quantize_model(self) -> None:
except Exception as e:
show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e)))
def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line)
if match:
current = int(match.group(1))
total = int(match.group(2))
progress = int((current / total) * 100)
task_item.update_progress(progress)
def task_finished(self, thread, task_item) -> None:
self.logger.info(TASK_FINISHED.format(thread.log_file))
if thread in self.quant_threads:
@ -1681,28 +1595,6 @@ def import_model(self) -> None:
self.load_models()
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))
def browse_imatrix_datafile(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
if datafile:
self.imatrix_datafile.setText(os.path.abspath(datafile))
def browse_imatrix_model(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_FILE, "", GGUF_FILES
)
if model_file:
self.imatrix_model.setText(os.path.abspath(model_file))
def browse_imatrix_output(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", DAT_FILES
)
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))
def generate_imatrix(self) -> None:
self.logger.info(STARTING_IMATRIX_GENERATION)
try:

View File

@ -1,4 +1,5 @@
import os
import re
import signal
import subprocess
@ -78,6 +79,15 @@ def parse_model_info(self, line) -> None:
f"{quant_type}: {tensors} tensors"
)
def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line)
if match:
current = int(match.group(1))
total = int(match.group(2))
progress = int((current / total) * 100)
task_item.update_progress(progress)
def terminate(self) -> None:
# Terminate the subprocess if it's still running
if self.process:

View File

@ -162,3 +162,23 @@ def update_progress(self, value=None) -> None:
else:
# Set progress bar to zero for indeterminate progress
self.progress_bar.setValue(0)
def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
new_thread = QuantizationThread(
thread.command, thread.cwd, thread.log_file
)
self.quant_threads.append(new_thread)
new_thread.status_signal.connect(task_item.update_status)
new_thread.finished_signal.connect(
lambda: self.task_finished(new_thread, task_item)
)
new_thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
new_thread.model_info_signal.connect(self.update_model_info)
new_thread.start()
task_item.update_status(IN_PROGRESS)
break

View File

@ -1,3 +1,5 @@
import os
import re
import sys
from typing import TextIO, Union
@ -5,7 +7,48 @@
QMessageBox,
)
from Localizations import *
from Localizations import (
DOTENV_FILE_NOT_FOUND,
COULD_NOT_PARSE_LINE,
ERROR_LOADING_DOTENV,
AUTOGGUF_VERSION,
)
def load_dotenv(self) -> None:
if not os.path.isfile(".env"):
self.logger.warning(DOTENV_FILE_NOT_FOUND)
return
try:
with open(".env") as f:
for line in f:
# Strip leading/trailing whitespace
line = line.strip()
# Ignore comments and empty lines
if not line or line.startswith("#"):
continue
# Match key-value pairs (unquoted and quoted values)
match = re.match(r"^([^=]+)=(.*)$", line)
if not match:
self.logger.warning(COULD_NOT_PARSE_LINE.format(line))
continue
key, value = match.groups()
# Remove any surrounding quotes from the value
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
value = value[1:-1]
# Decode escape sequences
value = bytes(value, "utf-8").decode("unicode_escape")
# Set the environment variable
os.environ[key.strip()] = value.strip()
except Exception as e:
self.logger.error(ERROR_LOADING_DOTENV.format(e))
def show_about(self) -> None:

View File

@ -98,6 +98,12 @@ def export_lora(self) -> None:
show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e)))
def lora_conversion_finished(self, thread) -> None:
self.logger.info(LORA_CONVERSION_FINISHED)
if thread in self.quant_threads:
self.quant_threads.remove(thread)
def delete_lora_adapter_item(self, adapter_widget) -> None:
self.logger.info(DELETING_LORA_ADAPTER)
# Find the QListWidgetItem containing the adapter_widget

View File

@ -1,11 +1,62 @@
from PySide6.QtCore import QTimer
from PySide6.QtGui import Qt
from PySide6.QtWidgets import QFileDialog, QLabel
from Localizations import *
import psutil
from error_handling import show_error
def browse_base_model(self) -> None:
self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message
base_model_folder = QFileDialog.getExistingDirectory(self, SELECT_BASE_MODEL_FOLDER)
if base_model_folder:
self.base_model_path.setText(os.path.abspath(base_model_folder))
def browse_hf_model_input(self) -> None:
self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir))
def browse_hf_outfile(self) -> None:
self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT)
outfile, _ = QFileDialog.getSaveFileName(self, SELECT_OUTPUT_FILE, "", GGUF_FILES)
if outfile:
self.hf_outfile.setText(os.path.abspath(outfile))
def browse_imatrix_datafile(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
if datafile:
self.imatrix_datafile.setText(os.path.abspath(datafile))
def browse_imatrix_model(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES)
if model_file:
self.imatrix_model.setText(os.path.abspath(model_file))
def browse_imatrix_output(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", DAT_FILES
)
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))
def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text)
label.setToolTip(tooltip)
return label
def toggle_gpu_offload_auto(self, state) -> None:
is_auto = state == Qt.CheckState.Checked
self.gpu_offload_slider.setEnabled(not is_auto)