feat(hf): add support for repository types

- add support for repository types in HF Transfer utility
- add dequantize_gguf.py script
- improve layout of HF Upload window
This commit is contained in:
BuildTools 2024-09-22 09:48:48 -07:00
parent ac0f011784
commit c831622d6b
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
2 changed files with 143 additions and 12 deletions

View File

@ -313,38 +313,56 @@ def __init__(self, args: List[str]) -> None:
self.hf_upload_dialog.setFixedWidth(500) self.hf_upload_dialog.setFixedWidth(500)
self.hf_upload_layout = QVBoxLayout() self.hf_upload_layout = QVBoxLayout()
# Form layout for inputs
form_layout = QFormLayout()
# Repo input # Repo input
repo_layout = QHBoxLayout()
self.hf_repo_input = QLineEdit() self.hf_repo_input = QLineEdit()
repo_layout.addWidget(QLabel("Repository:")) form_layout.addRow("Repository:", self.hf_repo_input)
repo_layout.addWidget(self.hf_repo_input)
self.hf_upload_layout.addLayout(repo_layout)
# Remote path input # Remote path input
remote_path_layout = QHBoxLayout()
self.hf_remote_path_input = QLineEdit() self.hf_remote_path_input = QLineEdit()
remote_path_layout.addWidget(QLabel("Remote Path:")) form_layout.addRow("Remote Path:", self.hf_remote_path_input)
remote_path_layout.addWidget(self.hf_remote_path_input)
self.hf_upload_layout.addLayout(remote_path_layout)
# Local file/folder input # Local file/folder input
local_path_layout = QHBoxLayout() local_path_layout = QHBoxLayout()
self.hf_local_path_input = QLineEdit() self.hf_local_path_input = QLineEdit()
local_path_button = QPushButton("Browse") local_path_button = QPushButton("Browse")
local_path_button.clicked.connect(self.browse_local_path) local_path_button.clicked.connect(self.browse_local_path)
local_path_layout.addWidget(QLabel("Local Path:"))
local_path_layout.addWidget(self.hf_local_path_input) local_path_layout.addWidget(self.hf_local_path_input)
local_path_layout.addWidget(local_path_button) local_path_layout.addWidget(local_path_button)
self.hf_upload_layout.addLayout(local_path_layout) form_layout.addRow("Local Path:", local_path_layout)
self.hf_upload_layout.addLayout(form_layout)
# Upload type (file or folder) # Upload type (file or folder)
upload_type_group = QGroupBox("Upload Type")
upload_type_layout = QHBoxLayout()
self.upload_type_group = QButtonGroup() self.upload_type_group = QButtonGroup()
self.upload_type_file = QRadioButton("File") self.upload_type_file = QRadioButton("File")
self.upload_type_folder = QRadioButton("Folder") self.upload_type_folder = QRadioButton("Folder")
self.upload_type_group.addButton(self.upload_type_file) self.upload_type_group.addButton(self.upload_type_file)
self.upload_type_group.addButton(self.upload_type_folder) self.upload_type_group.addButton(self.upload_type_folder)
self.hf_upload_layout.addWidget(self.upload_type_file) upload_type_layout.addWidget(self.upload_type_file)
self.hf_upload_layout.addWidget(self.upload_type_folder) upload_type_layout.addWidget(self.upload_type_folder)
upload_type_group.setLayout(upload_type_layout)
self.hf_upload_layout.addWidget(upload_type_group)
# Repo type (dataset/space/model)
repo_type_group = QGroupBox("Repository Type")
repo_type_layout = QHBoxLayout()
self.repo_type_group = QButtonGroup()
self.repo_type_model = QRadioButton("Model")
self.repo_type_dataset = QRadioButton("Dataset")
self.repo_type_space = QRadioButton("Space")
self.repo_type_group.addButton(self.repo_type_model)
self.repo_type_group.addButton(self.repo_type_dataset)
self.repo_type_group.addButton(self.repo_type_space)
repo_type_layout.addWidget(self.repo_type_model)
repo_type_layout.addWidget(self.repo_type_dataset)
repo_type_layout.addWidget(self.repo_type_space)
repo_type_group.setLayout(repo_type_layout)
self.hf_upload_layout.addWidget(repo_type_group)
# Upload button # Upload button
upload_button = QPushButton("Upload") upload_button = QPushButton("Upload")
@ -1438,6 +1456,14 @@ def transfer_to_hf(self) -> None:
if remote_path: if remote_path:
command.append(remote_path) command.append(remote_path)
# Add repo type argument if selected
if self.repo_type_model.isChecked():
command.append("--repo-type=model")
elif self.repo_type_dataset.isChecked():
command.append("--repo-type=dataset")
elif self.repo_type_space.isChecked():
command.append("--repo-type=space")
logs_path = self.logs_input.text() logs_path = self.logs_input.text()
ensure_directory(logs_path) ensure_directory(logs_path)

