feat(server): replace Flask with FastAPI and Uvicorn

- replace Flask with FastAPI and Uvicorn
- fix web page not found error
- port is now defaulted to 7001
- bind to localhost (127.0.0.1) instead of 0.0.0.0
- improve performance by using Uvicorn
- add OpenAPI docs for endpoints
This commit is contained in:
BuildTools 2024-08-31 15:34:47 -07:00
parent db1733b4ed
commit 22bd74b399
No known key found for this signature in database
GPG Key ID: 3270C066C15D530B
2 changed files with 162 additions and 63 deletions

View File

@ -6,9 +6,10 @@ sentencepiece~=0.2.0
PyYAML~=6.0.2
pynvml~=11.5.3
PySide6~=6.7.2
flask~=3.0.3
python-dotenv~=1.0.1
safetensors~=0.4.4
setuptools~=68.2.0
huggingface-hub~=0.24.6
transformers~=4.44.2
fastapi~=0.112.2
uvicorn~=0.30.6

View File

@ -1,80 +1,178 @@
import os
import sys
import threading
from enum import Enum
from typing import List, Optional
from PySide6.QtCore import QTimer
from PySide6.QtWidgets import QApplication
from AutoGGUF import AutoGGUF
from flask import Flask, Response, jsonify
from fastapi import FastAPI, Query
from pydantic import BaseModel, Field
from uvicorn import Config, Server
server = Flask(__name__)
from AutoGGUF import AutoGGUF
from Localizations import AUTOGGUF_VERSION
app = FastAPI(
title="AutoGGUF",
description="API for AutoGGUF - automatically quant GGUF models",
version=AUTOGGUF_VERSION,
license_info={
"name": "Apache 2.0",
"url": "https://raw.githubusercontent.com/leafspark/AutoGGUF/main/LICENSE",
},
)
# Global variable to hold the window reference
window = None
class ModelType(str, Enum):
single = "single"
sharded = "sharded"
class Model(BaseModel):
name: str = Field(..., description="Name of the model")
type: str = Field(..., description="Type of the model")
path: str = Field(..., description="Path to the model file")
size: Optional[int] = Field(None, description="Size of the model in bytes")
class Config:
json_schema_extra = {
"example": {
"name": "Llama-3.1-8B-Instruct.fp16.gguf",
"type": "single",
"path": "Llama-3.1-8B-Instruct.fp16.gguf",
"size": 13000000000,
}
}
class Task(BaseModel):
id: str = Field(..., description="Unique identifier for the task")
status: str = Field(..., description="Current status of the task")
progress: float = Field(..., description="Progress of the task as a percentage")
class Config:
json_json_schema_extra = {
"example": {"id": "task_123", "status": "running", "progress": 75.5}
}
class Backend(BaseModel):
name: str = Field(..., description="Name of the backend")
path: str = Field(..., description="Path to the backend executable")
class Plugin(BaseModel):
name: str = Field(..., description="Name of the plugin")
version: str = Field(..., description="Version of the plugin")
description: str = Field(..., description="Description of the plugin")
author: str = Field(..., description="Author of the plugin")
@app.get("/v1/models", response_model=List[Model], tags=["Models"])
async def get_models(
type: Optional[ModelType] = Query(None, description="Filter models by type")
) -> List[Model]:
"""
Get a list of all available models.
- **type**: Optional filter for model type
Returns a list of Model objects containing name, type, path, and optional size.
"""
if window:
models = window.get_models_data()
if type:
models = [m for m in models if m["type"] == type]
# Convert to Pydantic models, handling missing 'size' field
return [Model(**m) for m in models]
return []
@app.get("/v1/tasks", response_model=List[Task], tags=["Tasks"])
async def get_tasks() -> List[Task]:
"""
Get a list of all current tasks.
Returns a list of Task objects containing id, status, and progress.
"""
if window:
return window.get_tasks_data()
return []
@app.get("/v1/health", tags=["System"])
async def health_check() -> dict:
"""
Check the health status of the API.
Returns a simple status message indicating the API is alive.
"""
return {"status": "alive"}
@app.get("/v1/backends", response_model=List[Backend], tags=["System"])
async def get_backends() -> List[Backend]:
"""
Get a list of all available llama.cpp backends.
Returns a list of Backend objects containing name and path.
"""
backends = []
if window:
for i in range(window.backend_combo.count()):
backends.append(
Backend(
name=window.backend_combo.itemText(i),
path=window.backend_combo.itemData(i),
)
)
return backends
@app.get("/v1/plugins", response_model=List[Plugin], tags=["System"])
async def get_plugins() -> List[Plugin]:
"""
Get a list of all installed plugins.
Returns a list of Plugin objects containing name, version, description, and author.
"""
if window:
return [
Plugin(**plugin_data["data"]) for plugin_data in window.plugins.values()
]
return []
def run_uvicorn() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
config = Config(
app=app,
host="127.0.0.1",
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 7001)),
log_level="info",
)
server = Server(config)
server.run()
def main() -> None:
@server.route("/v1/models", methods=["GET"])
def models() -> Response:
if window:
return jsonify({"models": window.get_models_data()})
return jsonify({"models": []})
@server.route("/v1/tasks", methods=["GET"])
def tasks() -> Response:
if window:
return jsonify({"tasks": window.get_tasks_data()})
return jsonify({"tasks": []})
@server.route("/v1/health", methods=["GET"])
def ping() -> Response:
return jsonify({"status": "alive"})
@server.route("/v1/backends", methods=["GET"])
def get_backends() -> Response:
backends = []
for i in range(window.backend_combo.count()):
backends.append(
{
"name": window.backend_combo.itemText(i),
"path": window.backend_combo.itemData(i),
}
)
return jsonify({"backends": backends})
@server.route("/v1/plugins", methods=["GET"])
def get_plugins() -> Response:
if window:
return jsonify(
{
"plugins": [
{
"name": plugin_data["data"]["name"],
"version": plugin_data["data"]["version"],
"description": plugin_data["data"]["description"],
"author": plugin_data["data"]["author"],
}
for plugin_data in window.plugins.values()
]
}
)
return jsonify({"plugins": []})
def run_flask() -> None:
if os.environ.get("AUTOGGUF_SERVER", "").lower() == "enabled":
server.run(
host="0.0.0.0",
port=int(os.environ.get("AUTOGGUF_SERVER_PORT", 5000)),
debug=False,
use_reloader=False,
)
app = QApplication(sys.argv)
global window
qt_app = QApplication(sys.argv)
window = AutoGGUF(sys.argv)
window.show()
# Start Flask in a separate thread after a short delay
# Start Uvicorn in a separate thread after a short delay
timer = QTimer()
timer.singleShot(
100, lambda: threading.Thread(target=run_flask, daemon=True).start()
100, lambda: threading.Thread(target=run_uvicorn, daemon=True).start()
)
sys.exit(app.exec())
sys.exit(qt_app.exec())
if __name__ == "__main__":