feat(core): add manual model import

- allow importing models from any directory on the system
- add args as AutoGGUF class parameter
This commit is contained in:
BuildTools 2024-08-22 15:39:08 -07:00
parent 89d3762317
commit 88875e3d67
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
3 changed files with 57 additions and 6 deletions

View File

@ -32,7 +32,7 @@
class AutoGGUF(QMainWindow):
def __init__(self):
def __init__(self, args):
super().__init__()
self.logger = Logger("AutoGGUF", "logs")
@ -272,6 +272,11 @@ def __init__(self):
refresh_models_button.clicked.connect(self.load_models)
left_layout.addWidget(refresh_models_button)
# Import Model button
import_model_button = QPushButton(IMPORT_MODEL)
import_model_button.clicked.connect(self.import_model)
left_layout.addWidget(import_model_button)
# Quantization options
quant_options_scroll = QScrollArea()
quant_options_widget = QWidget()
@ -1163,6 +1168,7 @@ def load_models(self):
# Regex pattern to match sharded model filenames
shard_pattern = re.compile(r"(.*)-(\d+)-of-(\d+)\.gguf$")
# Load models from the models directory
for file in os.listdir(models_dir):
if file.endswith(".gguf"):
match = shard_pattern.match(file)
@ -1175,11 +1181,17 @@ def load_models(self):
else:
single_models.append(file)
# Add sharded models
# Add imported models
if hasattr(self, "imported_models"):
for imported_model in self.imported_models:
file_name = os.path.basename(imported_model)
if file_name not in single_models:
single_models.append(file_name)
# Add sharded models to the tree
for base_name, shards in sharded_models.items():
parent_item = QTreeWidgetItem(self.model_tree)
parent_item.setText(0, f"{base_name} ({SHARDED})")
# Sort shards by shard number and get the first one
first_shard = sorted(shards, key=lambda x: x[0])[0][1]
parent_item.setData(0, Qt.ItemDataRole.UserRole, first_shard)
for _, shard_file in sorted(shards):
@ -1187,11 +1199,20 @@ def load_models(self):
child_item.setText(0, shard_file)
child_item.setData(0, Qt.ItemDataRole.UserRole, shard_file)
# Add single models
# Add single models to the tree
for model in sorted(single_models):
item = QTreeWidgetItem(self.model_tree)
item.setText(0, model)
item.setData(0, Qt.ItemDataRole.UserRole, model)
if hasattr(self, "imported_models") and model in [
os.path.basename(m) for m in self.imported_models
]:
full_path = next(
m for m in self.imported_models if os.path.basename(m) == model
)
item.setData(0, Qt.ItemDataRole.UserRole, full_path)
item.setToolTip(0, IMPORTED_MODEL_TOOLTIP.format(full_path))
else:
item.setData(0, Qt.ItemDataRole.UserRole, model)
self.model_tree.expandAll()
self.logger.info(LOADED_MODELS.format(len(single_models) + len(sharded_models)))
@ -1441,6 +1462,27 @@ def show_task_details(self, item):
log_dialog.exec()
def import_model(self):
self.logger.info(IMPORTING_MODEL)
file_path, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_TO_IMPORT, "", GGUF_FILES
)
if file_path:
file_name = os.path.basename(file_path)
reply = QMessageBox.question(
self,
CONFIRM_IMPORT,
IMPORT_MODEL_CONFIRMATION.format(file_name),
QMessageBox.StandardButton.Yes | QMessageBox.StandardButton.No,
QMessageBox.StandardButton.No,
)
if reply == QMessageBox.StandardButton.Yes:
if not hasattr(self, "imported_models"):
self.imported_models = []
self.imported_models.append(file_path)
self.load_models()
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))
def browse_imatrix_datafile(self):
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)

View File

@ -25,6 +25,15 @@ def __init__(self):
self.AVAILABLE_MODELS = "Available Models:"
self.REFRESH_MODELS = "Refresh Models"
# Model Import
self.IMPORT_MODEL = "Import Model"
self.SELECT_MODEL_TO_IMPORT = "Select Model to Import"
self.CONFIRM_IMPORT = "Confirm Import"
self.IMPORT_MODEL_CONFIRMATION = "Do you want to import the model {}?"
self.MODEL_IMPORTED_SUCCESSFULLY = "Model {} imported successfully"
self.IMPORTING_MODEL = "Importing model"
self.IMPORTED_MODEL_TOOLTIP = "Imported model: {}"
# GPU Monitoring
self.GPU_USAGE = "GPU Usage:"
self.GPU_USAGE_FORMAT = "GPU: {:.1f}% | VRAM: {:.1f}% ({} MB / {} MB)"

View File

@ -49,7 +49,7 @@ def run_flask():
)
app = QApplication(sys.argv)
window = AutoGGUF()
window = AutoGGUF(sys.argv)
window.show()
# Start Flask in a separate thread after a short delay
timer = QTimer()