mirror of https://github.com/leafspark/AutoGGUF
fix: use proper status in TaskListItem
- use proper status in TaskListItem - make sure to pass quant_threads and Logger to TaskListItem - remove unnecessary logging in quantize_to_fp8_dynamic.py and optimize imports
This commit is contained in:
parent
a7f2dec852
commit
a91f804ec1
|
@ -96,7 +96,6 @@ def __init__(self, args: List[str]) -> None:
|
||||||
self.delete_task = partial(TaskListItem.delete_task, self)
|
self.delete_task = partial(TaskListItem.delete_task, self)
|
||||||
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
|
self.show_task_context_menu = partial(TaskListItem.show_task_context_menu, self)
|
||||||
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
|
self.show_task_properties = partial(TaskListItem.show_task_properties, self)
|
||||||
self.cancel_task_by_item = partial(TaskListItem.cancel_task_by_item, self)
|
|
||||||
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
|
self.toggle_gpu_offload_auto = partial(ui_update.toggle_gpu_offload_auto, self)
|
||||||
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
|
self.update_threads_spinbox = partial(ui_update.update_threads_spinbox, self)
|
||||||
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
|
self.update_threads_slider = partial(ui_update.update_threads_slider, self)
|
||||||
|
@ -1036,7 +1035,13 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
|
||||||
self.quant_threads.append(thread)
|
self.quant_threads.append(thread)
|
||||||
|
|
||||||
task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
|
task_name = f"Quantizing {os.path.basename(model_dir)} with AutoFP8"
|
||||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
task_item = TaskListItem(
|
||||||
|
task_name,
|
||||||
|
log_file,
|
||||||
|
show_progress_bar=False,
|
||||||
|
logger=self.logger,
|
||||||
|
quant_threads=self.quant_threads,
|
||||||
|
)
|
||||||
list_item = QListWidgetItem(self.task_list)
|
list_item = QListWidgetItem(self.task_list)
|
||||||
list_item.setSizeHint(task_item.sizeHint())
|
list_item.setSizeHint(task_item.sizeHint())
|
||||||
self.task_list.addItem(list_item)
|
self.task_list.addItem(list_item)
|
||||||
|
@ -1152,7 +1157,13 @@ def convert_hf_to_gguf(self) -> None:
|
||||||
self.quant_threads.append(thread)
|
self.quant_threads.append(thread)
|
||||||
|
|
||||||
task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
|
task_name = CONVERTING_TO_GGUF.format(os.path.basename(model_dir))
|
||||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
task_item = TaskListItem(
|
||||||
|
task_name,
|
||||||
|
log_file,
|
||||||
|
show_progress_bar=False,
|
||||||
|
logger=self.logger,
|
||||||
|
quant_threads=self.quant_threads,
|
||||||
|
)
|
||||||
list_item = QListWidgetItem(self.task_list)
|
list_item = QListWidgetItem(self.task_list)
|
||||||
list_item.setSizeHint(task_item.sizeHint())
|
list_item.setSizeHint(task_item.sizeHint())
|
||||||
self.task_list.addItem(list_item)
|
self.task_list.addItem(list_item)
|
||||||
|
@ -1516,7 +1527,10 @@ def quantize_model(self) -> None:
|
||||||
self.quant_threads.append(thread)
|
self.quant_threads.append(thread)
|
||||||
|
|
||||||
task_item = TaskListItem(
|
task_item = TaskListItem(
|
||||||
QUANTIZING_MODEL_TO.format(model_name, quant_type), log_file
|
QUANTIZING_MODEL_TO.format(model_name, quant_type),
|
||||||
|
log_file,
|
||||||
|
show_properties=True,
|
||||||
|
logger=self.logger,
|
||||||
)
|
)
|
||||||
list_item = QListWidgetItem(self.task_list)
|
list_item = QListWidgetItem(self.task_list)
|
||||||
list_item.setSizeHint(task_item.sizeHint())
|
list_item.setSizeHint(task_item.sizeHint())
|
||||||
|
@ -1687,7 +1701,13 @@ def generate_imatrix(self) -> None:
|
||||||
task_name = GENERATING_IMATRIX_FOR.format(
|
task_name = GENERATING_IMATRIX_FOR.format(
|
||||||
os.path.basename(self.imatrix_model.text())
|
os.path.basename(self.imatrix_model.text())
|
||||||
)
|
)
|
||||||
task_item = TaskListItem(task_name, log_file, show_progress_bar=False)
|
task_item = TaskListItem(
|
||||||
|
task_name,
|
||||||
|
log_file,
|
||||||
|
show_progress_bar=False,
|
||||||
|
logger=self.logger,
|
||||||
|
quant_threads=self.quant_threads,
|
||||||
|
)
|
||||||
list_item = QListWidgetItem(self.task_list)
|
list_item = QListWidgetItem(self.task_list)
|
||||||
list_item.setSizeHint(task_item.sizeHint())
|
list_item.setSizeHint(task_item.sizeHint())
|
||||||
self.task_list.addItem(list_item)
|
self.task_list.addItem(list_item)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from PySide6.QtCore import *
|
from PySide6.QtCore import *
|
||||||
from PySide6.QtGui import QAction
|
from PySide6.QtGui import QAction
|
||||||
from PySide6.QtWidgets import *
|
from PySide6.QtWidgets import *
|
||||||
|
@ -15,19 +17,34 @@
|
||||||
SHOWING_PROPERTIES_FOR_TASK,
|
SHOWING_PROPERTIES_FOR_TASK,
|
||||||
DELETE,
|
DELETE,
|
||||||
RESTART,
|
RESTART,
|
||||||
|
IN_PROGRESS,
|
||||||
|
ERROR,
|
||||||
)
|
)
|
||||||
from ModelInfoDialog import ModelInfoDialog
|
from ModelInfoDialog import ModelInfoDialog
|
||||||
|
from QuantizationThread import QuantizationThread
|
||||||
|
from Logger import Logger
|
||||||
|
|
||||||
|
|
||||||
class TaskListItem(QWidget):
|
class TaskListItem(QWidget):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, task_name, log_file, show_progress_bar=True, parent=None
|
self,
|
||||||
|
task_name,
|
||||||
|
log_file,
|
||||||
|
show_progress_bar=True,
|
||||||
|
parent=None,
|
||||||
|
show_properties=False,
|
||||||
|
logger=Logger,
|
||||||
|
quant_threads=List[QuantizationThread],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(parent)
|
super().__init__(parent)
|
||||||
|
self.quant_threads = quant_threads
|
||||||
self.task_name = task_name
|
self.task_name = task_name
|
||||||
self.log_file = log_file
|
self.log_file = log_file
|
||||||
|
self.logger = logger
|
||||||
|
self.show_properties = show_properties
|
||||||
self.status = "Pending"
|
self.status = "Pending"
|
||||||
layout = QHBoxLayout(self)
|
layout = QHBoxLayout(self)
|
||||||
|
|
||||||
self.task_label = QLabel(task_name)
|
self.task_label = QLabel(task_name)
|
||||||
self.progress_bar = QProgressBar()
|
self.progress_bar = QProgressBar()
|
||||||
self.progress_bar.setRange(0, 100)
|
self.progress_bar.setRange(0, 100)
|
||||||
|
@ -84,7 +101,8 @@ def show_task_properties(self, item) -> None:
|
||||||
model_info_dialog.exec()
|
model_info_dialog.exec()
|
||||||
break
|
break
|
||||||
|
|
||||||
def cancel_task_by_item(self, item) -> None:
|
def cancel_task(self, item) -> None:
|
||||||
|
self.logger.info(CANCELLING_TASK.format(item.text()))
|
||||||
task_item = self.task_list.itemWidget(item)
|
task_item = self.task_list.itemWidget(item)
|
||||||
for thread in self.quant_threads:
|
for thread in self.quant_threads:
|
||||||
if thread.log_file == task_item.log_file:
|
if thread.log_file == task_item.log_file:
|
||||||
|
@ -93,15 +111,11 @@ def cancel_task_by_item(self, item) -> None:
|
||||||
self.quant_threads.remove(thread)
|
self.quant_threads.remove(thread)
|
||||||
break
|
break
|
||||||
|
|
||||||
def cancel_task(self, item) -> None:
|
|
||||||
self.logger.info(CANCELLING_TASK.format(item.text()))
|
|
||||||
self.cancel_task_by_item(item)
|
|
||||||
|
|
||||||
def delete_task(self, item) -> None:
|
def delete_task(self, item) -> None:
|
||||||
self.logger.info(DELETING_TASK.format(item.text()))
|
self.logger.info(DELETING_TASK.format(item.text()))
|
||||||
|
|
||||||
# Cancel the task first
|
# Cancel the task first
|
||||||
self.cancel_task_by_item(item)
|
self.cancel_task(item)
|
||||||
|
|
||||||
reply = QMessageBox.question(
|
reply = QMessageBox.question(
|
||||||
self,
|
self,
|
||||||
|
@ -121,21 +135,21 @@ def delete_task(self, item) -> None:
|
||||||
def update_status(self, status) -> None:
|
def update_status(self, status) -> None:
|
||||||
self.status = status
|
self.status = status
|
||||||
self.status_label.setText(status)
|
self.status_label.setText(status)
|
||||||
if status == "In Progress":
|
if status == IN_PROGRESS:
|
||||||
# Only start timer if showing percentage progress
|
# Only start timer if showing percentage progress
|
||||||
if self.progress_bar.isVisible():
|
if self.progress_bar.isVisible():
|
||||||
self.progress_bar.setRange(0, 100)
|
self.progress_bar.setRange(0, 100)
|
||||||
self.progress_timer.start(100)
|
self.progress_timer.start(100)
|
||||||
elif status == "Completed":
|
elif status == COMPLETED:
|
||||||
self.progress_timer.stop()
|
self.progress_timer.stop()
|
||||||
self.progress_bar.setValue(100)
|
self.progress_bar.setValue(100)
|
||||||
elif status == "Canceled":
|
elif status == CANCELED:
|
||||||
self.progress_timer.stop()
|
self.progress_timer.stop()
|
||||||
self.progress_bar.setValue(0)
|
self.progress_bar.setValue(0)
|
||||||
|
|
||||||
def set_error(self) -> None:
|
def set_error(self) -> None:
|
||||||
self.status = "Error"
|
self.status = ERROR
|
||||||
self.status_label.setText("Error")
|
self.status_label.setText(ERROR)
|
||||||
self.status_label.setStyleSheet("color: red;")
|
self.status_label.setStyleSheet("color: red;")
|
||||||
self.progress_bar.setRange(0, 100)
|
self.progress_bar.setRange(0, 100)
|
||||||
self.progress_timer.stop()
|
self.progress_timer.stop()
|
||||||
|
|
|
@ -9,7 +9,6 @@
|
||||||
import tqdm
|
import tqdm
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
from Logger import Logger
|
|
||||||
|
|
||||||
# https://github.com/neuralmagic/AutoFP8
|
# https://github.com/neuralmagic/AutoFP8
|
||||||
|
|
||||||
|
@ -544,7 +543,6 @@ def get_kv_cache_quant_layers(model, kv_cache_quant_targets: Tuple[str]) -> List
|
||||||
|
|
||||||
|
|
||||||
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
|
def quantize_to_fp8_dynamic(input_model_dir: str, output_model_dir: str) -> None:
|
||||||
print("Starting fp8 dynamic quantization")
|
|
||||||
# Define quantization config with static activation scales
|
# Define quantization config with static activation scales
|
||||||
quantize_config = BaseQuantizeConfig(
|
quantize_config = BaseQuantizeConfig(
|
||||||
quant_method="fp8", activation_scheme="dynamic"
|
quant_method="fp8", activation_scheme="dynamic"
|
||||||
|
|
Loading…
Reference in New Issue