refactor: move functions out of AutoGGUF.py

- relocate functions into utils.py and TaskListItem.py
This commit is contained in:
BuildTools 2024-08-29 15:01:13 -07:00
parent e307a4d3b5
commit 6583412b76
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
3 changed files with 97 additions and 80 deletions

View File

@ -1,36 +1,34 @@
import importlib
import json import json
import re import re
import shutil import shutil
import importlib
from functools import partial
from datetime import datetime from datetime import datetime
from typing import Tuple, Dict, List, Any from functools import partial
from dotenv import load_dotenv from typing import Any, Dict, List, Tuple
import requests
from PySide6.QtCore import * from PySide6.QtCore import *
from PySide6.QtGui import * from PySide6.QtGui import *
from PySide6.QtWidgets import * from PySide6.QtWidgets import *
from dotenv import load_dotenv
from GPUMonitor import GPUMonitor import lora_conversion
from KVOverrideEntry import KVOverrideEntry import presets
from Logger import Logger import ui_update
from ModelInfoDialog import ModelInfoDialog import utils
from CustomTitleBar import CustomTitleBar from CustomTitleBar import CustomTitleBar
from error_handling import show_error, handle_error from GPUMonitor import GPUMonitor
from TaskListItem import TaskListItem from Localizations import *
from Logger import Logger
from QuantizationThread import QuantizationThread from QuantizationThread import QuantizationThread
from TaskListItem import TaskListItem
from error_handling import handle_error, show_error
from imports_and_globals import ( from imports_and_globals import (
ensure_directory,
open_file_safe, open_file_safe,
resource_path, resource_path,
show_about, show_about,
ensure_directory,
) )
from Localizations import *
import presets
import ui_update
import lora_conversion
import utils
import requests
class AutoGGUF(QMainWindow): class AutoGGUF(QMainWindow):
@ -93,8 +91,13 @@ def __init__(self, args: List[str]) -> None:
self.browse_imatrix = utils.browse_imatrix.__get__(self) self.browse_imatrix = utils.browse_imatrix.__get__(self)
self.get_models_data = utils.get_models_data.__get__(self) self.get_models_data = utils.get_models_data.__get__(self)
self.get_tasks_data = utils.get_tasks_data.__get__(self) self.get_tasks_data = utils.get_tasks_data.__get__(self)
self.add_kv_override = partial(utils.add_kv_override, self)
self.remove_kv_override = partial(utils.remove_kv_override, self)
self.cancel_task = partial(TaskListItem.cancel_task, self) self.cancel_task = partial(TaskListItem.cancel_task, self)
self.delete_task = partial(TaskListItem.delete_task, self) self.delete_task = partial(TaskListItem.delete_task, self)
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self) self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self) self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
self.update_threads_slider = partial(ui_update.update_threads_slider, self) self.update_threads_slider = partial(ui_update.update_threads_slider, self)
@ -1171,51 +1174,6 @@ def download_error(self, error_message) -> None:
if os.path.exists(partial_file): if os.path.exists(partial_file):
os.remove(partial_file) os.remove(partial_file)
def show_task_context_menu(self, position) -> None:
self.logger.debug(SHOWING_TASK_CONTEXT_MENU)
item = self.task_list.itemAt(position)
if item is not None:
context_menu = QMenu(self)
properties_action = QAction(PROPERTIES, self)
properties_action.triggered.connect(lambda: self.show_task_properties(item))
context_menu.addAction(properties_action)
task_item = self.task_list.itemWidget(item)
if task_item.status != COMPLETED:
cancel_action = QAction(CANCEL, self)
cancel_action.triggered.connect(lambda: self.cancel_task(item))
context_menu.addAction(cancel_action)
if task_item.status == CANCELED:
restart_action = QAction(RESTART, self)
restart_action.triggered.connect(lambda: self.restart_task(task_item))
context_menu.addAction(restart_action)
delete_action = QAction(DELETE, self)
delete_action.triggered.connect(lambda: self.delete_task(item))
context_menu.addAction(delete_action)
context_menu.exec(self.task_list.viewport().mapToGlobal(position))
def show_task_properties(self, item) -> None:
self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
model_info_dialog = ModelInfoDialog(thread.model_info, self)
model_info_dialog.exec()
break
def cancel_task_by_item(self, item) -> None:
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
thread.terminate()
task_item.update_status(CANCELED)
self.quant_threads.remove(thread)
break
def create_label(self, text, tooltip) -> QLabel: def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text) label = QLabel(text)
label.setToolTip(tooltip) label.setToolTip(tooltip)
@ -1337,23 +1295,6 @@ def validate_quantization_inputs(self) -> None:
if errors: if errors:
raise ValueError("\n".join(errors)) raise ValueError("\n".join(errors))
def add_kv_override(self, override_string=None) -> None:
entry = KVOverrideEntry()
entry.deleted.connect(self.remove_kv_override)
if override_string:
key, value = override_string.split("=")
type_, val = value.split(":")
entry.key_input.setText(key)
entry.type_combo.setCurrentText(type_)
entry.value_input.setText(val)
self.kv_override_layout.addWidget(entry)
self.kv_override_entries.append(entry)
def remove_kv_override(self, entry) -> None:
self.kv_override_layout.removeWidget(entry)
self.kv_override_entries.remove(entry)
entry.deleteLater()
def quantize_model(self) -> None: def quantize_model(self) -> None:
self.logger.info(STARTING_MODEL_QUANTIZATION) self.logger.info(STARTING_MODEL_QUANTIZATION)
try: try:

