mirror of https://github.com/leafspark/AutoGGUF
feat(core): add manual model import
- allow importing models from any directory on the system - add args as AutoGGUF class parameter
This commit is contained in:
parent
89d3762317
commit
88875e3d67
|
@ -32,7 +32,7 @@
|
||||||
|
|
||||||
|
|
||||||
class AutoGGUF(QMainWindow):
|
class AutoGGUF(QMainWindow):
|
||||||
def __init__(self):
|
def __init__(self, args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.logger = Logger("AutoGGUF", "logs")
|
self.logger = Logger("AutoGGUF", "logs")
|
||||||
|
|
||||||
|
@ -272,6 +272,11 @@ def __init__(self):
|
||||||
refresh_models_button.clicked.connect(self.load_models)
|
refresh_models_button.clicked.connect(self.load_models)
|
||||||
left_layout.addWidget(refresh_models_button)
|
left_layout.addWidget(refresh_models_button)
|
||||||
|
|
||||||
|
# Import Model button
|
||||||
|
import_model_button = QPushButton(IMPORT_MODEL)
|
||||||
|
import_model_button.clicked.connect(self.import_model)
|
||||||
|
left_layout.addWidget(import_model_button)
|
||||||
|
|
||||||
# Quantization options
|
# Quantization options
|
||||||
quant_options_scroll = QScrollArea()
|
quant_options_scroll = QScrollArea()
|
||||||
quant_options_widget = QWidget()
|
quant_options_widget = QWidget()
|
||||||
|
@ -1163,6 +1168,7 @@ def load_models(self):
|
||||||
# Regex pattern to match sharded model filenames
|
# Regex pattern to match sharded model filenames
|
||||||
shard_pattern = re.compile(r"(.*)-(\d+)-of-(\d+)\.gguf$")
|
shard_pattern = re.compile(r"(.*)-(\d+)-of-(\d+)\.gguf$")
|
||||||
|
|
||||||
|
# Load models from the models directory
|
||||||
for file in os.listdir(models_dir):
|
for file in os.listdir(models_dir):
|
||||||
if file.endswith(".gguf"):
|
if file.endswith(".gguf"):
|
||||||
match = shard_pattern.match(file)
|
match = shard_pattern.match(file)
|
||||||
|
@ -1175,11 +1181,17 @@ def load_models(self):
|
||||||
else:
|
else:
|
||||||
single_models.append(file)
|
single_models.append(file)
|
||||||
|
|
||||||
# Add sharded models
|
# Add imported models
|
||||||
|
if hasattr(self, "imported_models"):
|
||||||
|
for imported_model in self.imported_models:
|
||||||
|
file_name = os.path.basename(imported_model)
|
||||||
|
if file_name not in single_models:
|
||||||
|
single_models.append(file_name)
|
||||||
|
|
||||||
|
# Add sharded models to the tree
|
||||||
for base_name, shards in sharded_models.items():
|
for base_name, shards in sharded_models.items():
|
||||||
parent_item = QTreeWidgetItem(self.model_tree)
|
parent_item = QTreeWidgetItem(self.model_tree)
|
||||||
parent_item.setText(0, f"{base_name} ({SHARDED})")
|
parent_item.setText(0, f"{base_name} ({SHARDED})")
|
||||||
# Sort shards by shard number and get the first one
|
|
||||||
first_shard = sorted(shards, key=lambda x: x[0])[0][1]
|
first_shard = sorted(shards, key=lambda x: x[0])[0][1]
|
||||||
parent_item.setData(0, Qt.ItemDataRole.UserRole, first_shard)
|
parent_item.setData(0, Qt.ItemDataRole.UserRole, first_shard)
|
||||||
for _, shard_file in sorted(shards):
|
for _, shard_file in sorted(shards):
|
||||||
|
@ -1187,10 +1199,19 @@ def load_models(self):
|
||||||
child_item.setText(0, shard_file)
|
child_item.setText(0, shard_file)
|
||||||
child_item.setData(0, Qt.ItemDataRole.UserRole, shard_file)
|
child_item.setData(0, Qt.ItemDataRole.UserRole, shard_file)
|
||||||
|
|
||||||
# Add single models
|
# Add single models to the tree
|
||||||
for model in sorted(single_models):
|
for model in sorted(single_models):
|
||||||
item = QTreeWidgetItem(self.model_tree)
|
item = QTreeWidgetItem(self.model_tree)
|
||||||
item.setText(0, model)
|
item.setText(0, model)
|
||||||
|
if hasattr(self, "imported_models") and model in [
|
||||||
|
os.path.basename(m) for m in self.imported_models
|
||||||
|
]:
|
||||||
|
full_path = next(
|
||||||
|
m for m in self.imported_models if os.path.basename(m) == model
|
||||||
|
)
|
||||||
|
item.setData(0, Qt.ItemDataRole.UserRole, full_path)
|
||||||
|
item.setToolTip(0, IMPORTED_MODEL_TOOLTIP.format(full_path))
|
||||||
|
else:
|
||||||
item.setData(0, Qt.ItemDataRole.UserRole, model)
|
item.setData(0, Qt.ItemDataRole.UserRole, model)
|
||||||
|
|
||||||
self.model_tree.expandAll()
|
self.model_tree.expandAll()
|
||||||
|
@ -1441,6 +1462,27 @@ def show_task_details(self, item):
|
||||||
|
|
||||||
log_dialog.exec()
|
log_dialog.exec()
|
||||||
|
|
||||||
|
def import_model(self):
|
||||||
|
self.logger.info(IMPORTING_MODEL)
|
||||||
|
file_path, _ = QFileDialog.getOpenFileName(
|
||||||
|
self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES
|
||||||
|
)
|
||||||
|
if file_path:
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
reply = QMessageBox.question(
|
||||||
|
self,
|
||||||
|
CONFIRM_IMPORT,
|
||||||
|
IMPORT_MODEL_CONFIRMATION.format(file_name),
|
||||||
|
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
|
||||||
|
QMessageBox.StandardButton.No,
|
||||||
|
)
|
||||||
|
if reply == QMessageBox.StandardButton.Yes:
|
||||||
|
if not hasattr(self, "imported_models"):
|
||||||
|
self.imported_models = []
|
||||||
|
self.imported_models.append(file_path)
|
||||||
|
self.load_models()
|
||||||
|
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))
|
||||||
|
|
||||||
def browse_imatrix_datafile(self):
|
def browse_imatrix_datafile(self):
|
||||||
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
|
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
|
||||||
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
|
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
|
||||||
|
|
|
@ -25,6 +25,15 @@ def __init__(self):
|
||||||
self.AVAILABLE_MODELS = "Available Models:"
|
self.AVAILABLE_MODELS = "Available Models:"
|
||||||
self.REFRESH_MODELS = "Refresh Models"
|
self.REFRESH_MODELS = "Refresh Models"
|
||||||
|
|
||||||
|
# Model Import
|
||||||
|
self.IMPORT_MODEL = "Import Model"
|
||||||
|
self.SELECT_MODEL_TO_IMPORT = "Select Model to Import"
|
||||||
|
self.CONFIRM_IMPORT = "Confirm Import"
|
||||||
|
self.IMPORT_MODEL_CONFIRMATION = "Do you want to import the model {}?"
|
||||||
|
self.MODEL_IMPORTED_SUCCESSFULLY = "Model {} imported successfully"
|
||||||
|
self.IMPORTING_MODEL = "Importing model"
|
||||||
|
self.IMPORTED_MODEL_TOOLTIP = "Imported model: {}"
|
||||||
|
|
||||||
# GPU Monitoring
|
# GPU Monitoring
|
||||||
self.GPU_USAGE = "GPU Usage:"
|
self.GPU_USAGE = "GPU Usage:"
|
||||||
self.GPU_USAGE_FORMAT = "GPU: {:.1f}% | VRAM: {:.1f}% ({} MB / {} MB)"
|
self.GPU_USAGE_FORMAT = "GPU: {:.1f}% | VRAM: {:.1f}% ({} MB / {} MB)"
|
||||||
|
|
|
@ -49,7 +49,7 @@ def run_flask():
|
||||||
)
|
)
|
||||||
|
|
||||||
app = QApplication(sys.argv)
|
app = QApplication(sys.argv)
|
||||||
window = AutoGGUF()
|
window = AutoGGUF(sys.argv)
|
||||||
window.show()
|
window.show()
|
||||||
# Start Flask in a separate thread after a short delay
|
# Start Flask in a separate thread after a short delay
|
||||||
timer = QTimer()
|
timer = QTimer()
|
||||||
|
|
Loading…
Reference in New Issue