DoubleCloud 即将关闭。利用限时免费迁移服务迁移到 ClickHouse。立即联系我们 ->->

博客 / 工程

使用 ClickHouse 机器学习函数进行预测

author avatar
Ensemble
2023 年 10 月 26 日

这篇文章最初由 Ensemble Analytics 发布,他们已友好地允许我们重新发布此内容。我们欢迎来自社区的文章,并感谢他们的贡献。

简介

在进行统计分析或数据科学工作时,我们通常会第一时间转向使用 Python 或 R 等编程语言。

但是,当我们使用 ClickHouse 时,我们更倾向于尽可能地使用数据库本身。通过这样做,我们可以依靠 ClickHouse 的强大功能来快速处理数据,并减少甚至完全避免我们需要编写的代码量。这也意味着我们可以处理客户端上更小的内存数据集,并避免对分布式计算的需求。

一个很好的例子就是预测。ClickHouse 实现了两个机器学习函数:随机线性回归 (stochasticLinearRegression) 可用于拟合模型,另一个函数 (evalMLMethod) 可用于直接在数据库中进行后续推理。

当然,一旦你从 SQL 转向成熟的编程语言,就会有更多复杂的预测模型和更大的灵活性,但这种技术无疑有其用途,并且在我们这里演示的场景中表现良好。

数据集

为了演示,我们将使用一个简单的航班起飞数据集,其中包含来自不同机场使用不同航空公司起飞的乘客数量的月度时间序列。

我们的目标是使用此数据并将其用于预测未来相同数据。

我们将致力于使用 2008 年至 2015 年的数据构建模型,然后测试模型在 2015 年至 2018 年之间的性能。最后,我们将预测超出该时间段直至 2021 年的数据。

我们的源数据具有以下结构

SELECT *
FROM flight_data
LIMIT 10

┌─AIRLINE─┬─DEPARTURE_AIRPORT─┬──────MONTH─┬─PASSENGERS─┐
│ Delta   │ DIA               │ 2008-01-01434 │
│ Delta   │ DIA               │ 2008-02-01475 │
│ Delta   │ DIA               │ 2008-03-01531 │
│ Delta   │ DIA               │ 2008-04-01509 │
│ Delta   │ DIA               │ 2008-05-01472 │
│ Delta   │ DIA               │ 2008-06-01562 │
│ Delta   │ DIA               │ 2008-07-01642 │
│ Delta   │ DIA               │ 2008-08-01642 │
│ Delta   │ DIA               │ 2008-09-01596 │
│ Delta   │ DIA               │ 2008-10-01503 │
└─────────┴───────────────────┴────────────┴────────────┘

10 rows in set. Elapsed: 0.002 sec. Processed 4.62 thousand rows, 151.54 KB (2.16 million rows/s., 70.86 MB/s.)
Peak memory usage: 229.15 KiB.

绘制出来后,数据看起来像这样,显示了所有航空公司随着时间的推移都运送了越来越多的乘客,同时还存在明显的季节性影响。

hex01.png

数据准备

我们的预测模型使用 13 个确定性特征:线性时间趋势和 12 个表示一年中 12 个月的虚拟(或独热编码)变量。我们排除常数项(或截距)以避免“虚拟变量陷阱”。

该模型预测乘客数量的对数。对数变换使我们能够更好地捕获季节性波动的随时间变化的幅度。

CREATE VIEW
    data
AS WITH
    (select toDate(min(MONTH)) from flight_data) as start_date,
    (select toDate(max(MONTH)) from flight_data) as end_date