View File

@ -1,4 +1,5 @@
from PySide6.QtCore import * from PySide6.QtCore import *
from PySide6.QtGui import QAction
from PySide6.QtWidgets import * from PySide6.QtWidgets import *
from Localizations import ( from Localizations import (
@ -6,7 +7,16 @@
CANCELLING_TASK, CANCELLING_TASK,
CONFIRM_DELETION_TITLE, CONFIRM_DELETION_TITLE,
CONFIRM_DELETION, CONFIRM_DELETION,
SHOWING_TASK_CONTEXT_MENU,
CANCELED,
CANCEL,
PROPERTIES,
COMPLETED,
SHOWING_PROPERTIES_FOR_TASK,
DELETE,
RESTART,
) )
from ModelInfoDialog import ModelInfoDialog
class TaskListItem(QWidget): class TaskListItem(QWidget):
@ -37,6 +47,52 @@ def __init__(
self.progress_timer.timeout.connect(self.update_progress) self.progress_timer.timeout.connect(self.update_progress)
self.progress_value = 0 self.progress_value = 0
def show_task_context_menu(self, position) -> None:
self.logger.debug(SHOWING_TASK_CONTEXT_MENU)
item = self.task_list.itemAt(position)
if item is not None:
context_menu = QMenu(self)
properties_action = QAction(PROPERTIES, self)
properties_action.triggered.connect(lambda: self.show_task_properties(item))
context_menu.addAction(properties_action)
task_item = self.task_list.itemWidget(item)
if task_item.status != COMPLETED:
cancel_action = QAction(CANCEL, self)
cancel_action.triggered.connect(lambda: self.cancel_task(item))
context_menu.addAction(cancel_action)
if task_item.status == CANCELED:
restart_action = QAction(RESTART, self)
restart_action.triggered.connect(lambda: self.restart_task(task_item))
context_menu.addAction(restart_action)
delete_action = QAction(DELETE, self)
delete_action.triggered.connect(lambda: self.delete_task(item))
context_menu.addAction(delete_action)
context_menu.exec(self.task_list.viewport().mapToGlobal(position))
def show_task_properties(self, item) -> None:
self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
model_info_dialog = ModelInfoDialog(thread.model_info, self)
model_info_dialog.exec()
break
def cancel_task_by_item(self, item) -> None:
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
thread.terminate()
task_item.update_status(CANCELED)
self.quant_threads.remove(thread)
break
def cancel_task(self, item) -> None: def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text())) self.logger.info(CANCELLING_TASK.format(item.text()))
self.cancel_task_by_item(item) self.cancel_task_by_item(item)

View File

@ -8,6 +8,26 @@
from Localizations import * from Localizations import *
from error_handling import show_error from error_handling import show_error
from imports_and_globals import ensure_directory from imports_and_globals import ensure_directory
from KVOverrideEntry import KVOverrideEntry
def add_kv_override(self, override_string=None) -> None:
entry = KVOverrideEntry()
entry.deleted.connect(self.remove_kv_override)
if override_string:
key, value = override_string.split("=")
type_, val = value.split(":")
entry.key_input.setText(key)
entry.type_combo.setCurrentText(type_)
entry.value_input.setText(val)
self.kv_override_layout.addWidget(entry)
self.kv_override_entries.append(entry)
def remove_kv_override(self, entry) -> None:
self.kv_override_layout.removeWidget(entry)
self.kv_override_entries.remove(entry)
entry.deleteLater()
def get_models_data(self) -> list[dict[str, Union[str, Any]]]: def get_models_data(self) -> list[dict[str, Union[str, Any]]]: