MLflow 是一个用于机器学习实验追踪与模型版本管理的工具集,旨在解决实验可复现性差、参数和模型难以管理的问题。其核心概念包括 MLflow Tracking、MLflow Projects 和 MLflow Model Registry。MLflow Tracking 记录实验参数、指标和 artifact,MLflow Projects 允许代码在任何环境中运行,而 MLflow Model Registry 则管理模型的版本和生命周期。读者将学会如何使用 MLflow 的自动记录功能(autolog)来简化实验追踪,并通过 UI 界面比较不同实验的结果。此外,读者还将了解如何利用 Model Registry 管理模型的阶段流转,以及如何结合 DVC 进行数据集版本管理。最终,读者能够实现机器学习实验的可复现性,高效管理模型生命周期,并结合数据版本控制工具确保整个流程的完整性和可追溯性。
实验追踪与模型版本管理
每个 ML 工程师都遇到过这个噩梦:
"我上周训练的那个 92% 准确率的模型去哪儿了?超参记不清了,数据好像也不是这个版本..."
MLflow 解决这一切: 把实验参数、代码、数据、模型统一管理,让 ML 实验可复现。
MLflow 三件套
- MLflow Tracking: 记录实验参数 + 指标 + artifact
- MLflow Projects: 打包代码, 任何环境能跑
- MLflow Model Registry: 模型版本管理 + 阶段流转
安装: pip install mlflow scikit-learn
5 行开启 Tracking
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# 加载数据
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 启动一个 experiment
mlflow.set_experiment("iris-classifier")
with mlflow.start_run():
# 记录超参数
n_estimators = 100
max_depth = 5
mlflow.log_param("n_estimators", n_estimators)
mlflow.log_param("max_depth", max_depth)
# 训练
model = RandomForestClassifier(n_estimators=n_estimators, max_depth=max_depth)
model.fit(X_train, y_train)
# 评估
acc = accuracy_score(y_test, model.predict(X_test))
mlflow.log_metric("accuracy", acc)
# 保存模型
mlflow.sklearn.log_model(model, "model")
print(f"Run ID: {mlflow.active_run().info.run_id}, Accuracy: {acc:.2%}")
运行后: mlflow ui 打开 http://localhost:5000 看所有实验。
自动 log 一切 (autolog)
不用手动 log_param / log_metric, MLflow 一键全记:
import mlflow
mlflow.sklearn.autolog() # 训练 sklearn 时自动 log
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100, max_depth=5)
model.fit(X_train, y_train)
# 全部参数、指标、模型文件都自动保存!
支持的库: sklearn / XGBoost / LightGBM / PyTorch / TensorFlow / Keras / Spark MLlib / Fastai
一次训 50 个模型 + 找最优
import mlflow
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
mlflow.sklearn.autolog()
# 网格搜索 + 自动 log
param_grid = {
"n_estimators": [50, 100, 200],
"max_depth": [3, 5, 10, None],
}
grid = GridSearchCV(
RandomForestClassifier(),
param_grid, cv=5, scoring="accuracy"
)
grid.fit(X_train, y_train)
print(f"Best: {grid.best_params_}, {grid.best_score_:.2%}")
autolog 会记录每组参数的结果, 之后在 UI 里按 metric 排序, 找最优。
Model Registry: 模型从实验到生产
光有 tracking 还不够, 还要管理模型生命周期 (Staging → Production → Archived):
from mlflow import MlflowClient
client = MlflowClient()
# 1. 注册模型 (从某个 run)
model_uri = f"runs:/{run_id}/model"
mlflow.register_model(model_uri, "iris-classifier")
# 2. 提升到生产
client.transition_model_version_stage(
name="iris-classifier",
version=1,
stage="Production"
)
# 3. 加载生产模型
prod_model = mlflow.pyfunc.load_model("models:/iris-classifier/Production")
pred = prod_model.predict(X_test)
完整流程图
[开发阶段] [上线阶段]
Train v1 → log Register → Stage: None
Train v2 → log Register → Stage: Staging
Train v3 → log Test OK → Stage: Production
Bad → Stage: Archived
模型 Registry 让 "哪个模型在线上" 一目了然, 回滚也只是改 stage。
配合 DVC: 数据集版本管理
模型是数据的派生, 还要管数据。DVC (Data Version Control) 是 Git 的搭档:
# 1. 初始化
dvc init
# 2. 把 data.csv 加到 DVC
dvc add data/training.csv
git add data/training.csv.dvc data/.gitignore
git commit -m "Add training data v1"
# 3. 数据变了
dvc add data/training.csv
git commit -m "Update to training data v2"
# 4. 切回老数据
git checkout HEAD~1 -- data/training.csv.dvc
dvc checkout
DVC 把大文件存到 S3 / OSS, .dvc 文件存指针到 Git。
5 个工程实践
- 每次跑都写
mlflow.start_run(), 不跑就丢失 - 代码 commit hash 记到 run:
mlflow.set_tag("git_sha", sha) - 数据 hash 也记:
mlflow.set_tag("data_hash", md5(file)) - 模型 schema 记: 输入输出列名、类型
- 别用默认 SQLite: 生产用
mlflow server --backend-store-uri postgresql://...
与其他工具对比
| 工具 | 优势 | 劣势 |
|---|---|---|
| MLflow | 通用、API 简单、本地优先 | 大规模需要自己部署 |
| Weights & Biases | UI 漂亮、协作强 | SaaS, 收费 |
| Neptune.ai | 团队协作、对比视图 | SaaS, 收费 |
| TensorBoard | 深度学习集成 | 主打可视化, 不是 tracking |
| DVC | 数据 + 模型 + pipeline | 学习曲线陡 |
小结
- MLflow = Tracking (实验) + Projects (代码) + Registry (模型生命周期)
autolog()一键记录 sklearn / XGBoost / PyTorch- Model Registry 用 stage (Staging / Production) 管理线上模型
- 配合 DVC 管数据, Git 管代码
- 可复现 = 代码 hash + 数据 hash + 依赖 hash + 随机种子
练习思考
- 跑 5 个不同超参的 RandomForest, 在 MLflow UI 里比较 accuracy, 哪组最好?
- 注册 2 个模型版本, 把 v1 提升到 Production, 加载并预测, 跟 v2 的预测结果对比。
- MLflow autolog 跟手动 log_param 有什么区别? 什么时候必须用手动?
章末小测验
检验你对《实验追踪与模型版本管理》的掌握程度。
MLflow Tracking 主要记录什么?
Model Registry 中模型的生命周期阶段是?
学完这章, 你可能想看
这门课在以下学习路径中
当前课程出现在 1 条系统化路径里, 你可以一键生成完整学习计划, 自动跳过已完成章节。
还有疑问? 问问 AI (v19.5)
基于全站 19 门课 68 章内容检索 + LLM 总结, 会引用具体章节作为出处