mirror of https://github.com/leafspark/AutoGGUF
fix: use proper status in TaskListItem
- use proper status in TaskListItem - make sure to pass quant_threads and Logger to TaskListItem - remove unnecessary logging in quantize_to_fp8_dynamic.py and optimize imports
This commit is contained in:
parent
a7f2dec852
commit
a91f804ec1
|
@ -96,7 +96,6 @@ def __init__(self, args: List[str]) -> None:
|
|||
self.delete_task = partial(TaskListItem.delete_task, self)
|
||||
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
|
||||
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
|
||||
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
|
||||
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
|
||||
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
|
||||
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
|
||||
|
@ -1036,7 +1035,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
|
|||
self.quant_threads.append(thread)
|
||||
|
||||
task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
|
||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
||||
task_item = TaskListItem(
|
||||
task_name,
|
||||
log_file,
|
||||
show_progress_bar=False,
|
||||
logger=self.logger,
|
||||
quant_threads=self.quant_threads,
|
||||
)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
@ -1152,7 +1157,13 @@ def convert_hf_to_gguf(self) -> None:
|
|||
self.quant_threads.append(thread)
|
||||
|
||||
task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
|
||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
||||
task_item = TaskListItem(
|
||||
task_name,
|
||||
log_file,
|
||||
show_progress_bar=False,
|
||||
logger=self.logger,
|
||||
quant_threads=self.quant_threads,
|
||||
)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
@ -1516,7 +1527,10 @@ def quantize_model(self) -> None:
|
|||
self.quant_threads.append(thread)
|
||||
|
||||
task_item = TaskListItem(
|
||||
QUANTIZING_MODEL_TO.format(model_name, quant_type), log_file
|
||||
QUANTIZING_MODEL_TO.format(model_name, quant_type),
|
||||
log_file,
|
||||
show_properties=True,
|
||||
logger=self.logger,
|
||||
)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
|
@ -1687,7 +1701,13 @@ def generate_imatrix(self) -> None:
|
|||
task_name = GENERATING_IMATRIX_FOR.format(
|
||||
os.path.basename(self.imatrix_model.text())
|
||||
)
|
||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
||||
task_item = TaskListItem(
|
||||
task_name,
|
||||
log_file,
|
||||
show_progress_bar=False,
|
||||
logger=self.logger,
|
||||
quant_threads=self.quant_threads,
|
||||
)
|
||||
list_item = QListWidgetItem(self.task_list)
|
||||
list_item.setSizeHint(task_item.sizeHint())
|
||||
self.task_list.addItem(list_item)
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
from typing import List
|
||||
|
||||
from PySide6.QtCore import *
|
||||
from PySide6.QtGui import QAction
|
||||
from PySide6.QtWidgets import *
|
||||
|
@ -15,19 +17,34 @@
|
|||
SHOWING_PROPERTIES_FOR_TASK,
|
||||
DELETE,
|
||||
RESTART,
|
||||
IN_PROGRESS,
|
||||
ERROR,
|
||||
)
|
||||
from ModelInfoDialog import ModelInfoDialog
|
||||
from QuantizationThread import QuantizationThread
|
||||
from Logger import Logger
|
||||
|
||||
|
||||
class TaskListItem(QWidget):
|
||||
def __init__(
|
||||
self, task_name, log_file, show_progress_bar=True, parent=None
|
||||
self,
|
||||
task_name,
|
||||
log_file,
|
||||
show_progress_bar=True,
|
||||
parent=None,
|
||||
show_properties=False,
|
||||
logger=Logger,
|
||||
quant_threads=List[QuantizationThread],
|
||||
) -> None:
|
||||
super().__init__(parent)
|
||||
self.quant_threads = quant_threads
|
||||
self.task_name = task_name
|
||||
self.log_file = log_file
|
||||
self.logger = logger
|
||||
self.show_properties = show_properties
|
||||
self.status = "Pending"
|
||||
layout = QHBoxLayout(self)
|
||||
|
||||
self.task_label = QLabel(task_name)
|
||||
self.progress_bar = QProgressBar()
|
||||
self.progress_bar.setRange(0, 100)
|
||||
|
@ -84,7 +101,8 @@ def show_task_properties(self, item) -> None:
|
|||
model_info_dialog.exec()
|
||||
break
|
||||
|
||||
def cancel_task_by_item(self, item) -> None:
|
||||
def cancel_task(self, item) -> None:
|
||||
self.logger.info(CANCELLING_TASK.format(item.text()))
|
||||
task_item = self.task_list.itemWidget(item)
|
||||
for thread in self.quant_threads:
|
||||
if thread.log_file == task_item.log_file:
|
||||
|
@ -93,15 +111,11 @@ def cancel_task_by_item(self, item) -> None:
|
|||
self.quant_threads.remove(thread)
|
||||
break
|
||||
|
||||
def cancel_task(self, item) -> None:
|
||||
self.logger.info(CANCELLING_TASK.format(item.text()))
|
||||
self.cancel_task_by_item(item)
|
||||
|
||||
def delete_task(self, item) -> None:
|
||||
self.logger.info(DELETING_TASK.format(item.text()))
|
||||
|
||||
# Cancel the task first
|
||||
self.cancel_task_by_item(item)
|
||||
self.cancel_task(item)
|
||||
|
||||
reply = QMessageBox.question(
|
||||
self,
|
||||
|
@ -121,21 +135,21 @@ def delete_task(self, item) -> None:
|
|||
def update_status(self, status) -> None:
|
||||
self.status = status
|
||||
self.status_label.setText(status)
|
||||
if status == "In Progress":
|
||||
if status == IN_PROGRESS:
|
||||
# Only start timer if showing percentage progress
|
||||
if self.progress_bar.isVisible():
|
||||
self.progress_bar.setRange(0, 100)
|
||||
self.progress_timer.start(100)
|
||||
elif status == "Completed":
|
||||
elif status == COMPLETED:
|
||||
self.progress_timer.stop()
|
||||
self.progress_bar.setValue(100)
|
||||
elif status == "Canceled":
|
||||
elif status == CANCELED:
|
||||
self.progress_timer.stop()
|
||||
self.progress_bar.setValue(0)
|
||||
|
||||
def set_error(self) -> None:
|
||||
self.status = "Error"
|
||||
self.status_label.setText("Error")
|
||||
self.status = ERROR
|
||||
self.status_label.setText(ERROR)
|
||||
self.status_label.setStyleSheet("color: red;")
|
||||
self.progress_bar.setRange(0, 100)
|
||||
self.progress_timer.stop()
|
||||
|
|
|
@ -9,7 +9,6 @@
|
|||
import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from Logger import Logger
|
||||
|
||||
# https://github.com/neuralmagic/AutoFP8
|
||||
|
||||
|
@ -544,7 +543,6 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
|
|||
|
||||
|
||||
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
|
||||
print("Starting fp8 dynamic quantization")
|
||||
# Define quantization config with static activation scales
|
||||
quantize_config = BaseQuantizeConfig(
|
||||
quant_method="fp8", activation_scheme="dynamic"
|
||||
|
|
Loading…
Reference in New Issue