refactor: move error handling to separate module

This commit is contained in:
BuildTools 2024-08-05 16:52:33 -07:00
parent 3ff9caabbf
commit 1feab011e4
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
2 changed files with 61 additions and 48 deletions

View File

@ -22,6 +22,7 @@
from KVOverrideEntry import KVOverrideEntry
from Logger import Logger
from localizations import *
from error_handling import show_error, handle_error
class AutoGGUF(QMainWindow):
@ -627,15 +628,15 @@ def __init__(self):
def refresh_backends(self):
self.logger.info(REFRESHING_BACKENDS)
llama_bin = os.path.abspath("llama_bin")
if not os.path.exists(llama_bin):
os.makedirs(llama_bin)
os.makedirs(llama_bin, exist_ok=True)
self.backend_combo.clear()
valid_backends = []
for item in os.listdir(llama_bin):
item_path = os.path.join(llama_bin, item)
if os.path.isdir(item_path) and "cudart-llama" not in item.lower():
valid_backends.append((item, item_path))
valid_backends = [
(item, os.path.join(llama_bin, item))
for item in os.listdir(llama_bin)
if os.path.isdir(os.path.join(llama_bin, item))
and "cudart-llama" not in item.lower()
]
if valid_backends:
for name, path in valid_backends:
@ -890,11 +891,13 @@ def convert_hf_to_gguf(self):
thread.finished_signal.connect(
lambda: self.task_finished(thread, task_item)
)
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
thread.start()
except Exception as e:
self.show_error(ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED)
def export_lora(self):
@ -968,13 +971,15 @@ def export_lora(self):
thread.status_signal.connect(task_item.update_status)
thread.finished_signal.connect(lambda: self.task_finished(thread))
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
thread.start()
self.logger.info(LORA_EXPORT_TASK_STARTED)
except ValueError as e:
self.show_error(str(e))
show_error(self.logger, str(e))
except Exception as e:
self.show_error(ERROR_STARTING_LORA_EXPORT.format(str(e)))
show_error(self.logger, ERROR_STARTING_LORA_EXPORT.format(str(e)))
def restart_task(self, task_item):
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
@ -989,7 +994,7 @@ def restart_task(self, task_item):
lambda: self.task_finished(new_thread)
)
new_thread.error_signal.connect(
lambda err: self.handle_error(err, task_item)
lambda err: handle_error(self.logger, err, task_item)
)
new_thread.model_info_signal.connect(self.update_model_info)
new_thread.start()
@ -1067,13 +1072,15 @@ def convert_lora(self):
thread, lora_input_path, lora_output_path
)
)
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
thread.start()
self.logger.info(LORA_CONVERSION_TASK_STARTED)
except ValueError as e:
self.show_error(str(e))
show_error(self.logger, str(e))
except Exception as e:
self.show_error(ERROR_STARTING_LORA_CONVERSION.format(str(e)))
show_error(self.logger, ERROR_STARTING_LORA_CONVERSION.format(str(e)))
def lora_conversion_finished(self, thread, input_path, output_path):
self.logger.info(LORA_CONVERSION_FINISHED)
@ -1142,7 +1149,7 @@ def refresh_releases(self):
self.release_combo.currentIndexChanged.connect(self.update_assets)
self.update_assets()
except requests.exceptions.RequestException as e:
self.show_error(ERROR_FETCHING_RELEASES.format(str(e)))
show_error(self.logger, ERROR_FETCHING_RELEASES.format(str(e)))
def update_assets(self):
self.logger.debug(UPDATING_ASSET_LIST)
@ -1153,19 +1160,20 @@ def update_assets(self):
for asset in release["assets"]:
self.asset_combo.addItem(asset["name"], userData=asset)
else:
self.show_error(NO_ASSETS_FOUND_FOR_RELEASE.format(release["tag_name"]))
show_error(
self.logger, NO_ASSETS_FOUND_FOR_RELEASE.format(release["tag_name"])
)
self.update_cuda_option()
def download_llama_cpp(self):
self.logger.info(STARTING_LLAMACPP_DOWNLOAD)
asset = self.asset_combo.currentData()
if not asset:
self.show_error(NO_ASSET_SELECTED)
show_error(self.logger, NO_ASSET_SELECTED)
return
llama_bin = os.path.abspath("llama_bin")
if not os.path.exists(llama_bin):
os.makedirs(llama_bin)
os.makedirs(llama_bin, exist_ok=True)
save_path = os.path.join(llama_bin, asset["name"])
@ -1259,7 +1267,7 @@ def download_error(self, error_message):
self.logger.error(DOWNLOAD_ERROR.format(error_message))
self.download_button.setEnabled(True)
self.download_progress.setValue(0)
self.show_error(DOWNLOAD_FAILED.format(error_message))
show_error(self.logger, DOWNLOAD_FAILED.format(error_message))
# Clean up any partially downloaded files
asset = self.asset_combo.currentData()
@ -1450,17 +1458,6 @@ def browse_imatrix(self):
if imatrix_file:
self.imatrix.setText(os.path.abspath(imatrix_file))
def update_system_info(self):
ram = psutil.virtual_memory()
cpu = psutil.cpu_percent()
self.ram_bar.setValue(int(ram.percent))
self.ram_bar.setFormat(
RAM_USAGE_FORMAT.format(
ram.percent, ram.used // 1024 // 1024, ram.total // 1024 // 1024
)
)
self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu))
def validate_quantization_inputs(self):
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
errors = []
@ -1478,6 +1475,17 @@ def validate_quantization_inputs(self):
if errors:
raise ValueError("\n".join(errors))
def update_system_info(self):
ram = psutil.virtual_memory()
cpu = psutil.cpu_percent()
self.ram_bar.setValue(int(ram.percent))
self.ram_bar.setFormat(
RAM_USAGE_FORMAT.format(
ram.percent, ram.used // 1024 // 1024, ram.total // 1024 // 1024
)
)
self.cpu_label.setText(CPU_USAGE_FORMAT.format(cpu))
def add_kv_override(self, override_string=None):
entry = KVOverrideEntry()
entry.deleted.connect(self.remove_kv_override)
@ -1644,7 +1652,7 @@ def quantize_model(self):
lambda t=thread, ti=task_item: self.task_finished(t, ti)
)
thread.error_signal.connect(
lambda err, ti=task_item: self.handle_error(err, ti)
lambda err, ti=task_item: handle_error(self.logger, err, ti)
)
thread.model_info_signal.connect(self.update_model_info)
@ -1654,11 +1662,11 @@ def quantize_model(self):
self.logger.info(QUANTIZATION_TASK_STARTED.format(model_name))
except ValueError as e:
self.show_error(str(e))
show_error(self.logger, str(e))
except FileNotFoundError as e:
self.show_error(str(e))
show_error(self.logger, str(e))
except Exception as e:
self.show_error(ERROR_STARTING_QUANTIZATION.format(str(e)))
show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e)))
def update_model_info(self, model_info):
self.logger.debug(UPDATING_MODEL_INFO.format(model_info))
@ -1794,21 +1802,14 @@ def generate_imatrix(self):
thread.status_signal.connect(task_item.update_status)
thread.finished_signal.connect(lambda: self.task_finished(thread))
thread.error_signal.connect(lambda err: self.handle_error(err, task_item))
thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
thread.start()
except Exception as e:
self.show_error(ERROR_STARTING_IMATRIX_GENERATION.format(str(e)))
show_error(self.logger, ERROR_STARTING_IMATRIX_GENERATION.format(str(e)))
self.logger.info(IMATRIX_GENERATION_TASK_STARTED)
def show_error(self, message):
self.logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(self, ERROR, message)
def handle_error(self, error_message, task_item):
self.logger.error(TASK_ERROR.format(error_message))
self.show_error(error_message)
task_item.update_status(ERROR)
def closeEvent(self, event: QCloseEvent):
self.logger.info(APPLICATION_CLOSING)
if self.quant_threads:

12
src/error_handling.py Normal file
View File

@ -0,0 +1,12 @@
from PyQt6.QtWidgets import QMessageBox
from localizations import *
def show_error(logger, message):
logger.error(ERROR_MESSAGE.format(message))
QMessageBox.critical(None, ERROR, message)
def handle_error(logger, error_message, task_item):
logger.error(TASK_ERROR.format(error_message))
show_error(logger, error_message)
task_item.update_status(ERROR)