443 lines
15 KiB
Python
443 lines
15 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import threading
|
|
import time
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from datetime import datetime, timedelta
|
|
|
|
import plotly.graph_objects as go
|
|
import plotly.offline as pyo
|
|
from flask import Flask, render_template, request, send_file
|
|
from plotly.subplots import make_subplots
|
|
|
|
from globals import load_env
|
|
|
|
|
|
load_env()
|
|
|
|
# Configuration
|
|
HOST = os.getenv("HOST", "0.0.0.0")
|
|
PORT = int(os.getenv("PORT", 4421))
|
|
DATA_DIR = os.getenv("DATA_DIR", "./data")
|
|
UPDATE_INTERVAL = int(os.getenv("UPDATE_INTERVAL", 30))
|
|
LOG_LEVEL = os.getenv("LOG_LEVEL", "WARNING")
|
|
|
|
# Set up logging
|
|
logging.basicConfig(
|
|
level=getattr(logging, LOG_LEVEL),
|
|
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
)
|
|
|
|
# Global variable to store the cached data
|
|
cached_data = {}
|
|
last_modified_times = {}
|
|
|
|
|
|
async def load_data_from_db(filepath):
|
|
global cached_data
|
|
|
|
try:
|
|
conn = sqlite3.connect(filepath)
|
|
cursor = conn.cursor()
|
|
|
|
twenty_four_hours_ago = datetime.now() - timedelta(hours=24)
|
|
timestamp_24h_ago = int(twenty_four_hours_ago.timestamp())
|
|
|
|
cursor.execute(
|
|
"SELECT data, timestamp FROM json_data WHERE timestamp >= ?",
|
|
(timestamp_24h_ago,),
|
|
)
|
|
rows = cursor.fetchall()
|
|
|
|
model_name = os.path.splitext(os.path.basename(filepath))[0]
|
|
|
|
# Optimize data structure creation
|
|
new_data = {}
|
|
for row in rows:
|
|
data = json.loads(row[0])
|
|
timestamp = datetime.fromtimestamp(row[1])
|
|
|
|
for metric_name, metric_data in data.items():
|
|
if metric_name not in new_data:
|
|
new_data[metric_name] = {}
|
|
if model_name not in new_data[metric_name]:
|
|
new_data[metric_name][model_name] = []
|
|
new_data[metric_name][model_name].append((timestamp, metric_data))
|
|
|
|
# Update cached_data efficiently
|
|
for metric_name, model_data in new_data.items():
|
|
if metric_name not in cached_data:
|
|
cached_data[metric_name] = model_data
|
|
else:
|
|
cached_data[metric_name].update(model_data)
|
|
|
|
except sqlite3.Error as e:
|
|
logging.error(f"SQLite error in {filepath}: {e}")
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"JSON decode error in {filepath}: {e}")
|
|
except Exception as e:
|
|
logging.error(f"Error processing file {filepath}: {e}")
|
|
finally:
|
|
if conn:
|
|
conn.close()
|
|
|
|
|
|
async def load_data():
|
|
tasks = []
|
|
for filename in os.listdir(DATA_DIR):
|
|
if filename.endswith(".sqlite"):
|
|
filepath = os.path.join(DATA_DIR, filename)
|
|
tasks.append(load_data_from_db(filepath))
|
|
|
|
await asyncio.gather(*tasks)
|
|
logging.info(f"Loaded data for {len(cached_data)} metrics")
|
|
if len(cached_data) == 0:
|
|
logging.warning(
|
|
"No data was loaded. Check if SQLite files exist and contain recent data."
|
|
)
|
|
|
|
|
|
async def background_data_loader():
|
|
while True:
|
|
await load_data()
|
|
await asyncio.sleep(UPDATE_INTERVAL)
|
|
|
|
|
|
def start_background_loop(loop):
|
|
asyncio.set_event_loop(loop)
|
|
loop.run_forever()
|
|
|
|
|
|
# Start the background data loader
|
|
threading.Thread(target=background_data_loader, daemon=True).start()
|
|
|
|
|
|
def create_trace(model_name, metric_name, data_points, row, col):
|
|
return (
|
|
go.Scattergl(
|
|
x=[point[0] for point in data_points],
|
|
y=[point[1] for point in data_points],
|
|
mode="lines",
|
|
name=f"{model_name} - {metric_name}",
|
|
),
|
|
row,
|
|
col,
|
|
)
|
|
|
|
|
|
def create_plots(selected_model):
|
|
global cached_data
|
|
start_time = time.time()
|
|
|
|
all_data = {}
|
|
selected_models = selected_model.split(",")
|
|
for metric, data in cached_data.items():
|
|
all_data[metric] = {
|
|
model: data[model] for model in selected_models if model in data
|
|
}
|
|
|
|
data_prep_time = time.time() - start_time
|
|
print(f"Data preparation took {data_prep_time:.2f} seconds")
|
|
|
|
num_metrics = len(all_data)
|
|
if num_metrics == 0:
|
|
logging.warning("No valid data found.")
|
|
return None
|
|
|
|
num_cols = 2
|
|
num_rows = (num_metrics + num_cols - 1) // num_cols
|
|
fig = make_subplots(
|
|
rows=num_rows, cols=num_cols, subplot_titles=list(all_data.keys())
|
|
)
|
|
|
|
subplot_creation_time = time.time() - start_time - data_prep_time
|
|
print(f"Subplot creation took {subplot_creation_time:.2f} seconds")
|
|
|
|
now = datetime.now()
|
|
twenty_four_hours_ago = now - timedelta(hours=24)
|
|
|
|
trace_creation_start = time.time()
|
|
|
|
with ThreadPoolExecutor() as executor:
|
|
futures = []
|
|
for index, (metric_name, model_data) in enumerate(all_data.items()):
|
|
row = index // num_cols + 1
|
|
col = index % num_cols + 1
|
|
|
|
for model_name, metric_data_list in model_data.items():
|
|
if isinstance(metric_data_list[0][1], list):
|
|
for label_set in metric_data_list[0][1]:
|
|
data_points = []
|
|
for timestamp, metric_data in metric_data_list:
|
|
if timestamp >= twenty_four_hours_ago:
|
|
for data_point in metric_data:
|
|
if data_point["labels"] == label_set["labels"]:
|
|
try:
|
|
value = float(data_point["value"])
|
|
data_points.append((timestamp, value))
|
|
except ValueError:
|
|
logging.warning(
|
|
f"Invalid numeric value for {model_name} - {metric_name}: {data_point['value']}"
|
|
)
|
|
if not data_points:
|
|
continue
|
|
data_points.sort(key=lambda x: x[0])
|
|
futures.append(
|
|
executor.submit(
|
|
create_trace,
|
|
model_name,
|
|
str(label_set["labels"]),
|
|
data_points,
|
|
row,
|
|
col,
|
|
)
|
|
)
|
|
else:
|
|
data_points = []
|
|
for timestamp, metric_data in metric_data_list:
|
|
if timestamp >= twenty_four_hours_ago:
|
|
try:
|
|
value = float(metric_data)
|
|
data_points.append((timestamp, value))
|
|
except ValueError:
|
|
logging.warning(
|
|
f"Invalid numeric value for {model_name} - {metric_name}: {metric_data}"
|
|
)
|
|
if not data_points:
|
|
continue
|
|
data_points.sort(key=lambda x: x[0])
|
|
futures.append(
|
|
executor.submit(
|
|
create_trace, model_name, metric_name, data_points, row, col
|
|
)
|
|
)
|
|
|
|
for future in as_completed(futures):
|
|
trace, row, col = future.result()
|
|
fig.add_trace(trace, row=row, col=col)
|
|
|
|
trace_creation_time = time.time() - trace_creation_start
|
|
print(f"Trace creation took {trace_creation_time:.2f} seconds")
|
|
|
|
layout_update_start = time.time()
|
|
fig.update_layout(
|
|
height=300 * num_rows,
|
|
showlegend=True,
|
|
template="plotly_dark",
|
|
font=dict(family="Arial", size=10, color="white"),
|
|
paper_bgcolor="rgb(30, 30, 30)",
|
|
plot_bgcolor="rgb(30, 30, 30)",
|
|
)
|
|
fig.update_xaxes(title_text="Time", tickformat="%Y-%m-%d %H:%M:%S")
|
|
fig.update_yaxes(title_text="Value")
|
|
fig.update_traces(hovertemplate="%{x|%Y-%m-%d %H:%M:%S}<br>%{y:.3f}")
|
|
|
|
layout_update_time = time.time() - layout_update_start
|
|
print(f"Layout update took {layout_update_time:.2f} seconds")
|
|
|
|
total_time = time.time() - start_time
|
|
print(f"Total plot creation took {total_time:.2f} seconds")
|
|
|
|
return fig
|
|
|
|
|
|
def get_latest_rows(db_file, hours=1):
|
|
now = datetime.now()
|
|
one_hour_ago = now - timedelta(hours=hours)
|
|
one_hour_ago_timestamp = int(one_hour_ago.timestamp())
|
|
|
|
conn = sqlite3.connect(db_file)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute(
|
|
f"SELECT timestamp, data FROM json_data WHERE timestamp >= {one_hour_ago_timestamp} ORDER BY timestamp DESC"
|
|
)
|
|
rows = cursor.fetchall()
|
|
|
|
conn.close()
|
|
|
|
if not rows:
|
|
print(
|
|
f"No rows found in the last {hours} hour(s). Showing info for last 5 rows:"
|
|
)
|
|
conn = sqlite3.connect(db_file)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
f"SELECT timestamp, data FROM json_data ORDER BY timestamp DESC LIMIT 5"
|
|
)
|
|
rows = cursor.fetchall()
|
|
conn.close()
|
|
for timestamp, _ in rows:
|
|
print(f" {datetime.fromtimestamp(timestamp)}")
|
|
|
|
return rows
|
|
|
|
|
|
def extract_stats(data_json):
|
|
try:
|
|
data = json.loads(data_json)
|
|
total_prompt_tokens = float(data["vllm:prompt_tokens_total"][0]["value"])
|
|
total_generation_tokens = float(
|
|
data["vllm:generation_tokens_total"][0]["value"]
|
|
)
|
|
total_requests = sum(
|
|
float(item["value"]) for item in data["vllm:request_success_total"]
|
|
)
|
|
avg_prompt_throughput = float(
|
|
data["vllm:avg_prompt_throughput_toks_per_s"][0]["value"]
|
|
)
|
|
avg_generation_throughput = float(
|
|
data["vllm:avg_generation_throughput_toks_per_s"][0]["value"]
|
|
)
|
|
gpu_cache_usage_perc = float(data["vllm:gpu_cache_usage_perc"][0]["value"])
|
|
num_requests_running = float(data["vllm:num_requests_running"][0]["value"])
|
|
|
|
except (KeyError, IndexError, json.JSONDecodeError) as e:
|
|
print(f"Error extracting stats from data: {str(e)}")
|
|
return None
|
|
|
|
return {
|
|
"total_prompt_tokens": total_prompt_tokens,
|
|
"total_generation_tokens": total_generation_tokens,
|
|
"total_requests": total_requests,
|
|
"avg_prompt_throughput": avg_prompt_throughput,
|
|
"avg_generation_throughput": avg_generation_throughput,
|
|
"gpu_cache_usage_perc": gpu_cache_usage_perc,
|
|
"num_requests_running": num_requests_running,
|
|
}
|
|
|
|
|
|
def get_data(db_file, hours):
|
|
latest_rows = get_latest_rows(db_file, hours)
|
|
|
|
if not latest_rows:
|
|
print(f"No rows found for the last {hours} hour(s).")
|
|
return
|
|
|
|
print(f"Processing {len(latest_rows)} rows.")
|
|
|
|
valid_stats = [
|
|
extract_stats(data)
|
|
for _, data in latest_rows
|
|
if extract_stats(data) is not None
|
|
]
|
|
|
|
if not valid_stats:
|
|
print("No valid statistics could be extracted from the rows.")
|
|
return
|
|
|
|
first_stats = valid_stats[-1] # Oldest row
|
|
last_stats = valid_stats[0] # Newest row
|
|
|
|
tokens_processed = (
|
|
last_stats["total_prompt_tokens"]
|
|
- first_stats["total_prompt_tokens"]
|
|
+ last_stats["total_generation_tokens"]
|
|
- first_stats["total_generation_tokens"]
|
|
)
|
|
requests_processed = last_stats["total_requests"] - first_stats["total_requests"]
|
|
|
|
avg_prompt_throughput = sum(
|
|
stat["avg_prompt_throughput"] for stat in valid_stats
|
|
) / len(valid_stats)
|
|
avg_generation_throughput = sum(
|
|
stat["avg_generation_throughput"] for stat in valid_stats
|
|
) / len(valid_stats)
|
|
avg_num_requests_running = sum(
|
|
stat["num_requests_running"] for stat in valid_stats
|
|
) / len(valid_stats)
|
|
avg_gpu_cache_usage_perc = sum(
|
|
stat["gpu_cache_usage_perc"] for stat in valid_stats
|
|
) / len(valid_stats)
|
|
|
|
return (
|
|
f"\nStats for the last {hours} hour(s):\n"
|
|
f"Tokens processed: {tokens_processed:,.0f}\n"
|
|
f"Requests processed: {requests_processed:,.0f}\n"
|
|
f"Average prompt throughput: {avg_prompt_throughput:.2f} tokens/s\n"
|
|
f"Average generation throughput: {avg_generation_throughput:.2f} tokens/s\n"
|
|
f"Average tokens per request: {tokens_processed/requests_processed:,.2f} tokens\n"
|
|
f"Average number of requests running: {avg_num_requests_running:.2f} requests\n"
|
|
f"Average GPU cache usage percent: {avg_gpu_cache_usage_perc * 100:.2f}%"
|
|
)
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
@app.route("/", methods=["GET", "POST"])
|
|
def index():
|
|
model_names = [
|
|
name[:-7]
|
|
for name in os.listdir(DATA_DIR)
|
|
if name.endswith(".sqlite") and os.path.isfile(os.path.join(DATA_DIR, name))
|
|
]
|
|
valid_model_names = set(model_names)
|
|
|
|
if request.method == "POST":
|
|
selected_model = request.form.get("model_select")
|
|
else:
|
|
selected_model = model_names[0] if model_names else None
|
|
|
|
plot_div = None
|
|
error_message = None
|
|
if selected_model in valid_model_names:
|
|
try:
|
|
fig = create_plots(selected_model)
|
|
if fig is not None:
|
|
fig.update_layout(showlegend=False)
|
|
plot_div = pyo.plot(fig, output_type="div", include_plotlyjs=True)
|
|
else:
|
|
error_message = "No data available for the selected model."
|
|
except Exception as e:
|
|
logging.error(f"Error creating plot: {str(e)}")
|
|
error_message = (
|
|
"An error occurred while creating the plot. Please try again later."
|
|
)
|
|
|
|
result = get_data(f"{DATA_DIR}/{selected_model}.sqlite", 24)
|
|
else:
|
|
logging.error(f"Invalid model selected: {selected_model}")
|
|
result = None
|
|
|
|
return render_template(
|
|
"index.html",
|
|
plot_div=plot_div,
|
|
model_name=selected_model,
|
|
model_names=model_names,
|
|
result=result,
|
|
error_message=error_message,
|
|
)
|
|
|
|
|
|
@app.route("/favicon.ico")
|
|
def favicon():
|
|
return send_file("favicon.ico", mimetype="image/vnd.microsoft.icon")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
from asgiref.wsgi import WsgiToAsgi
|
|
|
|
# Initial data load
|
|
logging.info("Starting initial data load")
|
|
asyncio.run(load_data())
|
|
logging.info("Initial data load complete")
|
|
|
|
# Create a new event loop for the background task
|
|
loop = asyncio.new_event_loop()
|
|
|
|
def start_background_loop():
|
|
asyncio.set_event_loop(loop)
|
|
loop.create_task(background_data_loader())
|
|
loop.run_forever()
|
|
|
|
t = threading.Thread(target=start_background_loop, daemon=True)
|
|
t.start()
|
|
|
|
asgi_app = WsgiToAsgi(app)
|
|
uvicorn.run(asgi_app, host=HOST, port=PORT)
|