fix: use proper status in TaskListItem

- use proper status in TaskListItem
- make sure to pass quant_threads and Logger to TaskListItem
- remove unnecessary logging in quantize_to_fp8_dynamic.py and optimize imports
This commit is contained in:
BuildTools 2024-09-02 18:43:22 -07:00
parent a7f2dec852
commit a91f804ec1
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
3 changed files with 51 additions and 19 deletions

View File

@ -96,7 +96,6 @@ def __init__(self, args: List[str]) -> None:
self.delete_task = partial(TaskListItem.delete_task, self)
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
@ -1036,7 +1035,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
self.quant_threads.append(thread)
task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
@ -1152,7 +1157,13 @@ def convert_hf_to_gguf(self) -> None:
self.quant_threads.append(thread)
task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)
@ -1516,7 +1527,10 @@ def quantize_model(self) -> None:
self.quant_threads.append(thread)
task_item = TaskListItem(
QUANTIZING_MODEL_TO.format(model_name, quant_type), log_file
QUANTIZING_MODEL_TO.format(model_name, quant_type),
log_file,
show_properties=True,
logger=self.logger,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
@ -1687,7 +1701,13 @@ def generate_imatrix(self) -> None:
task_name = GENERATING_IMATRIX_FOR.format(
os.path.basename(self.imatrix_model.text())
)
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
task_item = TaskListItem(
task_name,
log_file,
show_progress_bar=False,
logger=self.logger,
quant_threads=self.quant_threads,
)
list_item = QListWidgetItem(self.task_list)
list_item.setSizeHint(task_item.sizeHint())
self.task_list.addItem(list_item)

View File

@ -1,3 +1,5 @@
from typing import List
from PySide6.QtCore import *
from PySide6.QtGui import QAction
from PySide6.QtWidgets import *
@ -15,19 +17,34 @@
SHOWING_PROPERTIES_FOR_TASK,
DELETE,
RESTART,
IN_PROGRESS,
ERROR,
)
from ModelInfoDialog import ModelInfoDialog
from QuantizationThread import QuantizationThread
from Logger import Logger
class TaskListItem(QWidget):
def __init__(
self, task_name, log_file, show_progress_bar=True, parent=None
self,
task_name,
log_file,
show_progress_bar=True,
parent=None,
show_properties=False,
logger=Logger,
quant_threads=List[QuantizationThread],
) -> None:
super().__init__(parent)
self.quant_threads = quant_threads
self.task_name = task_name
self.log_file = log_file
self.logger = logger
self.show_properties = show_properties
self.status = "Pending"
layout = QHBoxLayout(self)
self.task_label = QLabel(task_name)
self.progress_bar = QProgressBar()
self.progress_bar.setRange(0, 100)
@ -84,7 +101,8 @@ def show_task_properties(self, item) -> None:
model_info_dialog.exec()
break
def cancel_task_by_item(self, item) -> None:
def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text()))
task_item = self.task_list.itemWidget(item)
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
@ -93,15 +111,11 @@ def cancel_task_by_item(self, item) -> None:
self.quant_threads.remove(thread)
break
def cancel_task(self, item) -> None:
self.logger.info(CANCELLING_TASK.format(item.text()))
self.cancel_task_by_item(item)
def delete_task(self, item) -> None:
self.logger.info(DELETING_TASK.format(item.text()))
# Cancel the task first
self.cancel_task_by_item(item)
self.cancel_task(item)
reply = QMessageBox.question(
self,
@ -121,21 +135,21 @@ def delete_task(self, item) -> None:
def update_status(self, status) -> None:
self.status = status
self.status_label.setText(status)
if status == "In Progress":
if status == IN_PROGRESS:
# Only start timer if showing percentage progress
if self.progress_bar.isVisible():
self.progress_bar.setRange(0, 100)
self.progress_timer.start(100)
elif status == "Completed":
elif status == COMPLETED:
self.progress_timer.stop()
self.progress_bar.setValue(100)
elif status == "Canceled":
elif status == CANCELED:
self.progress_timer.stop()
self.progress_bar.setValue(0)
def set_error(self) -> None:
self.status = "Error"
self.status_label.setText("Error")
self.status = ERROR
self.status_label.setText(ERROR)
self.status_label.setStyleSheet("color: red;")
self.progress_bar.setRange(0, 100)
self.progress_timer.stop()

View File

@ -9,7 +9,6 @@
import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from Logger import Logger
# https://github.com/neuralmagic/AutoFP8
@ -544,7 +543,6 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
print("Starting fp8 dynamic quantization")
# Define quantization config with static activation scales
quantize_config = BaseQuantizeConfig(
quant_method="fp8", activation_scheme="dynamic"