From 6e424462ab57072d0045c0c1cadf9f1c65f4336f Mon Sep 17 00:00:00 2001 From: BuildTools Date: Thu, 22 Aug 2024 21:56:37 -0700 Subject: [PATCH] refactor: add type hints --- src/AutoGGUF.py | 90 +++++++++++++++++++------------------- src/CustomTitleBar.py | 11 ++--- src/DownloadThread.py | 4 +- src/GPUMonitor.py | 22 +++++----- src/KVOverrideEntry.py | 10 ++--- src/Localizations.py | 2 +- src/Logger.py | 12 ++--- src/ModelInfoDialog.py | 4 +- src/QuantizationThread.py | 8 ++-- src/TaskListItem.py | 8 ++-- src/error_handling.py | 6 +-- src/imports_and_globals.py | 10 +++-- src/lora_conversion.py | 12 ++--- src/main.py | 16 +++---- src/presets.py | 4 +- src/ui_update.py | 26 +++++------ src/utils.py | 22 +++++----- 17 files changed, 135 insertions(+), 132 deletions(-) diff --git a/src/AutoGGUF.py b/src/AutoGGUF.py index 8382e0c..c377996 100644 --- a/src/AutoGGUF.py +++ b/src/AutoGGUF.py @@ -5,6 +5,7 @@ from functools import partial from datetime import datetime +from typing import Tuple, Dict from dotenv import load_dotenv from PySide6.QtCore import * from PySide6.QtGui import * @@ -33,7 +34,8 @@ class AutoGGUF(QMainWindow): - def __init__(self, args): + + def __init__(self, args: List[str]) -> None: super().__init__() self.logger = Logger("AutoGGUF", "logs") @@ -785,7 +787,7 @@ def __init__(self, args): self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE) - def load_plugins(self): + def load_plugins(self) -> Dict[str, Dict[str, Any]]: plugins = {} plugin_dir = "plugins" @@ -844,7 +846,7 @@ def load_plugins(self): return plugins - def apply_plugins(self): + def apply_plugins(self) -> None: if not self.plugins: self.logger.info(NO_PLUGINS_LOADED) return @@ -859,7 +861,7 @@ def apply_plugins(self): if hasattr(plugin_instance, "init") and callable(plugin_instance.init): plugin_instance.init(self) - def check_for_updates(self): + def check_for_updates(self) -> None: try: response = requests.get( "https://api.github.com/repos/leafspark/AutoGGUF/releases/latest" @@ -874,7 +876,7 @@ def check_for_updates(self): except requests.exceptions.RequestException as e: self.logger.warning(f"{ERROR_CHECKING_FOR_UPDATES} {e}") - def prompt_for_update(self, release): + def prompt_for_update(self, release) -> None: update_message = QMessageBox() update_message.setIcon(QMessageBox.Information) update_message.setWindowTitle(UPDATE_AVAILABLE) @@ -887,7 +889,7 @@ def prompt_for_update(self, release): if update_message.exec() == QMessageBox.StandardButton.Yes: QDesktopServices.openUrl(QUrl(release["html_url"])) - def keyPressEvent(self, event): + def keyPressEvent(self, event) -> None: if event.modifiers() == Qt.ControlModifier: if ( event.key() == Qt.Key_Equal @@ -899,7 +901,7 @@ def keyPressEvent(self, event): self.reset_size() super().keyPressEvent(event) - def resize_window(self, larger): + def resize_window(self, larger) -> None: factor = 1.1 if larger else 1 / 1.1 current_width = self.width() current_height = self.height() @@ -907,10 +909,10 @@ def resize_window(self, larger): new_height = int(current_height * factor) self.resize(new_width, new_height) - def reset_size(self): + def reset_size(self) -> None: self.resize(self.default_width, self.default_height) - def parse_resolution(self): + def parse_resolution(self) -> Tuple[int, int]: res = os.environ.get("AUTOGGUF_RESOLUTION", "1650x1100") try: width, height = map(int, res.split("x")) @@ -920,14 +922,14 @@ def parse_resolution(self): except (ValueError, AttributeError): return 1650, 1100 - def resizeEvent(self, event): + def resizeEvent(self, event) -> None: super().resizeEvent(event) path = QPainterPath() path.addRoundedRect(self.rect(), 10, 10) mask = QRegion(path.toFillPolygon().toPolygon()) self.setMask(mask) - def refresh_backends(self): + def refresh_backends(self) -> None: self.logger.info(REFRESHING_BACKENDS) llama_bin = os.path.abspath("llama_bin") os.makedirs(llama_bin, exist_ok=True) @@ -951,7 +953,7 @@ def refresh_backends(self): self.backend_combo.setEnabled(False) self.logger.info(FOUND_VALID_BACKENDS.format(len(valid_backends))) - def save_task_preset(self, task_item): + def save_task_preset(self, task_item) -> None: self.logger.info(SAVING_TASK_PRESET.format(task_item.task_name)) for thread in self.quant_threads: if thread.log_file == task_item.log_file: @@ -971,7 +973,7 @@ def save_task_preset(self, task_item): ) break - def browse_base_model(self): + def browse_base_model(self) -> None: self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message base_model_folder = QFileDialog.getExistingDirectory( self, SELECT_BASE_MODEL_FOLDER @@ -979,13 +981,13 @@ def browse_base_model(self): if base_model_folder: self.base_model_path.setText(os.path.abspath(base_model_folder)) - def browse_hf_model_input(self): + def browse_hf_model_input(self) -> None: self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY) model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY) if model_dir: self.hf_model_input.setText(os.path.abspath(model_dir)) - def browse_hf_outfile(self): + def browse_hf_outfile(self) -> None: self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT) outfile, _ = QFileDialog.getSaveFileName( self, SELECT_OUTPUT_FILE, "", GGUF_FILES @@ -993,7 +995,7 @@ def browse_hf_outfile(self): if outfile: self.hf_outfile.setText(os.path.abspath(outfile)) - def convert_hf_to_gguf(self): + def convert_hf_to_gguf(self) -> None: self.logger.info(STARTING_HF_TO_GGUF_CONVERSION) try: model_dir = self.hf_model_input.text() @@ -1063,7 +1065,7 @@ def convert_hf_to_gguf(self): show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e))) self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED) - def restart_task(self, task_item): + def restart_task(self, task_item) -> None: self.logger.info(RESTARTING_TASK.format(task_item.task_name)) for thread in self.quant_threads: if thread.log_file == task_item.log_file: @@ -1083,7 +1085,7 @@ def restart_task(self, task_item): task_item.update_status(IN_PROGRESS) break - def lora_conversion_finished(self, thread, input_path, output_path): + def lora_conversion_finished(self, thread, input_path, output_path) -> None: self.logger.info(LORA_CONVERSION_FINISHED) if thread in self.quant_threads: self.quant_threads.remove(thread) @@ -1099,7 +1101,7 @@ def lora_conversion_finished(self, thread, input_path, output_path): except Exception as e: self.logger.error(ERROR_MOVING_LORA_FILE.format(str(e))) - def download_finished(self, extract_dir): + def download_finished(self, extract_dir) -> None: self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir)) self.download_button.setEnabled(True) self.download_progress.setValue(100) @@ -1136,7 +1138,7 @@ def download_finished(self, extract_dir): if index >= 0: self.backend_combo.setCurrentIndex(index) - def extract_cuda_files(self, extract_dir, destination): + def extract_cuda_files(self, extract_dir, destination) -> None: self.logger.info(EXTRACTING_CUDA_FILES.format(extract_dir, destination)) for root, dirs, files in os.walk(extract_dir): for file in files: @@ -1145,7 +1147,7 @@ def extract_cuda_files(self, extract_dir, destination): dest_path = os.path.join(destination, file) shutil.copy2(source_path, dest_path) - def download_error(self, error_message): + def download_error(self, error_message) -> None: self.logger.error(DOWNLOAD_ERROR.format(error_message)) self.download_button.setEnabled(True) self.download_progress.setValue(0) @@ -1158,7 +1160,7 @@ def download_error(self, error_message): if os.path.exists(partial_file): os.remove(partial_file) - def show_task_context_menu(self, position): + def show_task_context_menu(self, position) -> None: self.logger.debug(SHOWING_TASK_CONTEXT_MENU) item = self.task_list.itemAt(position) if item is not None: @@ -1185,7 +1187,7 @@ def show_task_context_menu(self, position): context_menu.exec(self.task_list.viewport().mapToGlobal(position)) - def show_task_properties(self, item): + def show_task_properties(self, item) -> None: self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text())) task_item = self.task_list.itemWidget(item) for thread in self.quant_threads: @@ -1194,12 +1196,12 @@ def show_task_properties(self, item): model_info_dialog.exec() break - def toggle_gpu_offload_auto(self, state): + def toggle_gpu_offload_auto(self, state) -> None: is_auto = state == Qt.CheckState.Checked self.gpu_offload_slider.setEnabled(not is_auto) self.gpu_offload_spinbox.setEnabled(not is_auto) - def cancel_task_by_item(self, item): + def cancel_task_by_item(self, item) -> None: task_item = self.task_list.itemWidget(item) for thread in self.quant_threads: if thread.log_file == task_item.log_file: @@ -1208,11 +1210,11 @@ def cancel_task_by_item(self, item): self.quant_threads.remove(thread) break - def cancel_task(self, item): + def cancel_task(self, item) -> None: self.logger.info(CANCELLING_TASK.format(item.text())) self.cancel_task_by_item(item) - def delete_task(self, item): + def delete_task(self, item) -> None: self.logger.info(DELETING_TASK.format(item.text())) # Cancel the task first @@ -1233,12 +1235,12 @@ def delete_task(self, item): if task_item: task_item.deleteLater() - def create_label(self, text, tooltip): + def create_label(self, text, tooltip) -> QLabel: label = QLabel(text) label.setToolTip(tooltip) return label - def verify_gguf(self, file_path): + def verify_gguf(self, file_path) -> bool: try: with open(file_path, "rb") as f: magic = f.read(4) @@ -1246,7 +1248,7 @@ def verify_gguf(self, file_path): except Exception: return False - def load_models(self): + def load_models(self) -> None: self.logger.info(LOADING_MODELS) models_dir = self.models_input.text() ensure_directory(models_dir) @@ -1322,7 +1324,7 @@ def load_models(self): CONCATENATED_FILES_FOUND.format(len(concatenated_models)) ) - def add_model_to_tree(self, model): + def add_model_to_tree(self, model) -> QTreeWidgetItem: item = QTreeWidgetItem(self.model_tree) item.setText(0, model) if hasattr(self, "imported_models") and model in [ @@ -1337,7 +1339,7 @@ def add_model_to_tree(self, model): item.setData(0, Qt.ItemDataRole.UserRole, model) return item - def validate_quantization_inputs(self): + def validate_quantization_inputs(self) -> None: self.logger.debug(VALIDATING_QUANTIZATION_INPUTS) errors = [] if not self.backend_combo.currentData(): @@ -1354,7 +1356,7 @@ def validate_quantization_inputs(self): if errors: raise ValueError("\n".join(errors)) - def add_kv_override(self, override_string=None): + def add_kv_override(self, override_string=None) -> None: entry = KVOverrideEntry() entry.deleted.connect(self.remove_kv_override) if override_string: @@ -1366,12 +1368,12 @@ def add_kv_override(self, override_string=None): self.kv_override_layout.addWidget(entry) self.kv_override_entries.append(entry) - def remove_kv_override(self, entry): + def remove_kv_override(self, entry) -> None: self.kv_override_layout.removeWidget(entry) self.kv_override_entries.remove(entry) entry.deleteLater() - def quantize_model(self): + def quantize_model(self) -> None: self.logger.info(STARTING_MODEL_QUANTIZATION) try: self.validate_quantization_inputs() @@ -1539,7 +1541,7 @@ def quantize_model(self): except Exception as e: show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e))) - def parse_progress(self, line, task_item): + def parse_progress(self, line, task_item) -> None: # Parses the output line for progress information and updates the task item. match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*\].*", line) if match: @@ -1548,13 +1550,13 @@ def parse_progress(self, line, task_item): progress = int((current / total) * 100) task_item.update_progress(progress) - def task_finished(self, thread, task_item): + def task_finished(self, thread, task_item) -> None: self.logger.info(TASK_FINISHED.format(thread.log_file)) if thread in self.quant_threads: self.quant_threads.remove(thread) task_item.update_status(COMPLETED) - def show_task_details(self, item): + def show_task_details(self, item) -> None: self.logger.debug(SHOWING_TASK_DETAILS_FOR.format(item.text())) task_item = self.task_list.itemWidget(item) if task_item: @@ -1582,7 +1584,7 @@ def show_task_details(self, item): log_dialog.exec() - def import_model(self): + def import_model(self) -> None: self.logger.info(IMPORTING_MODEL) file_path, _ = QFileDialog.getOpenFileName( self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES @@ -1609,13 +1611,13 @@ def import_model(self): self.load_models() self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name)) - def browse_imatrix_datafile(self): + def browse_imatrix_datafile(self) -> None: self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE) datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES) if datafile: self.imatrix_datafile.setText(os.path.abspath(datafile)) - def browse_imatrix_model(self): + def browse_imatrix_model(self) -> None: self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE) model_file, _ = QFileDialog.getOpenFileName( self, SELECT_MODEL_FILE, "", GGUF_FILES @@ -1623,7 +1625,7 @@ def browse_imatrix_model(self): if model_file: self.imatrix_model.setText(os.path.abspath(model_file)) - def browse_imatrix_output(self): + def browse_imatrix_output(self) -> None: self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE) output_file, _ = QFileDialog.getSaveFileName( self, SELECT_OUTPUT_FILE, "", DAT_FILES @@ -1631,7 +1633,7 @@ def browse_imatrix_output(self): if output_file: self.imatrix_output.setText(os.path.abspath(output_file)) - def generate_imatrix(self): + def generate_imatrix(self) -> None: self.logger.info(STARTING_IMATRIX_GENERATION) try: backend_path = self.backend_combo.currentData() @@ -1692,7 +1694,7 @@ def generate_imatrix(self): show_error(self.logger, ERROR_STARTING_IMATRIX_GENERATION.format(str(e))) self.logger.info(IMATRIX_GENERATION_TASK_STARTED) - def closeEvent(self, event: QCloseEvent): + def closeEvent(self, event: QCloseEvent) -> None: self.logger.info(APPLICATION_CLOSING) if self.quant_threads: reply = QMessageBox.question( diff --git a/src/CustomTitleBar.py b/src/CustomTitleBar.py index 46398eb..a8b4ac2 100644 --- a/src/CustomTitleBar.py +++ b/src/CustomTitleBar.py @@ -1,12 +1,9 @@ from PySide6.QtCore import QPoint -from PySide6.QtGui import QPixmap from PySide6.QtWidgets import QHBoxLayout, QLabel, QMenuBar, QPushButton, QWidget -from imports_and_globals import resource_path - class CustomTitleBar(QWidget): - def __init__(self, parent=None): + def __init__(self, parent=None) -> None: super().__init__(parent) self.parent = parent layout = QHBoxLayout(self) @@ -55,11 +52,11 @@ def __init__(self, parent=None): self.start = QPoint(0, 0) self.pressing = False - def mousePressEvent(self, event): + def mousePressEvent(self, event) -> None: self.start = self.mapToGlobal(event.pos()) self.pressing = True - def mouseMoveEvent(self, event): + def mouseMoveEvent(self, event) -> None: if self.pressing: end = self.mapToGlobal(event.pos()) movement = end - self.start @@ -71,5 +68,5 @@ def mouseMoveEvent(self, event): ) self.start = end - def mouseReleaseEvent(self, event): + def mouseReleaseEvent(self, event) -> None: self.pressing = False diff --git a/src/DownloadThread.py b/src/DownloadThread.py index d27403a..210a4ac 100644 --- a/src/DownloadThread.py +++ b/src/DownloadThread.py @@ -10,12 +10,12 @@ class DownloadThread(QThread): finished_signal = Signal(str) error_signal = Signal(str) - def __init__(self, url, save_path): + def __init__(self, url, save_path) -> None: super().__init__() self.url = url self.save_path = save_path - def run(self): + def run(self) -> None: try: response = requests.get(self.url, stream=True) response.raise_for_status() diff --git a/src/GPUMonitor.py b/src/GPUMonitor.py index 274cc4d..6046c2f 100644 --- a/src/GPUMonitor.py +++ b/src/GPUMonitor.py @@ -28,7 +28,7 @@ class SimpleGraph(QGraphicsView): - def __init__(self, title, parent=None): + def __init__(self, title, parent=None) -> None: super().__init__(parent) self.setScene(QGraphicsScene(self)) self.setRenderHint(QPainter.RenderHint.Antialiasing) @@ -37,7 +37,7 @@ def __init__(self, title, parent=None): self.title = title self.data = [] - def update_data(self, data): + def update_data(self, data) -> None: self.data = data self.scene().clear() if not self.data: @@ -65,13 +65,13 @@ def update_data(self, data): line.setPen(path) self.scene().addItem(line) - def resizeEvent(self, event): + def resizeEvent(self, event) -> None: super().resizeEvent(event) self.update_data(self.data) class GPUMonitor(QWidget): - def __init__(self, parent=None): + def __init__(self, parent=None) -> None: super().__init__(parent) self.setMinimumHeight(30) self.setMaximumHeight(30) @@ -125,17 +125,17 @@ def __init__(self, parent=None): if not self.handles: self.gpu_label.setText(NO_GPU_DETECTED) - def check_for_amd_gpu(self): + def check_for_amd_gpu(self) -> None: # This is a placeholder. Implementing AMD GPU detection would require # platform-specific methods or additional libraries. self.gpu_label.setText(AMD_GPU_NOT_SUPPORTED) - def change_gpu(self, index): + def change_gpu(self, index) -> None: self.current_gpu = index self.gpu_data.clear() self.vram_data.clear() - def update_gpu_info(self): + def update_gpu_info(self) -> None: if self.handles: try: handle = self.handles[self.current_gpu] @@ -165,11 +165,11 @@ def update_gpu_info(self): self.gpu_bar.setValue(0) self.gpu_label.setText(GPU_USAGE_FORMAT.format(0, 0, 0, 0)) - def mouseDoubleClickEvent(self, event): + def mouseDoubleClickEvent(self, event) -> None: if self.handles: self.show_detailed_stats() - def show_detailed_stats(self): + def show_detailed_stats(self) -> None: dialog = QDialog(self) dialog.setWindowTitle(GPU_DETAILS) dialog.setMinimumSize(800, 600) @@ -194,7 +194,7 @@ def show_detailed_stats(self): gpu_graph = SimpleGraph(GPU_USAGE_OVER_TIME) vram_graph = SimpleGraph(VRAM_USAGE_OVER_TIME) - def update_graph_data(): + def update_graph_data() -> None: gpu_graph.update_data(self.gpu_data) vram_graph.update_data(self.vram_data) @@ -207,7 +207,7 @@ def update_graph_data(): dialog.exec() - def closeEvent(self, event): + def closeEvent(self, event) -> None: if self.handles: pynvml.nvmlShutdown() super().closeEvent(event) diff --git a/src/KVOverrideEntry.py b/src/KVOverrideEntry.py index a5cfb7a..325f1aa 100644 --- a/src/KVOverrideEntry.py +++ b/src/KVOverrideEntry.py @@ -11,7 +11,7 @@ class KVOverrideEntry(QWidget): deleted = Signal(QWidget) - def __init__(self, parent=None): + def __init__(self, parent=None) -> None: super().__init__(parent) layout = QHBoxLayout(self) layout.setContentsMargins(0, 0, 0, 0) @@ -42,12 +42,12 @@ def __init__(self, parent=None): # Initialize validator self.update_validator(self.type_combo.currentText()) - def delete_clicked(self): + def delete_clicked(self) -> None: self.deleted.emit(self) def get_override_string( self, model_name=None, quant_type=None, output_path=None - ): # Add arguments + ) -> str: # Add arguments key = self.key_input.text() type_ = self.type_combo.currentText() value = self.value_input.text() @@ -79,11 +79,11 @@ def get_override_string( return f"{key}={type_}:{value}" - def get_raw_override_string(self): + def get_raw_override_string(self) -> str: # Return the raw override string with placeholders intact return f"{self.key_input.text()}={self.type_combo.currentText()}:{self.value_input.text()}" - def update_validator(self, type_): + def update_validator(self, type_) -> None: if type_ == "int": self.value_input.setValidator(QIntValidator()) elif type_ == "float": diff --git a/src/Localizations.py b/src/Localizations.py index 5df4063..c1259de 100644 --- a/src/Localizations.py +++ b/src/Localizations.py @@ -6325,7 +6325,7 @@ def __init__(self): # fmt: on -def set_language(lang_code): +def set_language(lang_code) -> None: # Globals global WINDOW_TITLE, RAM_USAGE, CPU_USAGE, BACKEND, REFRESH_BACKENDS, MODELS_PATH, OUTPUT_PATH, LOGS_PATH global BROWSE, AVAILABLE_MODELS, QUANTIZATION_TYPE, ALLOW_REQUANTIZE, LEAVE_OUTPUT_TENSOR, PURE, IMATRIX diff --git a/src/Logger.py b/src/Logger.py index 30b8d75..5793585 100644 --- a/src/Logger.py +++ b/src/Logger.py @@ -5,7 +5,7 @@ class Logger: - def __init__(self, name, log_dir): + def __init__(self, name, log_dir) -> None: self.logger = logging.getLogger(name) self.logger.setLevel(logging.DEBUG) @@ -34,17 +34,17 @@ def __init__(self, name, log_dir): self.logger.addHandler(console_handler) self.logger.addHandler(file_handler) - def debug(self, message): + def debug(self, message) -> None: self.logger.debug(message) - def info(self, message): + def info(self, message) -> None: self.logger.info(message) - def warning(self, message): + def warning(self, message) -> None: self.logger.warning(message) - def error(self, message): + def error(self, message) -> None: self.logger.error(message) - def critical(self, message): + def critical(self, message) -> None: self.logger.critical(message) diff --git a/src/ModelInfoDialog.py b/src/ModelInfoDialog.py index 161c1d5..849abe4 100644 --- a/src/ModelInfoDialog.py +++ b/src/ModelInfoDialog.py @@ -2,7 +2,7 @@ class ModelInfoDialog(QDialog): - def __init__(self, model_info, parent=None): + def __init__(self, model_info, parent=None) -> None: super().__init__(parent) self.setWindowTitle("Model Information") self.setGeometry(200, 200, 600, 400) @@ -21,7 +21,7 @@ def __init__(self, model_info, parent=None): self.setLayout(layout) - def format_model_info(self, model_info): + def format_model_info(self, model_info) -> str: html = "

