diff --git a/src/AutoGGUF.py b/src/AutoGGUF.py index 783c6d6..95a7e81 100644 --- a/src/AutoGGUF.py +++ b/src/AutoGGUF.py @@ -32,7 +32,7 @@ class AutoGGUF(QMainWindow): - def __init__(self): + def __init__(self, args): super().__init__() self.logger = Logger("AutoGGUF", "logs") @@ -272,6 +272,11 @@ def __init__(self): refresh_models_button.clicked.connect(self.load_models) left_layout.addWidget(refresh_models_button) + # Import Model button + import_model_button = QPushButton(IMPORT_MODEL) + import_model_button.clicked.connect(self.import_model) + left_layout.addWidget(import_model_button) + # Quantization options quant_options_scroll = QScrollArea() quant_options_widget = QWidget() @@ -1163,6 +1168,7 @@ def load_models(self): # Regex pattern to match sharded model filenames shard_pattern = re.compile(r"(.*)-(\d+)-of-(\d+)\.gguf$") + # Load models from the models directory for file in os.listdir(models_dir): if file.endswith(".gguf"): match = shard_pattern.match(file) @@ -1175,11 +1181,17 @@ def load_models(self): else: single_models.append(file) - # Add sharded models + # Add imported models + if hasattr(self, "imported_models"): + for imported_model in self.imported_models: + file_name = os.path.basename(imported_model) + if file_name not in single_models: + single_models.append(file_name) + + # Add sharded models to the tree for base_name, shards in sharded_models.items(): parent_item = QTreeWidgetItem(self.model_tree) parent_item.setText(0, f"{base_name} ({SHARDED})") - # Sort shards by shard number and get the first one first_shard = sorted(shards, key=lambda x: x[0])[0][1] parent_item.setData(0, Qt.ItemDataRole.UserRole, first_shard) for _, shard_file in sorted(shards): @@ -1187,11 +1199,20 @@ def load_models(self): child_item.setText(0, shard_file) child_item.setData(0, Qt.ItemDataRole.UserRole, shard_file) - # Add single models + # Add single models to the tree for model in sorted(single_models): item = QTreeWidgetItem(self.model_tree) item.setText(0, model) - item.setData(0, Qt.ItemDataRole.UserRole, model) + if hasattr(self, "imported_models") and model in [ + os.path.basename(m) for m in self.imported_models + ]: + full_path = next( + m for m in self.imported_models if os.path.basename(m) == model + ) + item.setData(0, Qt.ItemDataRole.UserRole, full_path) + item.setToolTip(0, IMPORTED_MODEL_TOOLTIP.format(full_path)) + else: + item.setData(0, Qt.ItemDataRole.UserRole, model) self.model_tree.expandAll() self.logger.info(LOADED_MODELS.format(len(single_models) + len(sharded_models))) @@ -1441,6 +1462,27 @@ def show_task_details(self, item): log_dialog.exec() + def import_model(self): + self.logger.info(IMPORTING_MODEL) + file_path, _ = QFileDialog.getOpenFileName( + self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES + ) + if file_path: + file_name = os.path.basename(file_path) + reply = QMessageBox.question( + self, + CONFIRM_IMPORT, + IMPORT_MODEL_CONFIRMATION.format(file_name), + QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No, + QMessageBox.StandardButton.No, + ) + if reply == QMessageBox.StandardButton.Yes: + if not hasattr(self, "imported_models"): + self.imported_models = [] + self.imported_models.append(file_path) + self.load_models() + self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name)) + def browse_imatrix_datafile(self): self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE) datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES) diff --git a/src/Localizations.py b/src/Localizations.py index 752b11e..df415d1 100644 --- a/src/Localizations.py +++ b/src/Localizations.py @@ -25,6 +25,15 @@ def __init__(self): self.AVAILABLE_MODELS = "Available Models:" self.REFRESH_MODELS = "Refresh Models" + # Model Import + self.IMPORT_MODEL = "Import Model" + self.SELECT_MODEL_TO_IMPORT = "Select Model to Import" + self.CONFIRM_IMPORT = "Confirm Import" + self.IMPORT_MODEL_CONFIRMATION = "Do you want to import the model {}?" + self.MODEL_IMPORTED_SUCCESSFULLY = "Model {} imported successfully" + self.IMPORTING_MODEL = "Importing model" + self.IMPORTED_MODEL_TOOLTIP = "Imported model: {}" + # GPU Monitoring self.GPU_USAGE = "GPU Usage:" self.GPU_USAGE_FORMAT = "GPU: {:.1f}% | VRAM: {:.1f}% ({} MB / {} MB)" diff --git a/src/main.py b/src/main.py index 0226cd2..e7eee9b 100644 --- a/src/main.py +++ b/src/main.py @@ -49,7 +49,7 @@ def run_flask(): ) app = QApplication(sys.argv) - window = AutoGGUF() + window = AutoGGUF(sys.argv) window.show() # Start Flask in a separate thread after a short delay timer = QTimer()