refactor: allow more tensor types

This commit is contained in:
BuildTools 2024-08-05 17:51:54 -07:00
parent eca2ecc785
commit c0635936cc
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
3 changed files with 26 additions and 31 deletions

View File

@ -237,11 +237,30 @@ def __init__(self):
self.exclude_weights, self.exclude_weights,
) )
tensor_types = [
"Q2_K",
"Q2_K_S",
"Q3_K_S",
"Q3_K_M",
"Q3_K_L",
"Q4_K_S",
"Q4_K_M",
"Q5_K_S",
"Q5_K_M",
"Q6_K",
"Q8_0",
"Q4_0",
"Q4_1",
"Q5_0",
"Q5_1",
"BF16",
"F16",
"F32",
]
self.use_output_tensor_type = QCheckBox(USE_OUTPUT_TENSOR_TYPE) self.use_output_tensor_type = QCheckBox(USE_OUTPUT_TENSOR_TYPE)
self.output_tensor_type = QComboBox() self.output_tensor_type = QComboBox()
self.output_tensor_type.addItems( self.output_tensor_type.addItems(tensor_types)
["F32", "F16", "Q4_0", "Q4_1", "Q5_0", "Q5_1", "Q8_0"]
)
self.output_tensor_type.setEnabled(False) self.output_tensor_type.setEnabled(False)
self.use_output_tensor_type.toggled.connect( self.use_output_tensor_type.toggled.connect(
lambda checked: self.output_tensor_type.setEnabled(checked) lambda checked: self.output_tensor_type.setEnabled(checked)
@ -256,9 +275,7 @@ def __init__(self):
self.use_token_embedding_type = QCheckBox(USE_TOKEN_EMBEDDING_TYPE) self.use_token_embedding_type = QCheckBox(USE_TOKEN_EMBEDDING_TYPE)
self.token_embedding_type = QComboBox() self.token_embedding_type = QComboBox()
self.token_embedding_type.addItems( self.token_embedding_type.addItems(tensor_types)
["F32", "F16", "Q4_0", "Q4_1", "Q5_0", "Q5_1", "Q8_0"]
)
self.token_embedding_type.setEnabled(False) self.token_embedding_type.setEnabled(False)
self.use_token_embedding_type.toggled.connect( self.use_token_embedding_type.toggled.connect(
lambda checked: self.token_embedding_type.setEnabled(checked) lambda checked: self.token_embedding_type.setEnabled(checked)

View File

@ -1,17 +1,8 @@
from PyQt6.QtWidgets import *
from PyQt6.QtCore import *
from PyQt6.QtGui import *
import os import os
import sys
import psutil
import subprocess
import time
import signal
import json
import platform
import requests
import zipfile import zipfile
from datetime import datetime
import requests
from PyQt6.QtCore import *
class DownloadThread(QThread): class DownloadThread(QThread):

View File

@ -1,17 +1,4 @@
from PyQt6.QtWidgets import * from PyQt6.QtWidgets import *
from PyQt6.QtCore import *
from PyQt6.QtGui import *
import os
import sys
import psutil
import subprocess
import time
import signal
import json
import platform
import requests
import zipfile
from datetime import datetime
class ModelInfoDialog(QDialog): class ModelInfoDialog(QDialog):