105
src/dequantize_gguf.py Normal file
View File

@ -0,0 +1,105 @@
import argparse
import json
import sys
from pathlib import Path
import numpy as np
import torch
from safetensors.torch import save_file
import gguf
def dequantize_tensor(tensor):
if tensor.tensor_type in [
gguf.GGMLQuantizationType.F32,
gguf.GGMLQuantizationType.F16,
gguf.GGMLQuantizationType.BF16,
]:
return np.array(tensor.data)
else:
return tensor.data.astype(np.float32)
def gguf_to_safetensors(gguf_path, safetensors_path, metadata_path=None):
try:
reader = gguf.GGUFReader(gguf_path)
except Exception as e:
print(f"Error reading GGUF file: {e}", file=sys.stderr)
sys.exit(1)
tensors = {}
metadata = {}
for tensor in reader.tensors:
try:
dequantized_data = dequantize_tensor(tensor)
tensors[tensor.name] = torch.from_numpy(
dequantized_data.reshape(tuple(reversed(tensor.shape)))
)
except Exception as e:
print(f"Error processing tensor {tensor.name}: {e}", file=sys.stderr)
continue
for field_name, field in reader.fields.items():
if field.data:
metadata[field_name] = field.parts[field.data[0]].tolist()
try:
save_file(tensors, safetensors_path)
except Exception as e:
print(f"Error saving SafeTensors file: {e}", file=sys.stderr)
sys.exit(1)
decoded_metadata = {}
for key, value in metadata.items():
if isinstance(value, list) and all(isinstance(item, int) for item in value):
decoded_value = ""
for item in value:
if 48 <= item <= 57:
decoded_value += str(item - 48)
elif 32 <= item <= 126:
decoded_value += chr(item)
else:
decoded_value += str(item)
decoded_metadata[key] = decoded_value
else:
decoded_metadata[key] = value
if metadata_path:
try:
with open(metadata_path, "w") as f:
json.dump(decoded_metadata, f, indent=4)
except Exception as e:
print(f"Error saving metadata file: {e}", file=sys.stderr)
def main():
parser = argparse.ArgumentParser(description="Convert GGUF to SafeTensors format")
parser.add_argument("gguf_path", type=str, help="Path to the input GGUF file")
parser.add_argument(
"safetensors_path", type=str, help="Path to save the SafeTensors file"
)
parser.add_argument(
"--metadata_path",
type=str,
help="Optional path to save metadata as a JSON file",
)
args = parser.parse_args()
gguf_path = Path(args.gguf_path)
safetensors_path = Path(args.safetensors_path)
metadata_path = Path(args.metadata_path) if args.metadata_path else None
if not gguf_path.exists():
print(f"Error: GGUF file '{gguf_path}' does not exist.", file=sys.stderr)
sys.exit(1)
print(f"Converting {gguf_path} to {safetensors_path}")
gguf_to_safetensors(gguf_path, safetensors_path, metadata_path)
print("Conversion complete.")
if __name__ == "__main__":
main()