Model Information

" html += f"

Architecture: {model_info.get('architecture', 'N/A')}

" html += f"

Quantization Type: {model_info.get('quantization_type', 'N/A')}

" diff --git a/src/QuantizationThread.py b/src/QuantizationThread.py index b2f650d..fb90665 100644 --- a/src/QuantizationThread.py +++ b/src/QuantizationThread.py @@ -15,7 +15,7 @@ class QuantizationThread(QThread): error_signal = Signal(str) model_info_signal = Signal(dict) - def __init__(self, command, cwd, log_file): + def __init__(self, command, cwd, log_file) -> None: super().__init__() self.command = command self.cwd = cwd @@ -23,7 +23,7 @@ def __init__(self, command, cwd, log_file): self.process = None self.model_info = {} - def run(self): + def run(self) -> None: try: # Start the subprocess self.process = subprocess.Popen( @@ -56,7 +56,7 @@ def run(self): except Exception as e: self.error_signal.emit(str(e)) - def parse_model_info(self, line): + def parse_model_info(self, line) -> None: # Parse output for model information if "llama_model_loader: loaded meta data with" in line: parts = line.split() @@ -77,7 +77,7 @@ def parse_model_info(self, line): f"{quant_type}: {tensors} tensors" ) - def terminate(self): + def terminate(self) -> None: # Terminate the subprocess if it's still running if self.process: os.kill(self.process.pid, signal.SIGTERM) diff --git a/src/TaskListItem.py b/src/TaskListItem.py index dfe5aaa..0fe6990 100644 --- a/src/TaskListItem.py +++ b/src/TaskListItem.py @@ -3,7 +3,7 @@ class TaskListItem(QWidget): - def __init__(self, task_name, log_file, show_progress_bar=True, parent=None): + def __init__(self, task_name, log_file, show_progress_bar=True, parent=None) -> None: super().__init__(parent) self.task_name = task_name self.log_file = log_file @@ -28,7 +28,7 @@ def __init__(self, task_name, log_file, show_progress_bar=True, parent=None): self.progress_timer.timeout.connect(self.update_progress) self.progress_value = 0 - def update_status(self, status): + def update_status(self, status) -> None: self.status = status self.status_label.setText(status) if status == "In Progress": @@ -43,14 +43,14 @@ def update_status(self, status): self.progress_timer.stop() self.progress_bar.setValue(0) - def set_error(self): + def set_error(self) -> None: self.status = "Error" self.status_label.setText("Error") self.status_label.setStyleSheet("color: red;") self.progress_bar.setRange(0, 100) self.progress_timer.stop() - def update_progress(self, value=None): + def update_progress(self, value=None) -> None: if value is not None: # Update progress bar with specific value self.progress_value = value diff --git a/src/error_handling.py b/src/error_handling.py index 7b36bbc..0206cde 100644 --- a/src/error_handling.py +++ b/src/error_handling.py @@ -1,13 +1,13 @@ from PySide6.QtWidgets import QMessageBox -from Localizations import * +from Localizations import ERROR_MESSAGE, ERROR, TASK_ERROR -def show_error(logger, message): +def show_error(logger, message) -> None: logger.error(ERROR_MESSAGE.format(message)) QMessageBox.critical(None, ERROR, message) -def handle_error(logger, error_message, task_item): +def handle_error(logger, error_message, task_item) -> None: logger.error(TASK_ERROR.format(error_message)) show_error(logger, error_message) task_item.update_status(ERROR) diff --git a/src/imports_and_globals.py b/src/imports_and_globals.py index e9f3681..eec3e61 100644 --- a/src/imports_and_globals.py +++ b/src/imports_and_globals.py @@ -1,5 +1,7 @@ import os import sys +from typing import LiteralString, TextIO, Union + import psutil import subprocess import time @@ -41,7 +43,7 @@ from Localizations import * -def show_about(self): +def show_about(self) -> None: about_text = ( "AutoGGUF\n\n" f"Version: {AUTOGGUF_VERSION}\n\n" @@ -50,12 +52,12 @@ def show_about(self): QMessageBox.about(self, "About AutoGGUF", about_text) -def ensure_directory(path): +def ensure_directory(path) -> None: if not os.path.exists(path): os.makedirs(path) -def open_file_safe(file_path, mode="r"): +def open_file_safe(file_path, mode="r") -> TextIO: encodings = ["utf-8", "latin-1", "ascii", "utf-16"] for encoding in encodings: try: @@ -67,7 +69,7 @@ def open_file_safe(file_path, mode="r"): ) -def resource_path(relative_path): +def resource_path(relative_path) -> Union[LiteralString, str, bytes]: if hasattr(sys, "_MEIPASS"): # PyInstaller path base_path = sys._MEIPASS diff --git a/src/lora_conversion.py b/src/lora_conversion.py index 4337fbb..aa5cc90 100644 --- a/src/lora_conversion.py +++ b/src/lora_conversion.py @@ -16,7 +16,7 @@ from Localizations import * -def export_lora(self): +def export_lora(self) -> None: self.logger.info(STARTING_LORA_EXPORT) try: model_path = self.export_lora_model.text() @@ -98,7 +98,7 @@ def export_lora(self): show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e))) -def delete_lora_adapter_item(self, adapter_widget): +def delete_lora_adapter_item(self, adapter_widget) -> None: self.logger.info(DELETING_LORA_ADAPTER) # Find the QListWidgetItem containing the adapter_widget for i in range(self.export_lora_adapters.count()): @@ -108,14 +108,14 @@ def delete_lora_adapter_item(self, adapter_widget): break -def browse_export_lora_model(self): +def browse_export_lora_model(self) -> None: self.logger.info(BROWSING_FOR_EXPORT_LORA_MODEL_FILE) model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES) if model_file: self.export_lora_model.setText(os.path.abspath(model_file)) -def browse_export_lora_output(self): +def browse_export_lora_output(self) -> None: self.logger.info(BROWSING_FOR_EXPORT_LORA_OUTPUT_FILE) output_file, _ = QFileDialog.getSaveFileName( self, SELECT_OUTPUT_FILE, "", GGUF_FILES @@ -124,7 +124,7 @@ def browse_export_lora_output(self): self.export_lora_output.setText(os.path.abspath(output_file)) -def add_lora_adapter(self): +def add_lora_adapter(self) -> None: self.logger.info(ADDING_LORA_ADAPTER) adapter_path, _ = QFileDialog.getOpenFileName( self, SELECT_LORA_ADAPTER_FILE, "", LORA_FILES @@ -154,7 +154,7 @@ def add_lora_adapter(self): self.export_lora_adapters.setItemWidget(list_item, adapter_widget) -def convert_lora(self): +def convert_lora(self) -> None: self.logger.info(STARTING_LORA_CONVERSION) try: lora_input_path = self.lora_input.text() diff --git a/src/main.py b/src/main.py index ae4d8d5..5c93d15 100644 --- a/src/main.py +++ b/src/main.py @@ -5,30 +5,30 @@ from PySide6.QtCore import QTimer from PySide6.QtWidgets import QApplication from AutoGGUF import AutoGGUF -from flask import Flask, jsonify +from flask import Flask, Response, jsonify server = Flask(__name__) -def main(): +def main() -> None: @server.route("/v1/models", methods=["GET"]) - def models(): + def models() -> Response: if window: return jsonify({"models": window.get_models_data()}) return jsonify({"models": []}) @server.route("/v1/tasks", methods=["GET"]) - def tasks(): + def tasks() -> Response: if window: return jsonify({"tasks": window.get_tasks_data()}) return jsonify({"tasks": []}) @server.route("/v1/health", methods=["GET"]) - def ping(): + def ping() -> Response: return jsonify({"status": "alive"}) @server.route("/v1/backends", methods=["GET"]) - def get_backends(): + def get_backends() -> Response: backends = [] for i in range(window.backend_combo.count()): backends.append( @@ -40,7 +40,7 @@ def get_backends(): return jsonify({"backends": backends}) @server.route("/v1/plugins", methods=["GET"]) - def get_plugins(): + def get_plugins() -> Response: if window: return jsonify( { @@ -57,7 +57,7 @@ def get_plugins(): ) return jsonify({"plugins": []}) - def run_flask(): + def run_flask() -> None: if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": server.run( host="0.0.0.0", diff --git a/src/presets.py b/src/presets.py index 3acdffb..c203349 100644 --- a/src/presets.py +++ b/src/presets.py @@ -4,7 +4,7 @@ from Localizations import * -def save_preset(self): +def save_preset(self) -> None: self.logger.info(SAVING_PRESET) preset = { "quant_types": [item.text() for item in self.quant_type.selectedItems()], @@ -33,7 +33,7 @@ def save_preset(self): self.logger.info(PRESET_SAVED_TO.format(file_name)) -def load_preset(self): +def load_preset(self) -> None: self.logger.info(LOADING_PRESET) file_name, _ = QFileDialog.getOpenFileName(self, LOAD_PRESET, "", JSON_FILES) if file_name: diff --git a/src/ui_update.py b/src/ui_update.py index cf9e996..bf165c7 100644 --- a/src/ui_update.py +++ b/src/ui_update.py @@ -5,12 +5,12 @@ from error_handling import show_error -def update_model_info(logger, self, model_info): +def update_model_info(logger, self, model_info) -> None: logger.debug(UPDATING_MODEL_INFO.format(model_info)) pass -def update_system_info(self): +def update_system_info(self) -> None: ram = psutil.virtual_memory() cpu = psutil.cpu_percent() @@ -28,7 +28,7 @@ def update_system_info(self): self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu)) -def animate_bar(self, bar, target_value): +def animate_bar(self, bar, target_value) -> None: current_value = bar.value() difference = target_value - current_value @@ -42,7 +42,7 @@ def animate_bar(self, bar, target_value): timer.start(10) # Adjust the interval for animation speed -def _animate_step(bar, target_value, step, timer): +def _animate_step(bar, target_value, step, timer) -> None: current_value = bar.value() new_value = current_value + step @@ -55,11 +55,11 @@ def _animate_step(bar, target_value, step, timer): bar.setValue(new_value) -def update_download_progress(self, progress): +def update_download_progress(self, progress) -> None: self.download_progress.setValue(progress) -def update_cuda_backends(self): +def update_cuda_backends(self) -> None: self.logger.debug(UPDATING_CUDA_BACKENDS) self.backend_combo_cuda.clear() llama_bin = os.path.abspath("llama_bin") @@ -77,23 +77,23 @@ def update_cuda_backends(self): self.backend_combo_cuda.setEnabled(True) -def update_threads_spinbox(self, value): +def update_threads_spinbox(self, value) -> None: self.threads_spinbox.setValue(value) -def update_threads_slider(self, value): +def update_threads_slider(self, value) -> None: self.threads_slider.setValue(value) -def update_gpu_offload_spinbox(self, value): +def update_gpu_offload_spinbox(self, value) -> None: self.gpu_offload_spinbox.setValue(value) -def update_gpu_offload_slider(self, value): +def update_gpu_offload_slider(self, value) -> None: self.gpu_offload_slider.setValue(value) -def update_cuda_option(self): +def update_cuda_option(self) -> None: self.logger.debug(UPDATING_CUDA_OPTIONS) asset = self.asset_combo.currentData() @@ -113,7 +113,7 @@ def update_cuda_option(self): self.update_cuda_backends() -def update_assets(self): +def update_assets(self) -> None: self.logger.debug(UPDATING_ASSET_LIST) self.asset_combo.clear() release = self.release_combo.currentData() @@ -128,6 +128,6 @@ def update_assets(self): self.update_cuda_option() -def update_base_model_visibility(self, index): +def update_base_model_visibility(self, index) -> None: is_gguf = self.lora_output_type_combo.itemText(index) == "GGUF" self.base_model_wrapper.setVisible(is_gguf) diff --git a/src/utils.py b/src/utils.py index ae36822..b4cf27a 100644 --- a/src/utils.py +++ b/src/utils.py @@ -1,3 +1,5 @@ +from typing import Any, Dict, List, Union + from PySide6.QtCore import Qt from PySide6.QtWidgets import QFileDialog @@ -9,7 +11,7 @@ from imports_and_globals import ensure_directory -def get_models_data(self): +def get_models_data(self) -> list[dict[str, Union[str, Any]]]: models = [] root = self.model_tree.invisibleRootItem() child_count = root.childCount() @@ -22,7 +24,7 @@ def get_models_data(self): return models -def get_tasks_data(self): +def get_tasks_data(self) -> list[dict[str, Union[int, Any]]]: tasks = [] for i in range(self.task_list.count()): item = self.task_list.item(i) @@ -43,7 +45,7 @@ def get_tasks_data(self): return tasks -def browse_models(self): +def browse_models(self) -> None: self.logger.info(BROWSING_FOR_MODELS_DIRECTORY) models_path = QFileDialog.getExistingDirectory(self, SELECT_MODELS_DIRECTORY) if models_path: @@ -52,7 +54,7 @@ def browse_models(self): self.load_models() -def browse_output(self): +def browse_output(self) -> None: self.logger.info(BROWSING_FOR_OUTPUT_DIRECTORY) output_path = QFileDialog.getExistingDirectory(self, SELECT_OUTPUT_DIRECTORY) if output_path: @@ -60,7 +62,7 @@ def browse_output(self): ensure_directory(output_path) -def browse_logs(self): +def browse_logs(self) -> None: self.logger.info(BROWSING_FOR_LOGS_DIRECTORY) logs_path = QFileDialog.getExistingDirectory(self, SELECT_LOGS_DIRECTORY) if logs_path: @@ -68,7 +70,7 @@ def browse_logs(self): ensure_directory(logs_path) -def browse_imatrix(self): +def browse_imatrix(self) -> None: self.logger.info(BROWSING_FOR_IMATRIX_FILE) imatrix_file, _ = QFileDialog.getOpenFileName( self, SELECT_IMATRIX_FILE, "", DAT_FILES @@ -77,7 +79,7 @@ def browse_imatrix(self): self.imatrix.setText(os.path.abspath(imatrix_file)) -def browse_lora_input(self): +def browse_lora_input(self) -> None: self.logger.info(BROWSING_FOR_LORA_INPUT_DIRECTORY) lora_input_path = QFileDialog.getExistingDirectory( self, SELECT_LORA_INPUT_DIRECTORY @@ -87,7 +89,7 @@ def browse_lora_input(self): ensure_directory(lora_input_path) -def browse_lora_output(self): +def browse_lora_output(self) -> None: self.logger.info(BROWSING_FOR_LORA_OUTPUT_FILE) lora_output_file, _ = QFileDialog.getSaveFileName( self, SELECT_LORA_OUTPUT_FILE, "", GGUF_AND_BIN_FILES @@ -96,7 +98,7 @@ def browse_lora_output(self): self.lora_output.setText(os.path.abspath(lora_output_file)) -def download_llama_cpp(self): +def download_llama_cpp(self) -> None: self.logger.info(STARTING_LLAMACPP_DOWNLOAD) asset = self.asset_combo.currentData() if not asset: @@ -118,7 +120,7 @@ def download_llama_cpp(self): self.download_progress.setValue(0) -def refresh_releases(self): +def refresh_releases(self) -> None: self.logger.info(REFRESHING_LLAMACPP_RELEASES) try: response = requests.get(