SELECT
    AIRLINE,
    DEPARTURE_AIRPORT,
    MONTH,
    toFloat64(log(PASSENGERS)) as Target,
    assumeNotNull(dateDiff('month', start_date, MONTH) / dateDiff('month', start_date, end_date)) as Trend,
    if(toMonth(toDate(MONTH)) = 1, 1, 0) as Dummy1,
    if(toMonth(toDate(MONTH)) = 2, 1, 0) as Dummy2,
    if(toMonth(toDate(MONTH)) = 3, 1, 0) as Dummy3,
    if(toMonth(toDate(MONTH)) = 4, 1, 0) as Dummy4,
    if(toMonth(toDate(MONTH)) = 5, 1, 0) as Dummy5,
    if(toMonth(toDate(MONTH)) = 6, 1, 0) as Dummy6,
    if(toMonth(toDate(MONTH)) = 7, 1, 0) as Dummy7,
    if(toMonth(toDate(MONTH)) = 8, 1, 0) as Dummy8,
    if(toMonth(toDate(MONTH)) = 9, 1, 0) as Dummy9,
    if(toMonth(toDate(MONTH)) = 10, 1, 0) as Dummy10,
    if(toMonth(toDate(MONTH)) = 11, 1, 0) as Dummy11,
    if(toMonth(toDate(MONTH)) = 12, 1, 0) as Dummy12
FROM
    flight_data
ORDER BY AIRLINE, DEPARTURE_AIRPORT, MONTH

这将创建一个以下视图,它总结了我们的因变量和自变量

SELECT *
FROM data
LIMIT 10

┌─AIRLINE─┬─DEPARTURE_AIRPORT─┬──────MONTH─┬─────────────Target─┬────────────────Trend─┬─Dummy1─┬─Dummy2─┬─Dummy3─┬─Dummy4─┬─Dummy5─┬─Dummy6─┬─Dummy7─┬─Dummy8─┬─Dummy9─┬─Dummy10─┬─Dummy11─┬─Dummy12─┐
│ Delta   │ DIA               │ 2008-01-016.07304453333358650100000000000 │
│ Delta   │ DIA               │ 2008-02-016.1633148043360030.007633587786259542010000000000 │
│ Delta   │ DIA               │ 2008-03-016.2747620213889250.015267175572519083001000000000 │
│ Delta   │ DIA               │ 2008-04-016.2324480165547820.022900763358778626000100000000 │
│ Delta   │ DIA               │ 2008-05-016.1569789858738250.030534351145038167000010000000 │
│ Delta   │ DIA               │ 2008-06-016.33150185006186650.03816793893129771000001000000 │
│ Delta   │ DIA               │ 2008-07-016.4645883046242930.04580152671755725000000100000 │
│ Delta   │ DIA               │ 2008-08-016.4645883046242930.05343511450381679000000010000 │
│ Delta   │ DIA               │ 2008-09-016.3902406663626440.061068702290076333000000001000 │
│ Delta   │ DIA               │ 2008-10-016.2205901701385750.06870229007633588000000000100 │
└─────────┴───────────────────┴────────────┴────────────────────┴──────────────────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴────────┴─────────┴─────────┴─────────┘

10 rows in set. Elapsed: 0.010 sec. Processed 13.86 thousand rows, 170.02 KB (1.37 million rows/s., 16.81 MB/s.)
Peak memory usage: 420.28 KiB.

模型训练

我们使用 ClickHouse 的 stochasticLinearRegression 算法,该算法使用梯度下降法训练线性回归。我们同时构建 35 个不同的模型,每个模型对应于一个航空公司-机场组合。

CREATE VIEW model as SELECT
    AIRLINE,
    DEPARTURE_AIRPORT,
    stochasticLinearRegressionState(0.5, 0.01, 4, 'SGD')(
        Target, Trend, Dummy1, Dummy2, Dummy3, Dummy4, Dummy5, Dummy6, Dummy7, Dummy8, Dummy9, Dummy10, Dummy11, Dummy12
    ) as state
FROM train_data
GROUP BY AIRLINE, DEPARTURE_AIRPORT

由于数据量很小,因此该模型仅被定义为视图。对于更大的数据集,我们可能会选择将其物化成表或视图。

模型评估

现在,我们可以使用训练好的模型生成测试集上的预测,并将它们与实际值进行比较。在此阶段,我们还可以通过取指数将数据和预测转换回原始比例。

