feat(server): add read only flask server

- added the following endpoints:
/v1/backends (lists all backends and path)
/v1/health (heartbeat)
/v1/tasks (gets current task info, includes name, status, progress, and log file)
/v1/models (gets name, model type, path, and shard status)
This commit is contained in:
BuildTools 2024-08-15 17:15:21 -07:00
parent 79eeb02694
commit 2e90c91eb8
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
3 changed files with 86 additions and 17 deletions

View File

@ -6,3 +6,4 @@ sentencepiece~=0.2.0
PyYAML~=6.0.2 PyYAML~=6.0.2
pynvml~=11.5.3 pynvml~=11.5.3
PySide6~=6.7.2 PySide6~=6.7.2
flask~=3.0.3

View File

@ -1,7 +1,6 @@
import json import json
import re import re
import shutil import shutil
import sys
from datetime import datetime from datetime import datetime
import psutil import psutil
@ -9,15 +8,15 @@
from PySide6.QtCore import * from PySide6.QtCore import *
from PySide6.QtGui import * from PySide6.QtGui import *
from PySide6.QtWidgets import * from PySide6.QtWidgets import *
from flask import Flask, jsonify
from DownloadThread import DownloadThread from DownloadThread import DownloadThread
from GPUMonitor import GPUMonitor
from KVOverrideEntry import KVOverrideEntry from KVOverrideEntry import KVOverrideEntry
from Logger import Logger from Logger import Logger
from ModelInfoDialog import ModelInfoDialog from ModelInfoDialog import ModelInfoDialog
from QuantizationThread import QuantizationThread from QuantizationThread import QuantizationThread
from TaskListItem import TaskListItem from TaskListItem import TaskListItem
from GPUMonitor import GPUMonitor
from error_handling import show_error, handle_error from error_handling import show_error, handle_error
from imports_and_globals import ensure_directory, open_file_safe, resource_path from imports_and_globals import ensure_directory, open_file_safe, resource_path
from localizations import * from localizations import *
@ -655,7 +654,6 @@ def __init__(self):
# Load models # Load models
self.load_models() self.load_models()
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE) self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
def refresh_backends(self): def refresh_backends(self):
@ -1751,6 +1749,38 @@ def browse_imatrix_output(self):
if output_file: if output_file:
self.imatrix_output.setText(os.path.abspath(output_file)) self.imatrix_output.setText(os.path.abspath(output_file))
def get_models_data(self):
models = []
root = self.model_tree.invisibleRootItem()
child_count = root.childCount()
for i in range(child_count):
item = root.child(i)
model_name = item.text(0)
model_type = "sharded" if "sharded" in model_name.lower() else "single"
model_path = item.data(0, Qt.ItemDataRole.UserRole)
models.append({"name": model_name, "type": model_type, "path": model_path})
return models
def get_tasks_data(self):
tasks = []
for i in range(self.task_list.count()):
item = self.task_list.item(i)
task_widget = self.task_list.itemWidget(item)
if task_widget:
tasks.append(
{
"name": task_widget.task_name,
"status": task_widget.status,
"progress": (
task_widget.progress_bar.value()
if hasattr(task_widget, "progress_bar")
else 0
),
"log_file": task_widget.log_file,
}
)
return tasks
def generate_imatrix(self): def generate_imatrix(self):
self.logger.info(STARTING_IMATRIX_GENERATION) self.logger.info(STARTING_IMATRIX_GENERATION)
try: try:
@ -1832,10 +1862,3 @@ def closeEvent(self, event: QCloseEvent):
else: else:
event.accept() event.accept()
self.logger.info(APPLICATION_CLOSED) self.logger.info(APPLICATION_CLOSED)
if __name__ == "__main__":
app = QApplication(sys.argv)
window = AutoGGUF()
window.show()
sys.exit(app.exec())

View File

@ -1,9 +1,54 @@
import sys import sys
import threading
from PySide6.QtCore import QTimer
from PySide6.QtWidgets import QApplication from PySide6.QtWidgets import QApplication
from AutoGGUF import AutoGGUF from AutoGGUF import AutoGGUF
from flask import Flask, jsonify
server = Flask(__name__)
@server.route("/v1/models", methods=["GET"])
def models():
if window:
return jsonify({"models": window.get_models_data()})
return jsonify({"models": []})
@server.route("/v1/tasks", methods=["GET"])
def tasks():
if window:
return jsonify({"tasks": window.get_tasks_data()})
return jsonify({"tasks": []})
@server.route("/v1/health", methods=["GET"])
def ping():
return jsonify({"status": "alive"})
@server.route("/v1/backends", methods=["GET"])
def get_backends():
backends = []
for i in range(window.backend_combo.count()):
backends.append(
{
"name": window.backend_combo.itemText(i),
"path": window.backend_combo.itemData(i),
}
)
return jsonify({"backends": backends})
def run_flask():
server.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False)
if __name__ == "__main__":
app = QApplication(sys.argv) app = QApplication(sys.argv)
window = AutoGGUF() window = AutoGGUF()
window.show() window.show()
# Start Flask in a separate thread after a short delay
timer = QTimer()
timer.singleShot(100, lambda: threading.Thread(target=run_flask, daemon=True).start())
sys.exit(app.exec()) sys.exit(app.exec())