refactor: move get helper functions to utils.py

- move get_models_data and get_tasks_data to utils.py from AutoGGUF.py
This commit is contained in:
BuildTools 2024-08-22 17:08:45 -07:00
parent 4f2c8057e1
commit a97a545a28
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
2 changed files with 37 additions and 32 deletions

View File

@ -80,6 +80,8 @@ def __init__(self, args):
self.browse_output = utils.browse_output.__get__(self)
self.browse_logs = utils.browse_logs.__get__(self)
self.browse_imatrix = utils.browse_imatrix.__get__(self)
self.get_models_data = utils.get_models_data.__get__(self)
self.get_tasks_data = utils.get_tasks_data.__get__(self)
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
self.update_gpu_offload_spinbox = partial(
@ -1549,38 +1551,6 @@ def browse_imatrix_output(self):
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))
def get_models_data(self):
models = []
root = self.model_tree.invisibleRootItem()
child_count = root.childCount()
for i in range(child_count):
item = root.child(i)
model_name = item.text(0)
model_type = "sharded" if "sharded" in model_name.lower() else "single"
model_path = item.data(0, Qt.ItemDataRole.UserRole)
models.append({"name": model_name, "type": model_type, "path": model_path})
return models
def get_tasks_data(self):
tasks = []
for i in range(self.task_list.count()):
item = self.task_list.item(i)
task_widget = self.task_list.itemWidget(item)
if task_widget:
tasks.append(
{
"name": task_widget.task_name,
"status": task_widget.status,
"progress": (
task_widget.progress_bar.value()
if hasattr(task_widget, "progress_bar")
else 0
),
"log_file": task_widget.log_file,
}
)
return tasks
def generate_imatrix(self):
self.logger.info(STARTING_IMATRIX_GENERATION)
try:

View File

@ -1,3 +1,4 @@
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QFileDialog
from error_handling import show_error
@ -8,6 +9,40 @@
from imports_and_globals import ensure_directory
def get_models_data(self):
models = []
root = self.model_tree.invisibleRootItem()
child_count = root.childCount()
for i in range(child_count):
item = root.child(i)
model_name = item.text(0)
model_type = "sharded" if "sharded" in model_name.lower() else "single"
model_path = item.data(0, Qt.ItemDataRole.UserRole)
models.append({"name": model_name, "type": model_type, "path": model_path})
return models
def get_tasks_data(self):
tasks = []
for i in range(self.task_list.count()):
item = self.task_list.item(i)
task_widget = self.task_list.itemWidget(item)
if task_widget:
tasks.append(
{
"name": task_widget.task_name,
"status": task_widget.status,
"progress": (
task_widget.progress_bar.value()
if hasattr(task_widget, "progress_bar")
else 0
),
"log_file": task_widget.log_file,
}
)
return tasks
def browse_models(self):
self.logger.info(BROWSING_FOR_MODELS_DIRECTORY)
models_path = QFileDialog.getExistingDirectory(self, SELECT_MODELS_DIRECTORY)