SELECT
    a.MONTH as MONTH,
    a.AIRLINE as AIRLINE,
    a.DEPARTURE_AIRPORT as DEPARTURE_AIRPORT,
    toInt32(exp(a.Target)) as ACTUAL,
    toInt32(exp(evalMLMethod(b.state, Trend, Dummy1, Dummy2, Dummy3, Dummy4, Dummy5, Dummy6, Dummy7,
    Dummy8, Dummy9, Dummy10, Dummy11, Dummy12))) as FORECAST
FROM test_data as a
LEFT JOIN model as b
on a.AIRLINE = b.AIRLINE and a.DEPARTURE_AIRPORT = b.DEPARTURE_AIRPORT

如果我们比较预测和实际值,我们可以看到预测表现良好

hex02.png

我们可以通过计算每个航空公司-机场组合的预测的平均绝对误差 (MAE) 和均方根误差 (RMSE) 来验证这一点。

SELECT
    AIRLINE,
    DEPARTURE_AIRPORT,
    avg(abs(ERROR)) AS MAE,
    sqrt(avg(pow(ERROR, 2))) AS RMSE
FROM
(
    SELECT
        a.AIRLINE AS AIRLINE,
        a.DEPARTURE_AIRPORT AS DEPARTURE_AIRPORT,
        toInt32(exp(a.Target)) - toInt32(exp(evalMLMethod(b.state, Trend, Dummy1, Dummy2, Dummy3, Dummy4,
        Dummy5, Dummy6, Dummy7, Dummy8, Dummy9, Dummy10, Dummy11, Dummy12))) AS ERROR
    FROM test_data AS a
    LEFT JOIN model AS b ON (a.AIRLINE = b.AIRLINE) AND (a.DEPARTURE_AIRPORT = b.DEPARTURE_AIRPORT)
)
GROUP BY
    AIRLINE,
    DEPARTURE_AIRPORT

Query id: 320cad46-bb31-4248-bd25-19d98d5d2d15

┌─AIRLINE──┬─DEPARTURE_AIRPORT─┬────────────────MAE─┬───────────────RMSE─┐
│ JetBlue  │ SFO               │  86.38888888888889110.96671172523367 │
│ KLM      │ PDX               │ 167.97222222222223213.4134615143936 │
│ Delta    │ SJC               │ 141.80555555555554180.9452802491528 │
│ United   │ PDX               │ 115.19444444444444147.7711255812703 │
│ JetBlue  │ ORL               │  97.77777777777777125.28611699271038 │
│ KLM      │ JAX               │ 121.27777777777777155.41414207064798 │
│ Delta    │ JFK               │              168.5214.1754213515433 │
│ United   │ JAX               │ 153.88888888888889195.9098432102549 │
│ Delta    │ SFO               │ 184.66666666666666234.34068267280344 │
│ KLM      │ DIA               │ 148.94444444444446189.77618396416344 │
│ United   │ JFK               │ 178.02777777777777226.086205289536 │
│ Frontier │ ORL               │ 206.38888888888889261.27720485679146 │
│ United   │ SJC               │ 119.91666666666667153.72332650288018 │
│ KLM      │ SJC               │ 218.13888888888889275.90532796595284 │
│ KLM      │ JFK               │  70.3055555555555690.43244869944515 │
│ Delta    │ JAX               │ 186.55555555555554236.69213477990067 │
│ Delta    │ ORL               │  74.4444444444444495.50887102486577 │
│ Frontier │ SFO               │  63.0277777777777880.91748197323548 │
│ Frontier │ PDX               │                 81103.99278821149089 │
│ United   │ ORL               │              111.5142.90031490518138 │
│ Frontier │ JAX               │  98.11111111111111125.86147588166568 │
│ Frontier │ DIA               │  95.91666666666667122.96758832219886 │
│ Delta    │ PDX               │  72.4166666666666792.89046715830904 │
│ JetBlue  │ JFK               │ 141.91666666666666181.17877911057906 │
│ JetBlue  │ SJC               │              209.5265.1057441013973 │
│ JetBlue  │ JAX               │ 107.30555555555556137.61893845769274 │
│ KLM      │ ORL               │ 156.77777777777777199.51287900506296 │
│ JetBlue  │ DIA               │  76.8333333333333398.60076628054729 │
│ Frontier │ SJC               │  97.22222222222223124.6602048236191 │
│ Frontier │ JFK               │ 156.33333333333334199.04550010264265 │
│ Delta    │ DIA               │                114146.3065655092454 │
│ KLM      │ SFO               │ 119.97222222222223153.7722883573847 │
│ United   │ DIA               │  72.6388888888888993.25666493905706 │
│ JetBlue  │ PDX               │ 147.83333333333334188.4872527372725 │
│ United   │ SFO               │ 186.83333333333334237.06668072740865 │
└──────────┴───────────────────┴────────────────────┴────────────────────┘

