refactor: add type hints

This commit is contained in:
BuildTools 2024-08-22 21:56:37 -07:00
parent d4be39a22c
commit 6e424462ab
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
17 changed files with 135 additions and 132 deletions

View File

@ -5,6 +5,7 @@
from functools import partial from functools import partial
from datetime import datetime from datetime import datetime
from typing import Tuple, Dict
from dotenv import load_dotenv from dotenv import load_dotenv
from PySide6.QtCore import * from PySide6.QtCore import *
from PySide6.QtGui import * from PySide6.QtGui import *
@ -33,7 +34,8 @@
class AutoGGUF(QMainWindow): class AutoGGUF(QMainWindow):
def __init__(self, args):
def __init__(self, args: List[str]) -> None:
super().__init__() super().__init__()
self.logger = Logger("AutoGGUF", "logs") self.logger = Logger("AutoGGUF", "logs")
@ -785,7 +787,7 @@ def __init__(self, args):
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE) self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
def load_plugins(self): def load_plugins(self) -> Dict[str, Dict[str, Any]]:
plugins = {} plugins = {}
plugin_dir = "plugins" plugin_dir = "plugins"
@ -844,7 +846,7 @@ def load_plugins(self):
return plugins return plugins
def apply_plugins(self): def apply_plugins(self) -> None:
if not self.plugins: if not self.plugins:
self.logger.info(NO_PLUGINS_LOADED) self.logger.info(NO_PLUGINS_LOADED)
return return
@ -859,7 +861,7 @@ def apply_plugins(self):
if hasattr(plugin_instance, "init") and callable(plugin_instance.init): if hasattr(plugin_instance, "init") and callable(plugin_instance.init):
plugin_instance.init(self) plugin_instance.init(self)
def check_for_updates(self): def check_for_updates(self) -> None:
try: try:
response = requests.get( response = requests.get(
"https://api.github.com/repos/leafspark/AutoGGUF/releases/latest" "https://api.github.com/repos/leafspark/AutoGGUF/releases/latest"
@ -874,7 +876,7 @@ def check_for_updates(self):
except requests.exceptions.RequestException as e: except requests.exceptions.RequestException as e:
self.logger.warning(f"{ERROR_CHECKING_FOR_UPDATES} {e}") self.logger.warning(f"{ERROR_CHECKING_FOR_UPDATES} {e}")
def prompt_for_update(self, release): def prompt_for_update(self, release) -> None:
update_message = QMessageBox() update_message = QMessageBox()
update_message.setIcon(QMessageBox.Information) update_message.setIcon(QMessageBox.Information)
update_message.setWindowTitle(UPDATE_AVAILABLE) update_message.setWindowTitle(UPDATE_AVAILABLE)
@ -887,7 +889,7 @@ def prompt_for_update(self, release):
if update_message.exec() == QMessageBox.StandardButton.Yes: if update_message.exec() == QMessageBox.StandardButton.Yes:
QDesktopServices.openUrl(QUrl(release["html_url"])) QDesktopServices.openUrl(QUrl(release["html_url"]))
def keyPressEvent(self, event): def keyPressEvent(self, event) -> None:
if event.modifiers() == Qt.ControlModifier: if event.modifiers() == Qt.ControlModifier:
if ( if (
event.key() == Qt.Key_Equal event.key() == Qt.Key_Equal
@ -899,7 +901,7 @@ def keyPressEvent(self, event):
self.reset_size() self.reset_size()
super().keyPressEvent(event) super().keyPressEvent(event)
def resize_window(self, larger): def resize_window(self, larger) -> None:
factor = 1.1 if larger else 1 / 1.1 factor = 1.1 if larger else 1 / 1.1
current_width = self.width() current_width = self.width()
current_height = self.height() current_height = self.height()
@ -907,10 +909,10 @@ def resize_window(self, larger):
new_height = int(current_height * factor) new_height = int(current_height * factor)
self.resize(new_width, new_height) self.resize(new_width, new_height)
def reset_size(self): def reset_size(self) -> None:
self.resize(self.default_width, self.default_height) self.resize(self.default_width, self.default_height)
def parse_resolution(self): def parse_resolution(self) -> Tuple[int, int]:
res = os.environ.get("AUTOGGUF_RESOLUTION", "1650x1100") res = os.environ.get("AUTOGGUF_RESOLUTION", "1650x1100")
try: try:
width, height = map(int, res.split("x")) width, height = map(int, res.split("x"))
@ -920,14 +922,14 @@ def parse_resolution(self):
except (ValueError, AttributeError): except (ValueError, AttributeError):
return 1650, 1100 return 1650, 1100
def resizeEvent(self, event): def resizeEvent(self, event) -> None:
super().resizeEvent(event) super().resizeEvent(event)
path = QPainterPath() path = QPainterPath()
path.addRoundedRect(self.rect(), 10, 10) path.addRoundedRect(self.rect(), 10, 10)
mask = QRegion(path.toFillPolygon().toPolygon()) mask = QRegion(path.toFillPolygon().toPolygon())
self.setMask(mask) self.setMask(mask)
def refresh_backends(self): def refresh_backends(self) -> None:
self.logger.info(REFRESHING_BACKENDS) self.logger.info(REFRESHING_BACKENDS)
llama_bin = os.path.abspath("llama_bin") llama_bin = os.path.abspath("llama_bin")
os.makedirs(llama_bin, exist_ok=True) os.makedirs(llama_bin, exist_ok=True)
@ -951,7 +953,7 @@ def refresh_backends(self):
self.backend_combo.setEnabled(False) self.backend_combo.setEnabled(False)
self.logger.info(FOUND_VALID_BACKENDS.format(len(valid_backends))) self.logger.info(FOUND_VALID_BACKENDS.format(len(valid_backends)))
def save_task_preset(self, task_item): def save_task_preset(self, task_item) -> None:
self.logger.info(SAVING_TASK_PRESET.format(task_item.task_name)) self.logger.info(SAVING_TASK_PRESET.format(task_item.task_name))
for thread in self.quant_threads: for thread in self.quant_threads:
if thread.log_file == task_item.log_file: if thread.log_file == task_item.log_file:
@ -971,7 +973,7 @@ def save_task_preset(self, task_item):
) )
break break
def browse_base_model(self): def browse_base_model(self) -> None:
self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message
base_model_folder = QFileDialog.getExistingDirectory( base_model_folder = QFileDialog.getExistingDirectory(
self, SELECT_BASE_MODEL_FOLDER self, SELECT_BASE_MODEL_FOLDER
@ -979,13 +981,13 @@ def browse_base_model(self):
if base_model_folder: if base_model_folder:
self.base_model_path.setText(os.path.abspath(base_model_folder)) self.base_model_path.setText(os.path.abspath(base_model_folder))
def browse_hf_model_input(self): def browse_hf_model_input(self) -> None:
self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY) self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY) model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir: if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir)) self.hf_model_input.setText(os.path.abspath(model_dir))
def browse_hf_outfile(self): def browse_hf_outfile(self) -> None:
self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT) self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT)
outfile, _ = QFileDialog.getSaveFileName( outfile, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES self, SELECT_OUTPUT_FILE, "", GGUF_FILES
@ -993,7 +995,7 @@ def browse_hf_outfile(self):
if outfile: if outfile:
self.hf_outfile.setText(os.path.abspath(outfile)) self.hf_outfile.setText(os.path.abspath(outfile))
def convert_hf_to_gguf(self): def convert_hf_to_gguf(self) -> None:
self.logger.info(STARTING_HF_TO_GGUF_CONVERSION) self.logger.info(STARTING_HF_TO_GGUF_CONVERSION)
try: try:
model_dir = self.hf_model_input.text() model_dir = self.hf_model_input.text()
@ -1063,7 +1065,7 @@ def convert_hf_to_gguf(self):
show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e))) show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED) self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED)
def restart_task(self, task_item): def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name)) self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads: for thread in self.quant_threads:
if thread.log_file == task_item.log_file: if thread.log_file == task_item.log_file:
@ -1083,7 +1085,7 @@ def restart_task(self, task_item):
task_item.update_status(IN_PROGRESS) task_item.update_status(IN_PROGRESS)
break break
def lora_conversion_finished(self, thread, input_path, output_path): def lora_conversion_finished(self, thread, input_path, output_path) -> None:
self.logger.info(LORA_CONVERSION_FINISHED) self.logger.info(LORA_CONVERSION_FINISHED)
if thread in self.quant_threads: if thread in self.quant_threads:
self.quant_threads.remove(thread) self.quant_threads.remove(thread)
@ -1099,7 +1101,7 @@ def lora_conversion_finished(self, thread, input_path, output_path):
except Exception as e: except Exception as e:
self.logger.error(ERROR_MOVING_LORA_FILE.format(str(e))) self.logger.error(ERROR_MOVING_LORA_FILE.format(str(e)))
def download_finished(self, extract_dir): def download_finished(self, extract_dir) -> None:
self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir)) self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir))
self.download_button.setEnabled(True) self.download_button.setEnabled(True)
self.download_progress.setValue(100) self.download_progress.setValue(100)
@ -1136,7 +1138,7 @@ def download_finished(self, extract_dir):
if index >= 0: if index >= 0:
self.backend_combo.setCurrentIndex(index) self.backend_combo.setCurrentIndex(index)
def extract_cuda_files(self, extract_dir, destination): def extract_cuda_files(self, extract_dir, destination) -> None:
self.logger.info(EXTRACTING_CUDA_FILES.format(extract_dir, destination)) self.logger.info(EXTRACTING_CUDA_FILES.format(extract_dir, destination))
for root, dirs, files in os.walk(extract_dir): for root, dirs, files in os.walk(extract_dir):
for file in files: for file in files:
@ -1145,7 +1147,7 @@ def extract_cuda_files(self, extract_dir, destination):
dest_path = os.path.join(destination, file) dest_path = os.path.join(destination, file)
shutil.copy2(source_path, dest_path) shutil.copy2(source_path, dest_path)
def download_error(self, error_message): def download_error(self, error_message) -> None:
self.logger.error(DOWNLOAD_ERROR.format(error_message)) self.logger.error(DOWNLOAD_ERROR.format(error_message))
self.download_button.setEnabled(True) self.download_button.setEnabled(True)
self.download_progress.setValue(0) self.download_progress.setValue(0)
@ -1158,7 +1160,7 @@ def download_error(self, error_message):
if os.path.exists(partial_file): if os.path.exists(partial_file):
os.remove(partial_file) os.remove(partial_file)
def show_task_context_menu(self, position): def show_task_context_menu(self, position) -> None:
self.logger.debug(SHOWING_TASK_CONTEXT_MENU) self.logger.debug(SHOWING_TASK_CONTEXT_MENU)
item = self.task_list.itemAt(position) item = self.task_list.itemAt(position)
if item is not None: if item is not None:
@ -1185,7 +1187,7 @@ def show_task_context_menu(self, position):
context_menu.exec(self.task_list.viewport().mapToGlobal(position)) context_menu.exec(self.task_list.viewport().mapToGlobal(position))
def show_task_properties(self, item): def show_task_properties(self, item) -> None:
self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text())) self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item) task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads: for thread in self.quant_threads:
@ -1194,12 +1196,12 @@ def show_task_properties(self, item):
model_info_dialog.exec() model_info_dialog.exec()
break break
def toggle_gpu_offload_auto(self, state): def toggle_gpu_offload_auto(self, state) -> None:
is_auto = state == Qt.CheckState.Checked is_auto = state == Qt.CheckState.Checked
self.gpu_offload_slider.setEnabled(not is_auto) self.gpu_offload_slider.setEnabled(not is_auto)
self.gpu_offload_spinbox.setEnabled(not is_auto) self.gpu_offload_spinbox.setEnabled(not is_auto)
def cancel_task_by_item(self, item): def cancel_task_by_item(self, item) -> None:
task_item = self.task_list.itemWidget(item) task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads: for thread in self.quant_threads:
if thread.log_file == task_item.log_file: if thread.log_file == task_item.log_file:
@ -1208,11 +1210,11 @@ def cancel_task_by_item(self, item):
self.quant_threads.remove(thread) self.quant_threads.remove(thread)
break break
def cancel_task(self, item): def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text())) self.logger.info(CANCELLING_TASK.format(item.text()))
self.cancel_task_by_item(item) self.cancel_task_by_item(item)
def delete_task(self, item): def delete_task(self, item) -> None:
self.logger.info(DELETING_TASK.format(item.text())) self.logger.info(DELETING_TASK.format(item.text()))
# Cancel the task first # Cancel the task first
@ -1233,12 +1235,12 @@ def delete_task(self, item):
if task_item: if task_item:
task_item.deleteLater() task_item.deleteLater()
def create_label(self, text, tooltip): def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text) label = QLabel(text)
label.setToolTip(tooltip) label.setToolTip(tooltip)
return label return label
def verify_gguf(self, file_path): def verify_gguf(self, file_path) -> bool:
try: try:
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
magic = f.read(4) magic = f.read(4)
@ -1246,7 +1248,7 @@ def verify_gguf(self, file_path):
except Exception: except Exception:
return False return False
def load_models(self): def load_models(self) -> None:
self.logger.info(LOADING_MODELS) self.logger.info(LOADING_MODELS)
models_dir = self.models_input.text() models_dir = self.models_input.text()
ensure_directory(models_dir) ensure_directory(models_dir)
@ -1322,7 +1324,7 @@ def load_models(self):
CONCATENATED_FILES_FOUND.format(len(concatenated_models)) CONCATENATED_FILES_FOUND.format(len(concatenated_models))
) )
def add_model_to_tree(self, model): def add_model_to_tree(self, model) -> QTreeWidgetItem:
item = QTreeWidgetItem(self.model_tree) item = QTreeWidgetItem(self.model_tree)
item.setText(0, model) item.setText(0, model)
if hasattr(self, "imported_models") and model in [ if hasattr(self, "imported_models") and model in [
@ -1337,7 +1339,7 @@ def add_model_to_tree(self, model):
item.setData(0, Qt.ItemDataRole.UserRole, model) item.setData(0, Qt.ItemDataRole.UserRole, model)
return item return item
def validate_quantization_inputs(self): def validate_quantization_inputs(self) -> None:
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS) self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
errors = [] errors = []
if not self.backend_combo.currentData(): if not self.backend_combo.currentData():
@ -1354,7 +1356,7 @@ def validate_quantization_inputs(self):
if errors: if errors:
raise ValueError("\n".join(errors)) raise ValueError("\n".join(errors))
def add_kv_override(self, override_string=None): def add_kv_override(self, override_string=None) -> None:
entry = KVOverrideEntry() entry = KVOverrideEntry()
entry.deleted.connect(self.remove_kv_override) entry.deleted.connect(self.remove_kv_override)
if override_string: if override_string:
@ -1366,12 +1368,12 @@ def add_kv_override(self, override_string=None):
self.kv_override_layout.addWidget(entry) self.kv_override_layout.addWidget(entry)
self.kv_override_entries.append(entry) self.kv_override_entries.append(entry)
def remove_kv_override(self, entry): def remove_kv_override(self, entry) -> None:
self.kv_override_layout.removeWidget(entry) self.kv_override_layout.removeWidget(entry)
self.kv_override_entries.remove(entry) self.kv_override_entries.remove(entry)
entry.deleteLater() entry.deleteLater()
def quantize_model(self): def quantize_model(self) -> None:
self.logger.info(STARTING_MODEL_QUANTIZATION) self.logger.info(STARTING_MODEL_QUANTIZATION)
try: try:
self.validate_quantization_inputs() self.validate_quantization_inputs()
@ -1539,7 +1541,7 @@ def quantize_model(self):
except Exception as e: except Exception as e:
show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e))) show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e)))
def parse_progress(self, line, task_item): def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item. # Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*\].*", line) match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*\].*", line)
if match: if match:
@ -1548,13 +1550,13 @@ def parse_progress(self, line, task_item):
progress = int((current / total) * 100) progress = int((current / total) * 100)
task_item.update_progress(progress) task_item.update_progress(progress)
def task_finished(self, thread, task_item): def task_finished(self, thread, task_item) -> None:
self.logger.info(TASK_FINISHED.format(thread.log_file)) self.logger.info(TASK_FINISHED.format(thread.log_file))
if thread in self.quant_threads: if thread in self.quant_threads:
self.quant_threads.remove(thread) self.quant_threads.remove(thread)
task_item.update_status(COMPLETED) task_item.update_status(COMPLETED)
def show_task_details(self, item): def show_task_details(self, item) -> None:
self.logger.debug(SHOWING_TASK_DETAILS_FOR.format(item.text())) self.logger.debug(SHOWING_TASK_DETAILS_FOR.format(item.text()))
task_item = self.task_list.itemWidget(item) task_item = self.task_list.itemWidget(item)
if task_item: if task_item:
@ -1582,7 +1584,7 @@ def show_task_details(self, item):
log_dialog.exec() log_dialog.exec()
def import_model(self): def import_model(self) -> None:
self.logger.info(IMPORTING_MODEL) self.logger.info(IMPORTING_MODEL)
file_path, _ = QFileDialog.getOpenFileName( file_path, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES
@ -1609,13 +1611,13 @@ def import_model(self):
self.load_models() self.load_models()
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name)) self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))
def browse_imatrix_datafile(self): def browse_imatrix_datafile(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE) self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES) datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
if datafile: if datafile:
self.imatrix_datafile.setText(os.path.abspath(datafile)) self.imatrix_datafile.setText(os.path.abspath(datafile))
def browse_imatrix_model(self): def browse_imatrix_model(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE) self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName( model_file, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_FILE, "", GGUF_FILES self, SELECT_MODEL_FILE, "", GGUF_FILES
@ -1623,7 +1625,7 @@ def browse_imatrix_model(self):
if model_file: if model_file:
self.imatrix_model.setText(os.path.abspath(model_file)) self.imatrix_model.setText(os.path.abspath(model_file))
def browse_imatrix_output(self): def browse_imatrix_output(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE) self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName( output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", DAT_FILES self, SELECT_OUTPUT_FILE, "", DAT_FILES
@ -1631,7 +1633,7 @@ def browse_imatrix_output(self):
if output_file: if output_file:
self.imatrix_output.setText(os.path.abspath(output_file)) self.imatrix_output.setText(os.path.abspath(output_file))
def generate_imatrix(self): def generate_imatrix(self) -> None:
self.logger.info(STARTING_IMATRIX_GENERATION) self.logger.info(STARTING_IMATRIX_GENERATION)
try: try:
backend_path = self.backend_combo.currentData() backend_path = self.backend_combo.currentData()
@ -1692,7 +1694,7 @@ def generate_imatrix(self):
show_error(self.logger, ERROR_STARTING_IMATRIX_GENERATION.format(str(e))) show_error(self.logger, ERROR_STARTING_IMATRIX_GENERATION.format(str(e)))
self.logger.info(IMATRIX_GENERATION_TASK_STARTED) self.logger.info(IMATRIX_GENERATION_TASK_STARTED)
def closeEvent(self, event: QCloseEvent): def closeEvent(self, event: QCloseEvent) -> None:
self.logger.info(APPLICATION_CLOSING) self.logger.info(APPLICATION_CLOSING)
if self.quant_threads: if self.quant_threads:
reply = QMessageBox.question( reply = QMessageBox.question(

View File

@ -1,12 +1,9 @@
from PySide6.QtCore import QPoint from PySide6.QtCore import QPoint
from PySide6.QtGui import QPixmap
from PySide6.QtWidgets import QHBoxLayout, QLabel, QMenuBar, QPushButton, QWidget from PySide6.QtWidgets import QHBoxLayout, QLabel, QMenuBar, QPushButton, QWidget
from imports_and_globals import resource_path
class CustomTitleBar(QWidget): class CustomTitleBar(QWidget):
def __init__(self, parent=None): def __init__(self, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
self.parent = parent self.parent = parent
layout = QHBoxLayout(self) layout = QHBoxLayout(self)
@ -55,11 +52,11 @@ def __init__(self, parent=None):
self.start = QPoint(0, 0) self.start = QPoint(0, 0)
self.pressing = False self.pressing = False
def mousePressEvent(self, event): def mousePressEvent(self, event) -> None:
self.start = self.mapToGlobal(event.pos()) self.start = self.mapToGlobal(event.pos())
self.pressing = True self.pressing = True
def mouseMoveEvent(self, event): def mouseMoveEvent(self, event) -> None:
if self.pressing: if self.pressing:
end = self.mapToGlobal(event.pos()) end = self.mapToGlobal(event.pos())
movement = end - self.start movement = end - self.start
@ -71,5 +68,5 @@ def mouseMoveEvent(self, event):
) )
self.start = end self.start = end
def mouseReleaseEvent(self, event): def mouseReleaseEvent(self, event) -> None:
self.pressing = False self.pressing = False

View File

@ -10,12 +10,12 @@ class DownloadThread(QThread):
finished_signal = Signal(str) finished_signal = Signal(str)
error_signal = Signal(str) error_signal = Signal(str)
def __init__(self, url, save_path): def __init__(self, url, save_path) -> None:
super().__init__() super().__init__()
self.url = url self.url = url
self.save_path = save_path self.save_path = save_path
def run(self): def run(self) -> None:
try: try:
response = requests.get(self.url, stream=True) response = requests.get(self.url, stream=True)
response.raise_for_status() response.raise_for_status()

View File

@ -28,7 +28,7 @@
class SimpleGraph(QGraphicsView): class SimpleGraph(QGraphicsView):
def __init__(self, title, parent=None): def __init__(self, title, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
self.setScene(QGraphicsScene(self)) self.setScene(QGraphicsScene(self))
self.setRenderHint(QPainter.RenderHint.Antialiasing) self.setRenderHint(QPainter.RenderHint.Antialiasing)
@ -37,7 +37,7 @@ def __init__(self, title, parent=None):
self.title = title self.title = title
self.data = [] self.data = []
def update_data(self, data): def update_data(self, data) -> None:
self.data = data self.data = data
self.scene().clear() self.scene().clear()
if not self.data: if not self.data:
@ -65,13 +65,13 @@ def update_data(self, data):
line.setPen(path) line.setPen(path)
self.scene().addItem(line) self.scene().addItem(line)
def resizeEvent(self, event): def resizeEvent(self, event) -> None:
super().resizeEvent(event) super().resizeEvent(event)
self.update_data(self.data) self.update_data(self.data)
class GPUMonitor(QWidget): class GPUMonitor(QWidget):
def __init__(self, parent=None): def __init__(self, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
self.setMinimumHeight(30) self.setMinimumHeight(30)
self.setMaximumHeight(30) self.setMaximumHeight(30)
@ -125,17 +125,17 @@ def __init__(self, parent=None):
if not self.handles: if not self.handles:
self.gpu_label.setText(NO_GPU_DETECTED) self.gpu_label.setText(NO_GPU_DETECTED)
def check_for_amd_gpu(self): def check_for_amd_gpu(self) -> None:
# This is a placeholder. Implementing AMD GPU detection would require # This is a placeholder. Implementing AMD GPU detection would require
# platform-specific methods or additional libraries. # platform-specific methods or additional libraries.
self.gpu_label.setText(AMD_GPU_NOT_SUPPORTED) self.gpu_label.setText(AMD_GPU_NOT_SUPPORTED)
def change_gpu(self, index): def change_gpu(self, index) -> None:
self.current_gpu = index self.current_gpu = index
self.gpu_data.clear() self.gpu_data.clear()
self.vram_data.clear() self.vram_data.clear()
def update_gpu_info(self): def update_gpu_info(self) -> None:
if self.handles: if self.handles:
try: try:
handle = self.handles[self.current_gpu] handle = self.handles[self.current_gpu]
@ -165,11 +165,11 @@ def update_gpu_info(self):
self.gpu_bar.setValue(0) self.gpu_bar.setValue(0)
self.gpu_label.setText(GPU_USAGE_FORMAT.format(0, 0, 0, 0)) self.gpu_label.setText(GPU_USAGE_FORMAT.format(0, 0, 0, 0))
def mouseDoubleClickEvent(self, event): def mouseDoubleClickEvent(self, event) -> None:
if self.handles: if self.handles:
self.show_detailed_stats() self.show_detailed_stats()
def show_detailed_stats(self): def show_detailed_stats(self) -> None:
dialog = QDialog(self) dialog = QDialog(self)
dialog.setWindowTitle(GPU_DETAILS) dialog.setWindowTitle(GPU_DETAILS)
dialog.setMinimumSize(800, 600) dialog.setMinimumSize(800, 600)
@ -194,7 +194,7 @@ def show_detailed_stats(self):
gpu_graph = SimpleGraph(GPU_USAGE_OVER_TIME) gpu_graph = SimpleGraph(GPU_USAGE_OVER_TIME)
vram_graph = SimpleGraph(VRAM_USAGE_OVER_TIME) vram_graph = SimpleGraph(VRAM_USAGE_OVER_TIME)
def update_graph_data(): def update_graph_data() -> None:
gpu_graph.update_data(self.gpu_data) gpu_graph.update_data(self.gpu_data)
vram_graph.update_data(self.vram_data) vram_graph.update_data(self.vram_data)
@ -207,7 +207,7 @@ def update_graph_data():
dialog.exec() dialog.exec()
def closeEvent(self, event): def closeEvent(self, event) -> None:
if self.handles: if self.handles:
pynvml.nvmlShutdown() pynvml.nvmlShutdown()
super().closeEvent(event) super().closeEvent(event)

View File

@ -11,7 +11,7 @@
class KVOverrideEntry(QWidget): class KVOverrideEntry(QWidget):
deleted = Signal(QWidget) deleted = Signal(QWidget)
def __init__(self, parent=None): def __init__(self, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
layout = QHBoxLayout(self) layout = QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0) layout.setContentsMargins(0, 0, 0, 0)
@ -42,12 +42,12 @@ def __init__(self, parent=None):
# Initialize validator # Initialize validator
self.update_validator(self.type_combo.currentText()) self.update_validator(self.type_combo.currentText())
def delete_clicked(self): def delete_clicked(self) -> None:
self.deleted.emit(self) self.deleted.emit(self)
def get_override_string( def get_override_string(
self, model_name=None, quant_type=None, output_path=None self, model_name=None, quant_type=None, output_path=None
): # Add arguments ) -> str: # Add arguments
key = self.key_input.text() key = self.key_input.text()
type_ = self.type_combo.currentText() type_ = self.type_combo.currentText()
value = self.value_input.text() value = self.value_input.text()
@ -79,11 +79,11 @@ def get_override_string(
return f"{key}={type_}:{value}" return f"{key}={type_}:{value}"
def get_raw_override_string(self): def get_raw_override_string(self) -> str:
# Return the raw override string with placeholders intact # Return the raw override string with placeholders intact
return f"{self.key_input.text()}={self.type_combo.currentText()}:{self.value_input.text()}" return f"{self.key_input.text()}={self.type_combo.currentText()}:{self.value_input.text()}"
def update_validator(self, type_): def update_validator(self, type_) -> None:
if type_ == "int": if type_ == "int":
self.value_input.setValidator(QIntValidator()) self.value_input.setValidator(QIntValidator())
elif type_ == "float": elif type_ == "float":

View File

@ -6325,7 +6325,7 @@ def __init__(self):
# fmt: on # fmt: on
def set_language(lang_code): def set_language(lang_code) -> None:
# Globals # Globals
global WINDOW_TITLE, RAM_USAGE, CPU_USAGE, BACKEND, REFRESH_BACKENDS, MODELS_PATH, OUTPUT_PATH, LOGS_PATH global WINDOW_TITLE, RAM_USAGE, CPU_USAGE, BACKEND, REFRESH_BACKENDS, MODELS_PATH, OUTPUT_PATH, LOGS_PATH
global BROWSE, AVAILABLE_MODELS, QUANTIZATION_TYPE, ALLOW_REQUANTIZE, LEAVE_OUTPUT_TENSOR, PURE, IMATRIX global BROWSE, AVAILABLE_MODELS, QUANTIZATION_TYPE, ALLOW_REQUANTIZE, LEAVE_OUTPUT_TENSOR, PURE, IMATRIX

View File

@ -5,7 +5,7 @@
class Logger: class Logger:
def __init__(self, name, log_dir): def __init__(self, name, log_dir) -> None:
self.logger = logging.getLogger(name) self.logger = logging.getLogger(name)
self.logger.setLevel(logging.DEBUG) self.logger.setLevel(logging.DEBUG)
@ -34,17 +34,17 @@ def __init__(self, name, log_dir):
self.logger.addHandler(console_handler) self.logger.addHandler(console_handler)
self.logger.addHandler(file_handler) self.logger.addHandler(file_handler)
def debug(self, message): def debug(self, message) -> None:
self.logger.debug(message) self.logger.debug(message)
def info(self, message): def info(self, message) -> None:
self.logger.info(message) self.logger.info(message)
def warning(self, message): def warning(self, message) -> None:
self.logger.warning(message) self.logger.warning(message)
def error(self, message): def error(self, message) -> None:
self.logger.error(message) self.logger.error(message)
def critical(self, message): def critical(self, message) -> None:
self.logger.critical(message) self.logger.critical(message)

View File

@ -2,7 +2,7 @@
class ModelInfoDialog(QDialog): class ModelInfoDialog(QDialog):
def __init__(self, model_info, parent=None): def __init__(self, model_info, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
self.setWindowTitle("Model Information") self.setWindowTitle("Model Information")
self.setGeometry(200, 200, 600, 400) self.setGeometry(200, 200, 600, 400)
@ -21,7 +21,7 @@ def __init__(self, model_info, parent=None):
self.setLayout(layout) self.setLayout(layout)
def format_model_info(self, model_info): def format_model_info(self, model_info) -> str:
html = "<h2>Model Information</h2>" html = "<h2>Model Information</h2>"
html += f"<p><b>Architecture:</b> {model_info.get('architecture', 'N/A')}</p>" html += f"<p><b>Architecture:</b> {model_info.get('architecture', 'N/A')}</p>"
html += f"<p><b>Quantization Type:</b> {model_info.get('quantization_type', 'N/A')}</p>" html += f"<p><b>Quantization Type:</b> {model_info.get('quantization_type', 'N/A')}</p>"

View File

@ -15,7 +15,7 @@ class QuantizationThread(QThread):
error_signal = Signal(str) error_signal = Signal(str)
model_info_signal = Signal(dict) model_info_signal = Signal(dict)
def __init__(self, command, cwd, log_file): def __init__(self, command, cwd, log_file) -> None:
super().__init__() super().__init__()
self.command = command self.command = command
self.cwd = cwd self.cwd = cwd
@ -23,7 +23,7 @@ def __init__(self, command, cwd, log_file):
self.process = None self.process = None
self.model_info = {} self.model_info = {}
def run(self): def run(self) -> None:
try: try:
# Start the subprocess # Start the subprocess
self.process = subprocess.Popen( self.process = subprocess.Popen(
@ -56,7 +56,7 @@ def run(self):
except Exception as e: except Exception as e:
self.error_signal.emit(str(e)) self.error_signal.emit(str(e))
def parse_model_info(self, line): def parse_model_info(self, line) -> None:
# Parse output for model information # Parse output for model information
if "llama_model_loader: loaded meta data with" in line: if "llama_model_loader: loaded meta data with" in line:
parts = line.split() parts = line.split()
@ -77,7 +77,7 @@ def parse_model_info(self, line):
f"{quant_type}: {tensors} tensors" f"{quant_type}: {tensors} tensors"
) )
def terminate(self): def terminate(self) -> None:
# Terminate the subprocess if it's still running # Terminate the subprocess if it's still running
if self.process: if self.process:
os.kill(self.process.pid, signal.SIGTERM) os.kill(self.process.pid, signal.SIGTERM)

View File

@ -3,7 +3,7 @@
class TaskListItem(QWidget): class TaskListItem(QWidget):
def __init__(self, task_name, log_file, show_progress_bar=True, parent=None): def __init__(self, task_name, log_file, show_progress_bar=True, parent=None) -> None:
super().__init__(parent) super().__init__(parent)
self.task_name = task_name self.task_name = task_name
self.log_file = log_file self.log_file = log_file
@ -28,7 +28,7 @@ def __init__(self, task_name, log_file, show_progress_bar=True, parent=None):
self.progress_timer.timeout.connect(self.update_progress) self.progress_timer.timeout.connect(self.update_progress)
self.progress_value = 0 self.progress_value = 0
def update_status(self, status): def update_status(self, status) -> None:
self.status = status self.status = status
self.status_label.setText(status) self.status_label.setText(status)
if status == "In Progress": if status == "In Progress":
@ -43,14 +43,14 @@ def update_status(self, status):
self.progress_timer.stop() self.progress_timer.stop()
self.progress_bar.setValue(0) self.progress_bar.setValue(0)
def set_error(self): def set_error(self) -> None:
self.status = "Error" self.status = "Error"
self.status_label.setText("Error") self.status_label.setText("Error")
self.status_label.setStyleSheet("color: red;") self.status_label.setStyleSheet("color: red;")
self.progress_bar.setRange(0, 100) self.progress_bar.setRange(0, 100)
self.progress_timer.stop() self.progress_timer.stop()
def update_progress(self, value=None): def update_progress(self, value=None) -> None:
if value is not None: if value is not None:
# Update progress bar with specific value # Update progress bar with specific value
self.progress_value = value self.progress_value = value

View File

@ -1,13 +1,13 @@
from PySide6.QtWidgets import QMessageBox from PySide6.QtWidgets import QMessageBox
from Localizations import * from Localizations import ERROR_MESSAGE, ERROR, TASK_ERROR
def show_error(logger, message): def show_error(logger, message) -> None:
logger.error(ERROR_MESSAGE.format(message)) logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(None, ERROR, message) QMessageBox.critical(None, ERROR, message)
def handle_error(logger, error_message, task_item): def handle_error(logger, error_message, task_item) -> None:
logger.error(TASK_ERROR.format(error_message)) logger.error(TASK_ERROR.format(error_message))
show_error(logger, error_message) show_error(logger, error_message)
task_item.update_status(ERROR) task_item.update_status(ERROR)

View File

@ -1,5 +1,7 @@
import os import os
import sys import sys
from typing import LiteralString, TextIO, Union
import psutil import psutil
import subprocess import subprocess
import time import time
@ -41,7 +43,7 @@
from Localizations import * from Localizations import *
def show_about(self): def show_about(self) -> None:
about_text = ( about_text = (
"AutoGGUF\n\n" "AutoGGUF\n\n"
f"Version: {AUTOGGUF_VERSION}\n\n" f"Version: {AUTOGGUF_VERSION}\n\n"
@ -50,12 +52,12 @@ def show_about(self):
QMessageBox.about(self, "About AutoGGUF", about_text) QMessageBox.about(self, "About AutoGGUF", about_text)
def ensure_directory(path): def ensure_directory(path) -> None:
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
def open_file_safe(file_path, mode="r"): def open_file_safe(file_path, mode="r") -> TextIO:
encodings = ["utf-8", "latin-1", "ascii", "utf-16"] encodings = ["utf-8", "latin-1", "ascii", "utf-16"]
for encoding in encodings: for encoding in encodings:
try: try:
@ -67,7 +69,7 @@ def open_file_safe(file_path, mode="r"):
) )
def resource_path(relative_path): def resource_path(relative_path) -> Union[LiteralString, str, bytes]:
if hasattr(sys, "_MEIPASS"): if hasattr(sys, "_MEIPASS"):
# PyInstaller path # PyInstaller path
base_path = sys._MEIPASS base_path = sys._MEIPASS

View File

@ -16,7 +16,7 @@
from Localizations import * from Localizations import *
def export_lora(self): def export_lora(self) -> None:
self.logger.info(STARTING_LORA_EXPORT) self.logger.info(STARTING_LORA_EXPORT)
try: try:
model_path = self.export_lora_model.text() model_path = self.export_lora_model.text()
@ -98,7 +98,7 @@ def export_lora(self):
show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e))) show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e)))
def delete_lora_adapter_item(self, adapter_widget): def delete_lora_adapter_item(self, adapter_widget) -> None:
self.logger.info(DELETING_LORA_ADAPTER) self.logger.info(DELETING_LORA_ADAPTER)
# Find the QListWidgetItem containing the adapter_widget # Find the QListWidgetItem containing the adapter_widget
for i in range(self.export_lora_adapters.count()): for i in range(self.export_lora_adapters.count()):
@ -108,14 +108,14 @@ def delete_lora_adapter_item(self, adapter_widget):
break break
def browse_export_lora_model(self): def browse_export_lora_model(self) -> None:
self.logger.info(BROWSING_FOR_EXPORT_LORA_MODEL_FILE) self.logger.info(BROWSING_FOR_EXPORT_LORA_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES) model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES)
if model_file: if model_file:
self.export_lora_model.setText(os.path.abspath(model_file)) self.export_lora_model.setText(os.path.abspath(model_file))
def browse_export_lora_output(self): def browse_export_lora_output(self) -> None:
self.logger.info(BROWSING_FOR_EXPORT_LORA_OUTPUT_FILE) self.logger.info(BROWSING_FOR_EXPORT_LORA_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName( output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES self, SELECT_OUTPUT_FILE, "", GGUF_FILES
@ -124,7 +124,7 @@ def browse_export_lora_output(self):
self.export_lora_output.setText(os.path.abspath(output_file)) self.export_lora_output.setText(os.path.abspath(output_file))
def add_lora_adapter(self): def add_lora_adapter(self) -> None:
self.logger.info(ADDING_LORA_ADAPTER) self.logger.info(ADDING_LORA_ADAPTER)
adapter_path, _ = QFileDialog.getOpenFileName( adapter_path, _ = QFileDialog.getOpenFileName(
self, SELECT_LORA_ADAPTER_FILE, "", LORA_FILES self, SELECT_LORA_ADAPTER_FILE, "", LORA_FILES
@ -154,7 +154,7 @@ def add_lora_adapter(self):
self.export_lora_adapters.setItemWidget(list_item, adapter_widget) self.export_lora_adapters.setItemWidget(list_item, adapter_widget)
def convert_lora(self): def convert_lora(self) -> None:
self.logger.info(STARTING_LORA_CONVERSION) self.logger.info(STARTING_LORA_CONVERSION)
try: try:
lora_input_path = self.lora_input.text() lora_input_path = self.lora_input.text()

View File

@ -5,30 +5,30 @@
from PySide6.QtCore import QTimer from PySide6.QtCore import QTimer
from PySide6.QtWidgets import QApplication from PySide6.QtWidgets import QApplication
from AutoGGUF import AutoGGUF from AutoGGUF import AutoGGUF
from flask import Flask, jsonify from flask import Flask, Response, jsonify
server = Flask(__name__) server = Flask(__name__)
def main(): def main() -> None:
@server.route("/v1/models", methods=["GET"]) @server.route("/v1/models", methods=["GET"])
def models(): def models() -> Response:
if window: if window:
return jsonify({"models": window.get_models_data()}) return jsonify({"models": window.get_models_data()})
return jsonify({"models": []}) return jsonify({"models": []})
@server.route("/v1/tasks", methods=["GET"]) @server.route("/v1/tasks", methods=["GET"])
def tasks(): def tasks() -> Response:
if window: if window:
return jsonify({"tasks": window.get_tasks_data()}) return jsonify({"tasks": window.get_tasks_data()})
return jsonify({"tasks": []}) return jsonify({"tasks": []})
@server.route("/v1/health", methods=["GET"]) @server.route("/v1/health", methods=["GET"])
def ping(): def ping() -> Response:
return jsonify({"status": "alive"}) return jsonify({"status": "alive"})
@server.route("/v1/backends", methods=["GET"]) @server.route("/v1/backends", methods=["GET"])
def get_backends(): def get_backends() -> Response:
backends = [] backends = []
for i in range(window.backend_combo.count()): for i in range(window.backend_combo.count()):
backends.append( backends.append(
@ -40,7 +40,7 @@ def get_backends():
return jsonify({"backends": backends}) return jsonify({"backends": backends})
@server.route("/v1/plugins", methods=["GET"]) @server.route("/v1/plugins", methods=["GET"])
def get_plugins(): def get_plugins() -> Response:
if window: if window:
return jsonify( return jsonify(
{ {
@ -57,7 +57,7 @@ def get_plugins():
) )
return jsonify({"plugins": []}) return jsonify({"plugins": []})
def run_flask(): def run_flask() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled": if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
server.run( server.run(
host="0.0.0.0", host="0.0.0.0",

View File

@ -4,7 +4,7 @@
from Localizations import * from Localizations import *
def save_preset(self): def save_preset(self) -> None:
self.logger.info(SAVING_PRESET) self.logger.info(SAVING_PRESET)
preset = { preset = {
"quant_types": [item.text() for item in self.quant_type.selectedItems()], "quant_types": [item.text() for item in self.quant_type.selectedItems()],
@ -33,7 +33,7 @@ def save_preset(self):
self.logger.info(PRESET_SAVED_TO.format(file_name)) self.logger.info(PRESET_SAVED_TO.format(file_name))
def load_preset(self): def load_preset(self) -> None:
self.logger.info(LOADING_PRESET) self.logger.info(LOADING_PRESET)
file_name, _ = QFileDialog.getOpenFileName(self, LOAD_PRESET, "", JSON_FILES) file_name, _ = QFileDialog.getOpenFileName(self, LOAD_PRESET, "", JSON_FILES)
if file_name: if file_name:

View File

@ -5,12 +5,12 @@
from error_handling import show_error from error_handling import show_error
def update_model_info(logger, self, model_info): def update_model_info(logger, self, model_info) -> None:
logger.debug(UPDATING_MODEL_INFO.format(model_info)) logger.debug(UPDATING_MODEL_INFO.format(model_info))
pass pass
def update_system_info(self): def update_system_info(self) -> None:
ram = psutil.virtual_memory() ram = psutil.virtual_memory()
cpu = psutil.cpu_percent() cpu = psutil.cpu_percent()
@ -28,7 +28,7 @@ def update_system_info(self):
self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu)) self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu))
def animate_bar(self, bar, target_value): def animate_bar(self, bar, target_value) -> None:
current_value = bar.value() current_value = bar.value()
difference = target_value - current_value difference = target_value - current_value
@ -42,7 +42,7 @@ def animate_bar(self, bar, target_value):
timer.start(10) # Adjust the interval for animation speed timer.start(10) # Adjust the interval for animation speed
def _animate_step(bar, target_value, step, timer): def _animate_step(bar, target_value, step, timer) -> None:
current_value = bar.value() current_value = bar.value()
new_value = current_value + step new_value = current_value + step
@ -55,11 +55,11 @@ def _animate_step(bar, target_value, step, timer):
bar.setValue(new_value) bar.setValue(new_value)
def update_download_progress(self, progress): def update_download_progress(self, progress) -> None:
self.download_progress.setValue(progress) self.download_progress.setValue(progress)
def update_cuda_backends(self): def update_cuda_backends(self) -> None:
self.logger.debug(UPDATING_CUDA_BACKENDS) self.logger.debug(UPDATING_CUDA_BACKENDS)
self.backend_combo_cuda.clear() self.backend_combo_cuda.clear()
llama_bin = os.path.abspath("llama_bin") llama_bin = os.path.abspath("llama_bin")
@ -77,23 +77,23 @@ def update_cuda_backends(self):
self.backend_combo_cuda.setEnabled(True) self.backend_combo_cuda.setEnabled(True)
def update_threads_spinbox(self, value): def update_threads_spinbox(self, value) -> None:
self.threads_spinbox.setValue(value) self.threads_spinbox.setValue(value)
def update_threads_slider(self, value): def update_threads_slider(self, value) -> None:
self.threads_slider.setValue(value) self.threads_slider.setValue(value)
def update_gpu_offload_spinbox(self, value): def update_gpu_offload_spinbox(self, value) -> None:
self.gpu_offload_spinbox.setValue(value) self.gpu_offload_spinbox.setValue(value)
def update_gpu_offload_slider(self, value): def update_gpu_offload_slider(self, value) -> None:
self.gpu_offload_slider.setValue(value) self.gpu_offload_slider.setValue(value)
def update_cuda_option(self): def update_cuda_option(self) -> None:
self.logger.debug(UPDATING_CUDA_OPTIONS) self.logger.debug(UPDATING_CUDA_OPTIONS)
asset = self.asset_combo.currentData() asset = self.asset_combo.currentData()
@ -113,7 +113,7 @@ def update_cuda_option(self):
self.update_cuda_backends() self.update_cuda_backends()
def update_assets(self): def update_assets(self) -> None:
self.logger.debug(UPDATING_ASSET_LIST) self.logger.debug(UPDATING_ASSET_LIST)
self.asset_combo.clear() self.asset_combo.clear()
release = self.release_combo.currentData() release = self.release_combo.currentData()
@ -128,6 +128,6 @@ def update_assets(self):
self.update_cuda_option() self.update_cuda_option()
def update_base_model_visibility(self, index): def update_base_model_visibility(self, index) -> None:
is_gguf = self.lora_output_type_combo.itemText(index) == "GGUF" is_gguf = self.lora_output_type_combo.itemText(index) == "GGUF"
self.base_model_wrapper.setVisible(is_gguf) self.base_model_wrapper.setVisible(is_gguf)

View File

@ -1,3 +1,5 @@
from typing import Any, Dict, List, Union
from PySide6.QtCore import Qt from PySide6.QtCore import Qt
from PySide6.QtWidgets import QFileDialog from PySide6.QtWidgets import QFileDialog
@ -9,7 +11,7 @@
from imports_and_globals import ensure_directory from imports_and_globals import ensure_directory
def get_models_data(self): def get_models_data(self) -> list[dict[str, Union[str, Any]]]:
models = [] models = []
root = self.model_tree.invisibleRootItem() root = self.model_tree.invisibleRootItem()
child_count = root.childCount() child_count = root.childCount()
@ -22,7 +24,7 @@ def get_models_data(self):
return models return models
def get_tasks_data(self): def get_tasks_data(self) -> list[dict[str, Union[int, Any]]]:
tasks = [] tasks = []
for i in range(self.task_list.count()): for i in range(self.task_list.count()):
item = self.task_list.item(i) item = self.task_list.item(i)
@ -43,7 +45,7 @@ def get_tasks_data(self):
return tasks return tasks
def browse_models(self): def browse_models(self) -> None:
self.logger.info(BROWSING_FOR_MODELS_DIRECTORY) self.logger.info(BROWSING_FOR_MODELS_DIRECTORY)
models_path = QFileDialog.getExistingDirectory(self, SELECT_MODELS_DIRECTORY) models_path = QFileDialog.getExistingDirectory(self, SELECT_MODELS_DIRECTORY)
if models_path: if models_path:
@ -52,7 +54,7 @@ def browse_models(self):
self.load_models() self.load_models()
def browse_output(self): def browse_output(self) -> None:
self.logger.info(BROWSING_FOR_OUTPUT_DIRECTORY) self.logger.info(BROWSING_FOR_OUTPUT_DIRECTORY)
output_path = QFileDialog.getExistingDirectory(self, SELECT_OUTPUT_DIRECTORY) output_path = QFileDialog.getExistingDirectory(self, SELECT_OUTPUT_DIRECTORY)
if output_path: if output_path:
@ -60,7 +62,7 @@ def browse_output(self):
ensure_directory(output_path) ensure_directory(output_path)
def browse_logs(self): def browse_logs(self) -> None:
self.logger.info(BROWSING_FOR_LOGS_DIRECTORY) self.logger.info(BROWSING_FOR_LOGS_DIRECTORY)
logs_path = QFileDialog.getExistingDirectory(self, SELECT_LOGS_DIRECTORY) logs_path = QFileDialog.getExistingDirectory(self, SELECT_LOGS_DIRECTORY)
if logs_path: if logs_path:
@ -68,7 +70,7 @@ def browse_logs(self):
ensure_directory(logs_path) ensure_directory(logs_path)
def browse_imatrix(self): def browse_imatrix(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_FILE) self.logger.info(BROWSING_FOR_IMATRIX_FILE)
imatrix_file, _ = QFileDialog.getOpenFileName( imatrix_file, _ = QFileDialog.getOpenFileName(
self, SELECT_IMATRIX_FILE, "", DAT_FILES self, SELECT_IMATRIX_FILE, "", DAT_FILES
@ -77,7 +79,7 @@ def browse_imatrix(self):
self.imatrix.setText(os.path.abspath(imatrix_file)) self.imatrix.setText(os.path.abspath(imatrix_file))
def browse_lora_input(self): def browse_lora_input(self) -> None:
self.logger.info(BROWSING_FOR_LORA_INPUT_DIRECTORY) self.logger.info(BROWSING_FOR_LORA_INPUT_DIRECTORY)
lora_input_path = QFileDialog.getExistingDirectory( lora_input_path = QFileDialog.getExistingDirectory(
self, SELECT_LORA_INPUT_DIRECTORY self, SELECT_LORA_INPUT_DIRECTORY
@ -87,7 +89,7 @@ def browse_lora_input(self):
ensure_directory(lora_input_path) ensure_directory(lora_input_path)
def browse_lora_output(self): def browse_lora_output(self) -> None:
self.logger.info(BROWSING_FOR_LORA_OUTPUT_FILE) self.logger.info(BROWSING_FOR_LORA_OUTPUT_FILE)
lora_output_file, _ = QFileDialog.getSaveFileName( lora_output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_LORA_OUTPUT_FILE, "", GGUF_AND_BIN_FILES self, SELECT_LORA_OUTPUT_FILE, "", GGUF_AND_BIN_FILES
@ -96,7 +98,7 @@ def browse_lora_output(self):
self.lora_output.setText(os.path.abspath(lora_output_file)) self.lora_output.setText(os.path.abspath(lora_output_file))
def download_llama_cpp(self): def download_llama_cpp(self) -> None:
self.logger.info(STARTING_LLAMACPP_DOWNLOAD) self.logger.info(STARTING_LLAMACPP_DOWNLOAD)
asset = self.asset_combo.currentData() asset = self.asset_combo.currentData()
if not asset: if not asset:
@ -118,7 +120,7 @@ def download_llama_cpp(self):
self.download_progress.setValue(0) self.download_progress.setValue(0)
def refresh_releases(self): def refresh_releases(self) -> None:
self.logger.info(REFRESHING_LLAMACPP_RELEASES) self.logger.info(REFRESHING_LLAMACPP_RELEASES)
try: try:
response = requests.get( response = requests.get(