refactor: add type hints

This commit is contained in:
BuildTools 2024-08-22 21:56:37 -07:00
parent d4be39a22c
commit 6e424462ab
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
17 changed files with 135 additions and 132 deletions

View File

@ -5,6 +5,7 @@
from functools import partial
from datetime import datetime
from typing import Tuple, Dict
from dotenv import load_dotenv
from PySide6.QtCore import *
from PySide6.QtGui import *
@ -33,7 +34,8 @@
class AutoGGUF(QMainWindow):
def __init__(self, args):
def __init__(self, args: List[str]) -> None:
super().__init__()
self.logger = Logger("AutoGGUF", "logs")
@ -785,7 +787,7 @@ def __init__(self, args):
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
def load_plugins(self):
def load_plugins(self) -> Dict[str, Dict[str, Any]]:
plugins = {}
plugin_dir = "plugins"
@ -844,7 +846,7 @@ def load_plugins(self):
return plugins
def apply_plugins(self):
def apply_plugins(self) -> None:
if not self.plugins:
self.logger.info(NO_PLUGINS_LOADED)
return
@ -859,7 +861,7 @@ def apply_plugins(self):
if hasattr(plugin_instance, "init") and callable(plugin_instance.init):
plugin_instance.init(self)
def check_for_updates(self):
def check_for_updates(self) -> None:
try:
response = requests.get(
"https://api.github.com/repos/leafspark/AutoGGUF/releases/latest"
@ -874,7 +876,7 @@ def check_for_updates(self):
except requests.exceptions.RequestException as e:
self.logger.warning(f"{ERROR_CHECKING_FOR_UPDATES} {e}")
def prompt_for_update(self, release):
def prompt_for_update(self, release) -> None:
update_message = QMessageBox()
update_message.setIcon(QMessageBox.Information)
update_message.setWindowTitle(UPDATE_AVAILABLE)
@ -887,7 +889,7 @@ def prompt_for_update(self, release):
if update_message.exec() == QMessageBox.StandardButton.Yes:
QDesktopServices.openUrl(QUrl(release["html_url"]))
def keyPressEvent(self, event):
def keyPressEvent(self, event) -> None:
if event.modifiers() == Qt.ControlModifier:
if (
event.key() == Qt.Key_Equal
@ -899,7 +901,7 @@ def keyPressEvent(self, event):
self.reset_size()
super().keyPressEvent(event)
def resize_window(self, larger):
def resize_window(self, larger) -> None:
factor = 1.1 if larger else 1 / 1.1
current_width = self.width()
current_height = self.height()
@ -907,10 +909,10 @@ def resize_window(self, larger):
new_height = int(current_height * factor)
self.resize(new_width, new_height)
def reset_size(self):
def reset_size(self) -> None:
self.resize(self.default_width, self.default_height)
def parse_resolution(self):
def parse_resolution(self) -> Tuple[int, int]:
res = os.environ.get("AUTOGGUF_RESOLUTION", "1650x1100")
try:
width, height = map(int, res.split("x"))
@ -920,14 +922,14 @@ def parse_resolution(self):
except (ValueError, AttributeError):
return 1650, 1100
def resizeEvent(self, event):
def resizeEvent(self, event) -> None:
super().resizeEvent(event)
path = QPainterPath()
path.addRoundedRect(self.rect(), 10, 10)
mask = QRegion(path.toFillPolygon().toPolygon())
self.setMask(mask)
def refresh_backends(self):
def refresh_backends(self) -> None:
self.logger.info(REFRESHING_BACKENDS)
llama_bin = os.path.abspath("llama_bin")
os.makedirs(llama_bin, exist_ok=True)
@ -951,7 +953,7 @@ def refresh_backends(self):
self.backend_combo.setEnabled(False)
self.logger.info(FOUND_VALID_BACKENDS.format(len(valid_backends)))
def save_task_preset(self, task_item):
def save_task_preset(self, task_item) -> None:
self.logger.info(SAVING_TASK_PRESET.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
@ -971,7 +973,7 @@ def save_task_preset(self, task_item):
)
break
def browse_base_model(self):
def browse_base_model(self) -> None:
self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message
base_model_folder = QFileDialog.getExistingDirectory(
self, SELECT_BASE_MODEL_FOLDER
@ -979,13 +981,13 @@ def browse_base_model(self):
if base_model_folder:
self.base_model_path.setText(os.path.abspath(base_model_folder))
def browse_hf_model_input(self):
def browse_hf_model_input(self) -> None:
self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir))
def browse_hf_outfile(self):
def browse_hf_outfile(self) -> None:
self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT)
outfile, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES
@ -993,7 +995,7 @@ def browse_hf_outfile(self):
if outfile:
self.hf_outfile.setText(os.path.abspath(outfile))
def convert_hf_to_gguf(self):
def convert_hf_to_gguf(self) -> None:
self.logger.info(STARTING_HF_TO_GGUF_CONVERSION)
try:
model_dir = self.hf_model_input.text()
@ -1063,7 +1065,7 @@ def convert_hf_to_gguf(self):
show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED)
def restart_task(self, task_item):
def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
@ -1083,7 +1085,7 @@ def restart_task(self, task_item):
task_item.update_status(IN_PROGRESS)
break
def lora_conversion_finished(self, thread, input_path, output_path):
def lora_conversion_finished(self, thread, input_path, output_path) -> None:
self.logger.info(LORA_CONVERSION_FINISHED)
if thread in self.quant_threads:
self.quant_threads.remove(thread)
@ -1099,7 +1101,7 @@ def lora_conversion_finished(self, thread, input_path, output_path):
except Exception as e:
self.logger.error(ERROR_MOVING_LORA_FILE.format(str(e)))
def download_finished(self, extract_dir):
def download_finished(self, extract_dir) -> None:
self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir))
self.download_button.setEnabled(True)
self.download_progress.setValue(100)
@ -1136,7 +1138,7 @@ def download_finished(self, extract_dir):
if index >= 0:
self.backend_combo.setCurrentIndex(index)
def extract_cuda_files(self, extract_dir, destination):
def extract_cuda_files(self, extract_dir, destination) -> None:
self.logger.info(EXTRACTING_CUDA_FILES.format(extract_dir, destination))
for root, dirs, files in os.walk(extract_dir):
for file in files:
@ -1145,7 +1147,7 @@ def extract_cuda_files(self, extract_dir, destination):
dest_path = os.path.join(destination, file)
shutil.copy2(source_path, dest_path)
def download_error(self, error_message):
def download_error(self, error_message) -> None:
self.logger.error(DOWNLOAD_ERROR.format(error_message))
self.download_button.setEnabled(True)
self.download_progress.setValue(0)
@ -1158,7 +1160,7 @@ def download_error(self, error_message):
if os.path.exists(partial_file):
os.remove(partial_file)
def show_task_context_menu(self, position):
def show_task_context_menu(self, position) -> None:
self.logger.debug(SHOWING_TASK_CONTEXT_MENU)
item = self.task_list.itemAt(position)
if item is not None:
@ -1185,7 +1187,7 @@ def show_task_context_menu(self, position):
context_menu.exec(self.task_list.viewport().mapToGlobal(position))
def show_task_properties(self, item):
def show_task_properties(self, item) -> None:
self.logger.debug(SHOWING_PROPERTIES_FOR_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
@ -1194,12 +1196,12 @@ def show_task_properties(self, item):
model_info_dialog.exec()
break
def toggle_gpu_offload_auto(self, state):
def toggle_gpu_offload_auto(self, state) -> None:
is_auto = state == Qt.CheckState.Checked
self.gpu_offload_slider.setEnabled(not is_auto)
self.gpu_offload_spinbox.setEnabled(not is_auto)
def cancel_task_by_item(self, item):
def cancel_task_by_item(self, item) -> None:
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
@ -1208,11 +1210,11 @@ def cancel_task_by_item(self, item):
self.quant_threads.remove(thread)
break
def cancel_task(self, item):
def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text()))
self.cancel_task_by_item(item)
def delete_task(self, item):
def delete_task(self, item) -> None:
self.logger.info(DELETING_TASK.format(item.text()))
# Cancel the task first
@ -1233,12 +1235,12 @@ def delete_task(self, item):
if task_item:
task_item.deleteLater()
def create_label(self, text, tooltip):
def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text)
label.setToolTip(tooltip)
return label
def verify_gguf(self, file_path):
def verify_gguf(self, file_path) -> bool:
try:
with open(file_path, "rb") as f:
magic = f.read(4)
@ -1246,7 +1248,7 @@ def verify_gguf(self, file_path):
except Exception:
return False
def load_models(self):
def load_models(self) -> None:
self.logger.info(LOADING_MODELS)
models_dir = self.models_input.text()
ensure_directory(models_dir)
@ -1322,7 +1324,7 @@ def load_models(self):
CONCATENATED_FILES_FOUND.format(len(concatenated_models))
)
def add_model_to_tree(self, model):
def add_model_to_tree(self, model) -> QTreeWidgetItem:
item = QTreeWidgetItem(self.model_tree)
item.setText(0, model)
if hasattr(self, "imported_models") and model in [
@ -1337,7 +1339,7 @@ def add_model_to_tree(self, model):
item.setData(0, Qt.ItemDataRole.UserRole, model)
return item
def validate_quantization_inputs(self):
def validate_quantization_inputs(self) -> None:
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
errors = []
if not self.backend_combo.currentData():
@ -1354,7 +1356,7 @@ def validate_quantization_inputs(self):
if errors:
raise ValueError("\n".join(errors))
def add_kv_override(self, override_string=None):
def add_kv_override(self, override_string=None) -> None:
entry = KVOverrideEntry()
entry.deleted.connect(self.remove_kv_override)
if override_string:
@ -1366,12 +1368,12 @@ def add_kv_override(self, override_string=None):
self.kv_override_layout.addWidget(entry)
self.kv_override_entries.append(entry)
def remove_kv_override(self, entry):
def remove_kv_override(self, entry) -> None:
self.kv_override_layout.removeWidget(entry)
self.kv_override_entries.remove(entry)
entry.deleteLater()
def quantize_model(self):
def quantize_model(self) -> None:
self.logger.info(STARTING_MODEL_QUANTIZATION)
try:
self.validate_quantization_inputs()
@ -1539,7 +1541,7 @@ def quantize_model(self):
except Exception as e:
show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e)))
def parse_progress(self, line, task_item):
def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*\].*", line)
if match:
@ -1548,13 +1550,13 @@ def parse_progress(self, line, task_item):
progress = int((current / total) * 100)
task_item.update_progress(progress)
def task_finished(self, thread, task_item):
def task_finished(self, thread, task_item) -> None:
self.logger.info(TASK_FINISHED.format(thread.log_file))
if thread in self.quant_threads:
self.quant_threads.remove(thread)
task_item.update_status(COMPLETED)
def show_task_details(self, item):
def show_task_details(self, item) -> None:
self.logger.debug(SHOWING_TASK_DETAILS_FOR.format(item.text()))
task_item = self.task_list.itemWidget(item)
if task_item:
@ -1582,7 +1584,7 @@ def show_task_details(self, item):
log_dialog.exec()
def import_model(self):
def import_model(self) -> None:
self.logger.info(IMPORTING_MODEL)
file_path, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES
@ -1609,13 +1611,13 @@ def import_model(self):
self.load_models()
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))
def browse_imatrix_datafile(self):
def browse_imatrix_datafile(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
if datafile:
self.imatrix_datafile.setText(os.path.abspath(datafile))
def browse_imatrix_model(self):
def browse_imatrix_model(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_FILE, "", GGUF_FILES
@ -1623,7 +1625,7 @@ def browse_imatrix_model(self):
if model_file:
self.imatrix_model.setText(os.path.abspath(model_file))
def browse_imatrix_output(self):
def browse_imatrix_output(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", DAT_FILES
@ -1631,7 +1633,7 @@ def browse_imatrix_output(self):
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))
def generate_imatrix(self):
def generate_imatrix(self) -> None:
self.logger.info(STARTING_IMATRIX_GENERATION)
try:
backend_path = self.backend_combo.currentData()
@ -1692,7 +1694,7 @@ def generate_imatrix(self):
show_error(self.logger, ERROR_STARTING_IMATRIX_GENERATION.format(str(e)))
self.logger.info(IMATRIX_GENERATION_TASK_STARTED)
def closeEvent(self, event: QCloseEvent):
def closeEvent(self, event: QCloseEvent) -> None:
self.logger.info(APPLICATION_CLOSING)
if self.quant_threads:
reply = QMessageBox.question(

View File

@ -1,12 +1,9 @@
from PySide6.QtCore import QPoint
from PySide6.QtGui import QPixmap
from PySide6.QtWidgets import QHBoxLayout, QLabel, QMenuBar, QPushButton, QWidget
from imports_and_globals import resource_path
class CustomTitleBar(QWidget):
def __init__(self, parent=None):
def __init__(self, parent=None) -> None:
super().__init__(parent)
self.parent = parent
layout = QHBoxLayout(self)
@ -55,11 +52,11 @@ def __init__(self, parent=None):
self.start = QPoint(0, 0)
self.pressing = False
def mousePressEvent(self, event):
def mousePressEvent(self, event) -> None:
self.start = self.mapToGlobal(event.pos())
self.pressing = True
def mouseMoveEvent(self, event):
def mouseMoveEvent(self, event) -> None:
if self.pressing:
end = self.mapToGlobal(event.pos())
movement = end - self.start
@ -71,5 +68,5 @@ def mouseMoveEvent(self, event):
)
self.start = end
def mouseReleaseEvent(self, event):
def mouseReleaseEvent(self, event) -> None:
self.pressing = False

View File

@ -10,12 +10,12 @@ class DownloadThread(QThread):
finished_signal = Signal(str)
error_signal = Signal(str)
def __init__(self, url, save_path):
def __init__(self, url, save_path) -> None:
super().__init__()
self.url = url
self.save_path = save_path
def run(self):
def run(self) -> None:
try:
response = requests.get(self.url, stream=True)
response.raise_for_status()

View File

@ -28,7 +28,7 @@
class SimpleGraph(QGraphicsView):
def __init__(self, title, parent=None):
def __init__(self, title, parent=None) -> None:
super().__init__(parent)
self.setScene(QGraphicsScene(self))
self.setRenderHint(QPainter.RenderHint.Antialiasing)
@ -37,7 +37,7 @@ def __init__(self, title, parent=None):
self.title = title
self.data = []
def update_data(self, data):
def update_data(self, data) -> None:
self.data = data
self.scene().clear()
if not self.data:
@ -65,13 +65,13 @@ def update_data(self, data):
line.setPen(path)
self.scene().addItem(line)
def resizeEvent(self, event):
def resizeEvent(self, event) -> None:
super().resizeEvent(event)
self.update_data(self.data)
class GPUMonitor(QWidget):
def __init__(self, parent=None):
def __init__(self, parent=None) -> None:
super().__init__(parent)
self.setMinimumHeight(30)
self.setMaximumHeight(30)
@ -125,17 +125,17 @@ def __init__(self, parent=None):
if not self.handles:
self.gpu_label.setText(NO_GPU_DETECTED)
def check_for_amd_gpu(self):
def check_for_amd_gpu(self) -> None:
# This is a placeholder. Implementing AMD GPU detection would require
# platform-specific methods or additional libraries.
self.gpu_label.setText(AMD_GPU_NOT_SUPPORTED)
def change_gpu(self, index):
def change_gpu(self, index) -> None:
self.current_gpu = index
self.gpu_data.clear()
self.vram_data.clear()
def update_gpu_info(self):
def update_gpu_info(self) -> None:
if self.handles:
try:
handle = self.handles[self.current_gpu]
@ -165,11 +165,11 @@ def update_gpu_info(self):
self.gpu_bar.setValue(0)
self.gpu_label.setText(GPU_USAGE_FORMAT.format(0, 0, 0, 0))
def mouseDoubleClickEvent(self, event):
def mouseDoubleClickEvent(self, event) -> None:
if self.handles:
self.show_detailed_stats()
def show_detailed_stats(self):
def show_detailed_stats(self) -> None:
dialog = QDialog(self)
dialog.setWindowTitle(GPU_DETAILS)
dialog.setMinimumSize(800, 600)
@ -194,7 +194,7 @@ def show_detailed_stats(self):
gpu_graph = SimpleGraph(GPU_USAGE_OVER_TIME)
vram_graph = SimpleGraph(VRAM_USAGE_OVER_TIME)
def update_graph_data():
def update_graph_data() -> None:
gpu_graph.update_data(self.gpu_data)
vram_graph.update_data(self.vram_data)
@ -207,7 +207,7 @@ def update_graph_data():
dialog.exec()
def closeEvent(self, event):
def closeEvent(self, event) -> None:
if self.handles:
pynvml.nvmlShutdown()
super().closeEvent(event)

View File

@ -11,7 +11,7 @@
class KVOverrideEntry(QWidget):
deleted = Signal(QWidget)
def __init__(self, parent=None):
def __init__(self, parent=None) -> None:
super().__init__(parent)
layout = QHBoxLayout(self)
layout.setContentsMargins(0, 0, 0, 0)
@ -42,12 +42,12 @@ def __init__(self, parent=None):
# Initialize validator
self.update_validator(self.type_combo.currentText())
def delete_clicked(self):
def delete_clicked(self) -> None:
self.deleted.emit(self)
def get_override_string(
self, model_name=None, quant_type=None, output_path=None
): # Add arguments
) -> str: # Add arguments
key = self.key_input.text()
type_ = self.type_combo.currentText()
value = self.value_input.text()
@ -79,11 +79,11 @@ def get_override_string(
return f"{key}={type_}:{value}"
def get_raw_override_string(self):
def get_raw_override_string(self) -> str:
# Return the raw override string with placeholders intact
return f"{self.key_input.text()}={self.type_combo.currentText()}:{self.value_input.text()}"
def update_validator(self, type_):
def update_validator(self, type_) -> None:
if type_ == "int":
self.value_input.setValidator(QIntValidator())
elif type_ == "float":

View File

@ -6325,7 +6325,7 @@ def __init__(self):
# fmt: on
def set_language(lang_code):
def set_language(lang_code) -> None:
# Globals
global WINDOW_TITLE, RAM_USAGE, CPU_USAGE, BACKEND, REFRESH_BACKENDS, MODELS_PATH, OUTPUT_PATH, LOGS_PATH
global BROWSE, AVAILABLE_MODELS, QUANTIZATION_TYPE, ALLOW_REQUANTIZE, LEAVE_OUTPUT_TENSOR, PURE, IMATRIX

View File

@ -5,7 +5,7 @@
class Logger:
def __init__(self, name, log_dir):
def __init__(self, name, log_dir) -> None:
self.logger = logging.getLogger(name)
self.logger.setLevel(logging.DEBUG)
@ -34,17 +34,17 @@ def __init__(self, name, log_dir):
self.logger.addHandler(console_handler)
self.logger.addHandler(file_handler)
def debug(self, message):
def debug(self, message) -> None:
self.logger.debug(message)
def info(self, message):
def info(self, message) -> None:
self.logger.info(message)
def warning(self, message):
def warning(self, message) -> None:
self.logger.warning(message)
def error(self, message):
def error(self, message) -> None:
self.logger.error(message)
def critical(self, message):
def critical(self, message) -> None:
self.logger.critical(message)

View File

@ -2,7 +2,7 @@
class ModelInfoDialog(QDialog):
def __init__(self, model_info, parent=None):
def __init__(self, model_info, parent=None) -> None:
super().__init__(parent)
self.setWindowTitle("Model Information")
self.setGeometry(200, 200, 600, 400)
@ -21,7 +21,7 @@ def __init__(self, model_info, parent=None):
self.setLayout(layout)
def format_model_info(self, model_info):
def format_model_info(self, model_info) -> str:
html = "<h2>Model Information</h2>"
html += f"<p><b>Architecture:</b> {model_info.get('architecture', 'N/A')}</p>"
html += f"<p><b>Quantization Type:</b> {model_info.get('quantization_type', 'N/A')}</p>"

View File

@ -15,7 +15,7 @@ class QuantizationThread(QThread):
error_signal = Signal(str)
model_info_signal = Signal(dict)
def __init__(self, command, cwd, log_file):
def __init__(self, command, cwd, log_file) -> None:
super().__init__()
self.command = command
self.cwd = cwd
@ -23,7 +23,7 @@ def __init__(self, command, cwd, log_file):
self.process = None
self.model_info = {}
def run(self):
def run(self) -> None:
try:
# Start the subprocess
self.process = subprocess.Popen(
@ -56,7 +56,7 @@ def run(self):
except Exception as e:
self.error_signal.emit(str(e))
def parse_model_info(self, line):
def parse_model_info(self, line) -> None:
# Parse output for model information
if "llama_model_loader: loaded meta data with" in line:
parts = line.split()
@ -77,7 +77,7 @@ def parse_model_info(self, line):
f"{quant_type}: {tensors} tensors"
)
def terminate(self):
def terminate(self) -> None:
# Terminate the subprocess if it's still running
if self.process:
os.kill(self.process.pid, signal.SIGTERM)

View File

@ -3,7 +3,7 @@
class TaskListItem(QWidget):
def __init__(self, task_name, log_file, show_progress_bar=True, parent=None):
def __init__(self, task_name, log_file, show_progress_bar=True, parent=None) -> None:
super().__init__(parent)
self.task_name = task_name
self.log_file = log_file
@ -28,7 +28,7 @@ def __init__(self, task_name, log_file, show_progress_bar=True, parent=None):
self.progress_timer.timeout.connect(self.update_progress)
self.progress_value = 0
def update_status(self, status):
def update_status(self, status) -> None:
self.status = status
self.status_label.setText(status)
if status == "In Progress":
@ -43,14 +43,14 @@ def update_status(self, status):
self.progress_timer.stop()
self.progress_bar.setValue(0)
def set_error(self):
def set_error(self) -> None:
self.status = "Error"
self.status_label.setText("Error")
self.status_label.setStyleSheet("color: red;")
self.progress_bar.setRange(0, 100)
self.progress_timer.stop()
def update_progress(self, value=None):
def update_progress(self, value=None) -> None:
if value is not None:
# Update progress bar with specific value
self.progress_value = value

View File

@ -1,13 +1,13 @@
from PySide6.QtWidgets import QMessageBox
from Localizations import *
from Localizations import ERROR_MESSAGE, ERROR, TASK_ERROR
def show_error(logger, message):
def show_error(logger, message) -> None:
logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(None, ERROR, message)
def handle_error(logger, error_message, task_item):
def handle_error(logger, error_message, task_item) -> None:
logger.error(TASK_ERROR.format(error_message))
show_error(logger, error_message)
task_item.update_status(ERROR)

View File

@ -1,5 +1,7 @@
import os
import sys
from typing import LiteralString, TextIO, Union
import psutil
import subprocess
import time
@ -41,7 +43,7 @@
from Localizations import *
def show_about(self):
def show_about(self) -> None:
about_text = (
"AutoGGUF\n\n"
f"Version: {AUTOGGUF_VERSION}\n\n"
@ -50,12 +52,12 @@ def show_about(self):
QMessageBox.about(self, "About AutoGGUF", about_text)
def ensure_directory(path):
def ensure_directory(path) -> None:
if not os.path.exists(path):
os.makedirs(path)
def open_file_safe(file_path, mode="r"):
def open_file_safe(file_path, mode="r") -> TextIO:
encodings = ["utf-8", "latin-1", "ascii", "utf-16"]
for encoding in encodings:
try:
@ -67,7 +69,7 @@ def open_file_safe(file_path, mode="r"):
)
def resource_path(relative_path):
def resource_path(relative_path) -> Union[LiteralString, str, bytes]:
if hasattr(sys, "_MEIPASS"):
# PyInstaller path
base_path = sys._MEIPASS

View File

@ -16,7 +16,7 @@
from Localizations import *
def export_lora(self):
def export_lora(self) -> None:
self.logger.info(STARTING_LORA_EXPORT)
try:
model_path = self.export_lora_model.text()
@ -98,7 +98,7 @@ def export_lora(self):
show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e)))
def delete_lora_adapter_item(self, adapter_widget):
def delete_lora_adapter_item(self, adapter_widget) -> None:
self.logger.info(DELETING_LORA_ADAPTER)
# Find the QListWidgetItem containing the adapter_widget
for i in range(self.export_lora_adapters.count()):
@ -108,14 +108,14 @@ def delete_lora_adapter_item(self, adapter_widget):
break
def browse_export_lora_model(self):
def browse_export_lora_model(self) -> None:
self.logger.info(BROWSING_FOR_EXPORT_LORA_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(self, SELECT_MODEL_FILE, "", GGUF_FILES)
if model_file:
self.export_lora_model.setText(os.path.abspath(model_file))
def browse_export_lora_output(self):
def browse_export_lora_output(self) -> None:
self.logger.info(BROWSING_FOR_EXPORT_LORA_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES
@ -124,7 +124,7 @@ def browse_export_lora_output(self):
self.export_lora_output.setText(os.path.abspath(output_file))
def add_lora_adapter(self):
def add_lora_adapter(self) -> None:
self.logger.info(ADDING_LORA_ADAPTER)
adapter_path, _ = QFileDialog.getOpenFileName(
self, SELECT_LORA_ADAPTER_FILE, "", LORA_FILES
@ -154,7 +154,7 @@ def add_lora_adapter(self):
self.export_lora_adapters.setItemWidget(list_item, adapter_widget)
def convert_lora(self):
def convert_lora(self) -> None:
self.logger.info(STARTING_LORA_CONVERSION)
try:
lora_input_path = self.lora_input.text()

View File

@ -5,30 +5,30 @@
from PySide6.QtCore import QTimer
from PySide6.QtWidgets import QApplication
from AutoGGUF import AutoGGUF
from flask import Flask, jsonify
from flask import Flask, Response, jsonify
server = Flask(__name__)
def main():
def main() -> None:
@server.route("/v1/models", methods=["GET"])
def models():
def models() -> Response:
if window:
return jsonify({"models": window.get_models_data()})
return jsonify({"models": []})
@server.route("/v1/tasks", methods=["GET"])
def tasks():
def tasks() -> Response:
if window:
return jsonify({"tasks": window.get_tasks_data()})
return jsonify({"tasks": []})
@server.route("/v1/health", methods=["GET"])
def ping():
def ping() -> Response:
return jsonify({"status": "alive"})
@server.route("/v1/backends", methods=["GET"])
def get_backends():
def get_backends() -> Response:
backends = []
for i in range(window.backend_combo.count()):
backends.append(
@ -40,7 +40,7 @@ def get_backends():
return jsonify({"backends": backends})
@server.route("/v1/plugins", methods=["GET"])
def get_plugins():
def get_plugins() -> Response:
if window:
return jsonify(
{
@ -57,7 +57,7 @@ def get_plugins():
)
return jsonify({"plugins": []})
def run_flask():
def run_flask() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
server.run(
host="0.0.0.0",

View File

@ -4,7 +4,7 @@
from Localizations import *
def save_preset(self):
def save_preset(self) -> None:
self.logger.info(SAVING_PRESET)
preset = {
"quant_types": [item.text() for item in self.quant_type.selectedItems()],
@ -33,7 +33,7 @@ def save_preset(self):
self.logger.info(PRESET_SAVED_TO.format(file_name))
def load_preset(self):
def load_preset(self) -> None:
self.logger.info(LOADING_PRESET)
file_name, _ = QFileDialog.getOpenFileName(self, LOAD_PRESET, "", JSON_FILES)
if file_name:

View File

@ -5,12 +5,12 @@
from error_handling import show_error
def update_model_info(logger, self, model_info):
def update_model_info(logger, self, model_info) -> None:
logger.debug(UPDATING_MODEL_INFO.format(model_info))
pass
def update_system_info(self):
def update_system_info(self) -> None:
ram = psutil.virtual_memory()
cpu = psutil.cpu_percent()
@ -28,7 +28,7 @@ def update_system_info(self):
self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu))
def animate_bar(self, bar, target_value):
def animate_bar(self, bar, target_value) -> None:
current_value = bar.value()
difference = target_value - current_value
@ -42,7 +42,7 @@ def animate_bar(self, bar, target_value):
timer.start(10) # Adjust the interval for animation speed
def _animate_step(bar, target_value, step, timer):
def _animate_step(bar, target_value, step, timer) -> None:
current_value = bar.value()
new_value = current_value + step
@ -55,11 +55,11 @@ def _animate_step(bar, target_value, step, timer):
bar.setValue(new_value)
def update_download_progress(self, progress):
def update_download_progress(self, progress) -> None:
self.download_progress.setValue(progress)
def update_cuda_backends(self):
def update_cuda_backends(self) -> None:
self.logger.debug(UPDATING_CUDA_BACKENDS)
self.backend_combo_cuda.clear()
llama_bin = os.path.abspath("llama_bin")
@ -77,23 +77,23 @@ def update_cuda_backends(self):
self.backend_combo_cuda.setEnabled(True)
def update_threads_spinbox(self, value):
def update_threads_spinbox(self, value) -> None:
self.threads_spinbox.setValue(value)
def update_threads_slider(self, value):
def update_threads_slider(self, value) -> None:
self.threads_slider.setValue(value)
def update_gpu_offload_spinbox(self, value):
def update_gpu_offload_spinbox(self, value) -> None:
self.gpu_offload_spinbox.setValue(value)
def update_gpu_offload_slider(self, value):
def update_gpu_offload_slider(self, value) -> None:
self.gpu_offload_slider.setValue(value)
def update_cuda_option(self):
def update_cuda_option(self) -> None:
self.logger.debug(UPDATING_CUDA_OPTIONS)
asset = self.asset_combo.currentData()
@ -113,7 +113,7 @@ def update_cuda_option(self):
self.update_cuda_backends()
def update_assets(self):
def update_assets(self) -> None:
self.logger.debug(UPDATING_ASSET_LIST)
self.asset_combo.clear()
release = self.release_combo.currentData()
@ -128,6 +128,6 @@ def update_assets(self):
self.update_cuda_option()
def update_base_model_visibility(self, index):
def update_base_model_visibility(self, index) -> None:
is_gguf = self.lora_output_type_combo.itemText(index) == "GGUF"
self.base_model_wrapper.setVisible(is_gguf)

View File

@ -1,3 +1,5 @@
from typing import Any, Dict, List, Union
from PySide6.QtCore import Qt
from PySide6.QtWidgets import QFileDialog
@ -9,7 +11,7 @@
from imports_and_globals import ensure_directory
def get_models_data(self):
def get_models_data(self) -> list[dict[str, Union[str, Any]]]:
models = []
root = self.model_tree.invisibleRootItem()
child_count = root.childCount()
@ -22,7 +24,7 @@ def get_models_data(self):
return models
def get_tasks_data(self):
def get_tasks_data(self) -> list[dict[str, Union[int, Any]]]:
tasks = []
for i in range(self.task_list.count()):
item = self.task_list.item(i)
@ -43,7 +45,7 @@ def get_tasks_data(self):
return tasks
def browse_models(self):
def browse_models(self) -> None:
self.logger.info(BROWSING_FOR_MODELS_DIRECTORY)
models_path = QFileDialog.getExistingDirectory(self, SELECT_MODELS_DIRECTORY)
if models_path:
@ -52,7 +54,7 @@ def browse_models(self):
self.load_models()
def browse_output(self):
def browse_output(self) -> None:
self.logger.info(BROWSING_FOR_OUTPUT_DIRECTORY)
output_path = QFileDialog.getExistingDirectory(self, SELECT_OUTPUT_DIRECTORY)
if output_path:
@ -60,7 +62,7 @@ def browse_output(self):
ensure_directory(output_path)
def browse_logs(self):
def browse_logs(self) -> None:
self.logger.info(BROWSING_FOR_LOGS_DIRECTORY)
logs_path = QFileDialog.getExistingDirectory(self, SELECT_LOGS_DIRECTORY)
if logs_path:
@ -68,7 +70,7 @@ def browse_logs(self):
ensure_directory(logs_path)
def browse_imatrix(self):
def browse_imatrix(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_FILE)
imatrix_file, _ = QFileDialog.getOpenFileName(
self, SELECT_IMATRIX_FILE, "", DAT_FILES
@ -77,7 +79,7 @@ def browse_imatrix(self):
self.imatrix.setText(os.path.abspath(imatrix_file))
def browse_lora_input(self):
def browse_lora_input(self) -> None:
self.logger.info(BROWSING_FOR_LORA_INPUT_DIRECTORY)
lora_input_path = QFileDialog.getExistingDirectory(
self, SELECT_LORA_INPUT_DIRECTORY
@ -87,7 +89,7 @@ def browse_lora_input(self):
ensure_directory(lora_input_path)
def browse_lora_output(self):
def browse_lora_output(self) -> None:
self.logger.info(BROWSING_FOR_LORA_OUTPUT_FILE)
lora_output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_LORA_OUTPUT_FILE, "", GGUF_AND_BIN_FILES
@ -96,7 +98,7 @@ def browse_lora_output(self):
self.lora_output.setText(os.path.abspath(lora_output_file))
def download_llama_cpp(self):
def download_llama_cpp(self) -> None:
self.logger.info(STARTING_LLAMACPP_DOWNLOAD)
asset = self.asset_combo.currentData()
if not asset:
@ -118,7 +120,7 @@ def download_llama_cpp(self):
self.download_progress.setValue(0)
def refresh_releases(self):
def refresh_releases(self) -> None:
self.logger.info(REFRESHING_LLAMACPP_RELEASES)
try:
response = requests.get(