35 rows in set. Elapsed: 0.024 sec. Processed 18.48 thousand rows, 321.55 KB (785.99 thousand rows/s., 13.68 MB/s.)
Peak memory usage: 766.46 KiB.

模型推理

最后,我们现在可以使用该模型生成数据集最后日期之后的预测。为此,我们在随后 3 年内创建了一个新表,其中包含日期及其对应的变换(时间趋势和虚拟变量)。

CREATE VIEW
    future_data
AS WITH
    (select toDate(min(MONTH)) from flight_data) as start_date,
    (select toDate(max(MONTH)) from flight_data) as end_date
SELECT
    AIRLINE,
    DEPARTURE_AIRPORT,
    MONTH + INTERVAL 3 YEAR as MONTH,
    assumeNotNull(dateDiff('month', start_date, MONTH) / dateDiff('month', start_date, end_date)) as Trend,
    if(toMonth(toDate(MONTH)) = 1, 1, 0) as Dummy1,
    if(toMonth(toDate(MONTH)) = 2, 1, 0) as Dummy2,
    if(toMonth(toDate(MONTH)) = 3, 1, 0) as Dummy3,
    if(toMonth(toDate(MONTH)) = 4, 1, 0) as Dummy4,
    if(toMonth(toDate(MONTH)) = 5, 1, 0) as Dummy5,
    if(toMonth(toDate(MONTH)) = 6, 1, 0) as Dummy6,
    if(toMonth(toDate(MONTH)) = 7, 1, 0) as Dummy7,
    if(toMonth(toDate(MONTH)) = 8, 1, 0) as Dummy8,
    if(toMonth(toDate(MONTH)) = 9, 1, 0) as Dummy9,
    if(toMonth(toDate(MONTH)) = 10, 1, 0) as Dummy10,
    if(toMonth(toDate(MONTH)) = 11, 1, 0) as Dummy11,
    if(toMonth(toDate(MONTH)) = 12, 1, 0) as Dummy12
FROM
    test_data
ORDER BY AIRLINE, DEPARTURE_AIRPORT, MONTH

为我们提供了一种端到端的可视化。从视觉上看,我们可以看到乘客数量的增加和季节性已经被超出范围的预测所捕获。

hex03.png

结论

在本文中,我们演示了如何使用 ClickHouse 中可用的 ML 函数 (stochasticLinearRegression 和 evalMLMethod) 来实现简单的预测技术。

原则上,将此类指标和分析工作卸载到数据库中是一件好事。像 ClickHouse 这样的分析数据库通常会优于单机处理数据集的性能,并且允许我们处理比单机能够处理的数据集更大的数据集,同时还可以减少需要进行的脚本工作量。

在 ClickHouse 中,这也可以内置到物化视图中,这意味着当捕获到新数据时,模型会不断更新和重新训练,从而打开了实时可能性。

我们相信这种模式在未来会发展,更多的数据科学和机器学习算法将直接在数据库中实现。

可以在 此 URL 找到描述完整工作示例的笔记本。

分享这篇文章

订阅我们的时事通讯

随时了解功能发布、产品路线图、支持和云产品!
正在加载表单...
关注我们
Twitter imageSlack imageGitHub image
Telegram imageMeetup imageRss image